From b271cd938cc2a02b772ef2cbf28067749ea30e5e Mon Sep 17 00:00:00 2001 From: wangchao Date: Mon, 7 Aug 2023 18:47:35 +0800 Subject: [PATCH 1/4] add api has out, no grad --- .../api_accuracy_checker/run_ut/data_generate.py | 4 ++-- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 4200b3b3b1..aa48ce5edd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -239,6 +239,6 @@ def gen_api_params(api_info, need_grad=True, convert_type=None): if api_info.get("args"): args_params = gen_args(api_info.get("args"), need_grad, convert_type) else: - print_error_log(f'Warning: No args in {api_info} ') - raise NotImplementedError() + print_warn_log(f'Warning: No args in {api_info} ') + args_params = [] return args_params, kwargs_params diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index a5c96b207e..8a48d1fcf5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -73,6 +73,8 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, value): [api_type, api_name, _] = api_full_name.split("*") convert_type = check_need_convert(api_name) need_grad = True + if value.get("kwargs") and "out" in value.get("kwargs"): + need_grad = False if api_name[-1] == "_" or api_name in NO_GRAD_APIS: need_grad = False args, kwargs = gen_api_params(value, need_grad, convert_type) -- Gitee From 2f1a644b1a4786967d22bec65bcf70ebb0b78058 Mon Sep 17 00:00:00 2001 From: wangchao Date: Mon, 7 Aug 2023 18:58:55 +0800 Subject: [PATCH 2/4] add api has out, no grad --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 8a48d1fcf5..ec091e19b2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -139,7 +139,7 @@ def _run_ut_parser(parser): parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", help=" Save compare failed api output.", required=False) parser.add_argument("-c", "--jit_compile", dest="jit_compile", help=" whether to turn on jit compile", - default=True, required=False) + default=False, required=False) parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set NPU device id to run ut", default=0, required=False) @@ -148,8 +148,8 @@ def _run_ut(): parser = argparse.ArgumentParser() _run_ut_parser(parser) args = parser.parse_args(sys.argv[1:]) - if not args.jit_compile: - torch.npu.set_compile_mode(jit_compile=False) + if args.jit_compile: + torch.npu.set_compile_mode(jit_compile=True) npu_device = "npu:" + str(args.device_id) try: torch.npu.set_device(npu_device) -- Gitee From 925be82028d565c1d837e76f845b15a6f10af6a5 Mon Sep 17 00:00:00 2001 From: wangchao Date: Mon, 7 Aug 2023 20:29:22 +0800 Subject: [PATCH 3/4] add api has out, no grad --- debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 0bb30cc72d..51d2c0fbfb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -52,6 +52,9 @@ def cosine_sim(cpu_output, npu_output): if len(n_value) == 1: print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") return get_max_rel_err(n_value, b_value) + if len(n_value) == len(b_value) == 0: + print_warn_log("The npu dump data and bench dump data is empty.") + return cos, True if n_value.dtype == np.uint8: return compare_uint8_data(n_value, b_value) n_value = n_value / (np.max(np.abs(n_value)) + np.finfo(n_value.dtype).eps) -- Gitee From 6be076e9349f159daaea9325d7df22210631931f Mon Sep 17 00:00:00 2001 From: wangchao Date: Tue, 8 Aug 2023 10:23:34 +0800 Subject: [PATCH 4/4] add api has out, no grad --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index ec091e19b2..6ec357f62e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -148,8 +148,7 @@ def _run_ut(): parser = argparse.ArgumentParser() _run_ut_parser(parser) args = parser.parse_args(sys.argv[1:]) - if args.jit_compile: - torch.npu.set_compile_mode(jit_compile=True) + torch.npu.set_compile_mode(jit_compile=args.jit_compile) npu_device = "npu:" + str(args.device_id) try: torch.npu.set_device(npu_device) -- Gitee