diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index d7cab76d063100618cf2398818906170c3992f1f..7c0fa0f6a6b73f1cd86b256ec791262d53638745 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -76,11 +76,10 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di torch.npu.clear_npu_overflow_flag() api_type = api_full_name.split("_")[0] api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0] - args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) - need_backward = api_full_name.replace("forward", "backward") in backward_content and api_name[-1] != "_" and \ - inplace is not True + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) + need_backward = api_full_name.replace("forward", "backward") in backward_content and api_name[-1] != "_" need_backward = need_backward and need_grad - if inplace or not need_grad: + if not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) if kwargs.get("device"):