diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 28ca86793f1db03647df525ea5037b64623dde35..f7f1516323117885f6c4c15ae9918bc82fb37569 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -236,8 +236,6 @@ def gen_api_params(api_info, need_grad=True, convert_type=None): print_error_log(f"convert_type params not support {convert_type} ") raise CompareException.INVALID_PARAM_ERROR kwargs_params = gen_kwargs(api_info, convert_type) - if kwargs_params.get("inplace"): - need_grad = False if api_info.get("args"): args_params = gen_args(api_info.get("args"), need_grad, convert_type) else: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 707d6cbed9dbc19885470e916b4c49615356c657..97679796affda41554faddee373bcdd6a4b704d4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -137,12 +137,12 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): in_fwd_data_list = [] [api_type, api_name, _] = api_full_name.split("*") - args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) - need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True + need_backward = api_full_name in backward_content and api_name[-1] != "_" need_backward = need_backward and need_grad - if inplace or not need_grad: + if not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) @@ -173,8 +173,7 @@ def get_api_info(api_info_dict, api_name): if api_name[-1] == "_" or api_name in NO_GRAD_APIS: need_grad = False args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) - inplace = kwargs.get("inplace") if kwargs.get("inplace") else None - return args, inplace, kwargs, need_grad + return args, kwargs, need_grad def run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out):