From 96f3fb378c1746cc7e2c7d4009065c04b55f91c4 Mon Sep 17 00:00:00 2001 From: wangchao Date: Wed, 9 Aug 2023 03:56:56 +0000 Subject: [PATCH] =?UTF-8?q?=E9=92=88=E5=AF=B9cpu=E4=B8=8D=E6=94=AF?= =?UTF-8?q?=E6=8C=81half=E7=9A=84api=E5=81=9A=E6=8F=90=E7=A4=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: wangchao --- .../api_accuracy_checker/run_ut/run_ut.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 6ec357f62..e6162ca8b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -61,9 +61,17 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): api_setting_dict = get_json_contents("torch_ut_setting.json") compare = Comparator(out_path) for api_full_name, api_info_dict in forward_content.items(): - grad_out, npu_grad_out, npu_out, out = run_torch_api(api_full_name, api_setting_dict, backward_content, - api_info_dict) - compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) + try: + grad_out, npu_grad_out, npu_out, out = run_torch_api(api_full_name, api_setting_dict, backward_content, + api_info_dict) + compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) + except Exception as err: + if "not implemented for 'Half'" in str(err): + [_, api_name, _] = api_full_name.split("*") + print_warn_log(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API " + f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.") + else: + print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) compare.print_pretest_result() compare.write_compare_csv() -- Gitee