diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index a46e819b8758c97af1db7bd19ca12a0866bd5f73..65091159a744e284a5fee7e710cdc57efe340a4e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -224,7 +224,7 @@ def gen_api_params(api_info, need_grad=True, convert_type=None): print_error_log(f"convert_type params not support {convert_type} ") raise CompareException.INVALID_PARAM_ERROR kwargs_params = gen_kwargs(api_info, convert_type) - if "inplace" in kwargs_params: + if kwargs_params.get("inplace"): need_grad = False if api_info.get("args"): args_params = gen_args(api_info.get("args"), need_grad, convert_type) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index b6bd65b8899fc20b3a45f89b8bc7b0c16e55f647..0c18da7d0e8f573260a68e94f89b25727a035b5a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -8,6 +8,8 @@ from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, check_need_convert from api_accuracy_checker.compare.compare import Comparator +NO_GRAD_APIS = ["hardtanh"] + cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "../hook_module/support_wrap_ops.yaml") with open(yaml_path, 'r') as f: @@ -77,10 +79,14 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): def run_torch_api(api_full_name, api_setting_dict, backward_content, value): [api_type, api_name, _] = api_full_name.split("*") convert_type = check_need_convert(api_name) - args, kwargs = gen_api_params(value, api_name[-1] != "_", convert_type) + need_grad = True + if api_name[-1] == "_" or api_name in NO_GRAD_APIS: + need_grad = False + args, kwargs = gen_api_params(value, need_grad, convert_type) inplace = kwargs.get("inplace") if kwargs.get("inplace") else None need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True - if inplace or api_name[-1] == "_": + need_backward = need_backward and need_grad + if inplace or not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) grad_out, npu_grad_out = None, None