From 98a585194a136476a9d30011311a07644ab52375 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Fri, 8 Sep 2023 17:15:35 +0800 Subject: [PATCH] Resolve inconsistency between saved data and calculated data --- .../api_accuracy_checker/compare/algorithm.py | 12 +++++++----- .../api_accuracy_checker/run_ut/run_ut.py | 11 +++++------ 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index b0b1aaf605..98d0a5585c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -168,11 +168,13 @@ def compare_core(bench_out, npu_out, alg): CompareConst.NA, CompareConst.NA compare_result, test_success, bench_dtype, npu_dtype = compare_core(list(bench_out.values()), list(npu_out.values()), alg) elif isinstance(bench_out, torch.Tensor): - bench_dtype = str(bench_out.dtype) - npu_dtype = str(npu_out.dtype) - if bench_out.dtype in [torch.float32, torch.float64] and bench_out.dtype != npu_out.dtype: - npu_out = npu_out.type(bench_out.dtype) - compare_result, test_success, msg = compare_torch_tensor(bench_out.detach().numpy(), npu_out.detach().cpu().numpy(), alg) + copy_bench_out = bench_out.detach().clone() + copy_npu_out = npu_out.detach().clone() + bench_dtype = str(copy_bench_out.dtype) + npu_dtype = str(copy_npu_out.dtype) + if copy_bench_out.dtype in [torch.float32, torch.float64] and copy_bench_out.dtype != copy_npu_out.dtype: + copy_npu_out = copy_npu_out.type(copy_bench_out.dtype) + compare_result, test_success, msg = compare_torch_tensor(copy_bench_out.numpy(), copy_npu_out.cpu().numpy(), alg) elif isinstance(bench_out, (bool, int, float, str)): compare_result, test_success, msg = compare_builtin_type(bench_out, npu_out) bench_dtype = str(type(bench_out)) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 707d6cbed9..20b80ae627 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -99,12 +99,11 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): for api_full_name, api_info_dict in tqdm(forward_content.items()): try: data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) - is_fwd_success, is_bwd_success = \ - compare.compare_output(api_full_name, - None if data_info.bench_out is None else data_info.bench_out.clone(), - None if data_info.npu_out is None else data_info.npu_out.clone(), - None if data_info.bench_grad_out is None else data_info.bench_grad_out.clone(), - None if data_info.npu_grad_out is None else data_info.npu_grad_out.clone()) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, + data_info.bench_out, + data_info.npu_out, + data_info.bench_grad_out, + data_info.npu_grad_out) if save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: -- Gitee