From 4045d1c8f63862077417506e861e2279955d7eeb Mon Sep 17 00:00:00 2001 From: louyujing Date: Thu, 19 Oct 2023 07:42:53 +0000 Subject: [PATCH 1/2] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 034586cc1..4cc27880c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -143,7 +143,7 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di args, kwargs, need_grad = get_api_info(api_info_dict, api_name) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) - need_backward = api_full_name in backward_content and api_name[-1] != "_" + need_backward = api_full_name in backward_content need_backward = need_backward and need_grad if not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) @@ -173,7 +173,7 @@ def get_api_info(api_info_dict, api_name): need_grad = True if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): need_grad = False - if api_name[-1] == "_" or api_name in NO_GRAD_APIS: + if api_name in NO_GRAD_APIS: need_grad = False args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) return args, kwargs, need_grad -- Gitee From 6b71849aee320aa9359e16485f3534f36207f31b Mon Sep 17 00:00:00 2001 From: louyujing Date: Thu, 19 Oct 2023 07:44:39 +0000 Subject: [PATCH 2/2] update debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py. Signed-off-by: louyujing --- debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index aa564a6e8..cca521a29 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -165,7 +165,7 @@ def compare_core(bench_out, npu_out, alg): copy_npu_out = npu_out.detach().clone() bench_dtype = str(copy_bench_out.dtype) npu_dtype = str(copy_npu_out.dtype) - shape = list(npu_out.shape) + shape = tuple(npu_out.shape) if copy_npu_out.dtype == torch.bfloat16: copy_bench_out = copy_bench_out.to(torch.float32) copy_npu_out = copy_npu_out.to(torch.float32) -- Gitee