From 2346c0024af570036b64f4975ea2160c4294eacc Mon Sep 17 00:00:00 2001 From: l30044004 Date: Sat, 9 Sep 2023 16:58:47 +0800 Subject: [PATCH] =?UTF-8?q?=E9=92=88=E5=AF=B9inplace=E8=BF=90=E7=AE=97?= =?UTF-8?q?=E9=80=82=E9=85=8D=E5=8F=8D=E5=90=91=E8=BF=87=E7=A8=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/run_ut/data_generate.py | 2 -- .../accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 9 ++++----- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 28ca86793..f7f151632 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -236,8 +236,6 @@ def gen_api_params(api_info, need_grad=True, convert_type=None): print_error_log(f"convert_type params not support {convert_type} ") raise CompareException.INVALID_PARAM_ERROR kwargs_params = gen_kwargs(api_info, convert_type) - if kwargs_params.get("inplace"): - need_grad = False if api_info.get("args"): args_params = gen_args(api_info.get("args"), need_grad, convert_type) else: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 707d6cbed..97679796a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -137,12 +137,12 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): in_fwd_data_list = [] [api_type, api_name, _] = api_full_name.split("*") - args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) - need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True + need_backward = api_full_name in backward_content and api_name[-1] != "_" need_backward = need_backward and need_grad - if inplace or not need_grad: + if not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) @@ -173,8 +173,7 @@ def get_api_info(api_info_dict, api_name): if api_name[-1] == "_" or api_name in NO_GRAD_APIS: need_grad = False args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) - inplace = kwargs.get("inplace") if kwargs.get("inplace") else None - return args, inplace, kwargs, need_grad + return args, kwargs, need_grad def run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out): -- Gitee