diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 0bb30cc72d7d7f0a2387385084ec0722dc1cbffe..51d2c0fbfb3c3c2079f522f81a6b518eb2ed024f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -52,6 +52,9 @@ def cosine_sim(cpu_output, npu_output): if len(n_value) == 1: print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") return get_max_rel_err(n_value, b_value) + if len(n_value) == len(b_value) == 0: + print_warn_log("The npu dump data and bench dump data is empty.") + return cos, True if n_value.dtype == np.uint8: return compare_uint8_data(n_value, b_value) n_value = n_value / (np.max(np.abs(n_value)) + np.finfo(n_value.dtype).eps) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 4200b3b3b11aae14a969b686a3aa303e38a323b5..aa48ce5edd2c7fbd7e21cb3f9ca9048116bc0abf 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -239,6 +239,6 @@ def gen_api_params(api_info, need_grad=True, convert_type=None): if api_info.get("args"): args_params = gen_args(api_info.get("args"), need_grad, convert_type) else: - print_error_log(f'Warning: No args in {api_info} ') - raise NotImplementedError() + print_warn_log(f'Warning: No args in {api_info} ') + args_params = [] return args_params, kwargs_params diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index a5c96b207eaa37efca76b16e0bcec1db165f66ef..6ec357f62ebdd6ac66d9b86f7ac925b04118c136 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -73,6 +73,8 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, value): [api_type, api_name, _] = api_full_name.split("*") convert_type = check_need_convert(api_name) need_grad = True + if value.get("kwargs") and "out" in value.get("kwargs"): + need_grad = False if api_name[-1] == "_" or api_name in NO_GRAD_APIS: need_grad = False args, kwargs = gen_api_params(value, need_grad, convert_type) @@ -137,7 +139,7 @@ def _run_ut_parser(parser): parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", help=" Save compare failed api output.", required=False) parser.add_argument("-c", "--jit_compile", dest="jit_compile", help=" whether to turn on jit compile", - default=True, required=False) + default=False, required=False) parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set NPU device id to run ut", default=0, required=False) @@ -146,8 +148,7 @@ def _run_ut(): parser = argparse.ArgumentParser() _run_ut_parser(parser) args = parser.parse_args(sys.argv[1:]) - if not args.jit_compile: - torch.npu.set_compile_mode(jit_compile=False) + torch.npu.set_compile_mode(jit_compile=args.jit_compile) npu_device = "npu:" + str(args.device_id) try: torch.npu.set_device(npu_device)