diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index aa564a6e8e9019ee390be5599a59eb15fd9f2564..cca521a296d136a95038b6f452ea3f4a01341909 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -165,7 +165,7 @@ def compare_core(bench_out, npu_out, alg): copy_npu_out = npu_out.detach().clone() bench_dtype = str(copy_bench_out.dtype) npu_dtype = str(copy_npu_out.dtype) - shape = list(npu_out.shape) + shape = tuple(npu_out.shape) if copy_npu_out.dtype == torch.bfloat16: copy_bench_out = copy_bench_out.to(torch.float32) copy_npu_out = copy_npu_out.to(torch.float32) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 034586cc1d29064eebc9c3cd09534caf46287b01..4cc27880c791ba069c4ef7ca06a062dc618d49f3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -143,7 +143,7 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di args, kwargs, need_grad = get_api_info(api_info_dict, api_name) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) - need_backward = api_full_name in backward_content and api_name[-1] != "_" + need_backward = api_full_name in backward_content need_backward = need_backward and need_grad if not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) @@ -173,7 +173,7 @@ def get_api_info(api_info_dict, api_name): need_grad = True if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): need_grad = False - if api_name[-1] == "_" or api_name in NO_GRAD_APIS: + if api_name in NO_GRAD_APIS: need_grad = False args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) return args, kwargs, need_grad