diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index fb885168d662006b3fa639c2322113d6f1f6b386..f16d84b3cf58efd181b0f22c09da73c53d8ecaf3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -36,6 +36,7 @@ RESULT_FILE_NAME = "accuracy_checking_result_" + current_time + ".csv" DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', 'save_error_data', 'is_continue_run_ut', 'test_result_cnt']) +not_backward_list = ['repeat_interleave'] def exec_api(api_type, api_name, args, kwargs): @@ -167,9 +168,14 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) need_backward = api_full_name in backward_content - need_backward = need_backward and need_grad if not need_grad: - print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) + print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." + % api_full_name) + if api_name in not_backward_list: + need_grad = False + print_warn_log( + "%s function backward result is None, skip backward." % api_full_name) + need_backward = need_backward and need_grad if kwargs.get("device"): del kwargs["device"] cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward)