From aeba77659ed9654b01c2d9b609ba302046a91a11 Mon Sep 17 00:00:00 2001 From: louyujing Date: Thu, 21 Sep 2023 12:08:38 +0000 Subject: [PATCH] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py. Signed-off-by: louyujing --- .../api_accuracy_checker/run_ut/run_overflow_check.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index d7cab76d06..7c0fa0f6a6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -76,11 +76,10 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di torch.npu.clear_npu_overflow_flag() api_type = api_full_name.split("_")[0] api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0] - args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) - need_backward = api_full_name.replace("forward", "backward") in backward_content and api_name[-1] != "_" and \ - inplace is not True + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) + need_backward = api_full_name.replace("forward", "backward") in backward_content and api_name[-1] != "_" need_backward = need_backward and need_grad - if inplace or not need_grad: + if not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) if kwargs.get("device"): -- Gitee