From 2a11d867dd4a68876cfcc9db072f9e9605935b4e Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 29 Aug 2023 21:32:23 +0800 Subject: [PATCH 1/6] =?UTF-8?q?=E7=9C=9F=E5=80=BC=E6=AF=94=E5=AF=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/common/utils.py | 7 ++-- .../api_accuracy_checker/compare/algorithm.py | 8 ++-- .../api_accuracy_checker/run_ut/run_ut.py | 42 +++++++++++++++---- 3 files changed, 43 insertions(+), 14 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index a7922eadf..9b70b154a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -81,12 +81,13 @@ class Const: WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR CONVERT = { - "fp16_to_fp32": ["torch.float16", "torch.float32"], - "int32_to_int64": ["torch.int32", "torch.int64"] + "int32_to_int64": ["torch.int32", "torch.int64"], + "torch.float16" : "torch.float32", + "torch.bfloat16" : "torch.float32", + "torch.float32" : "torch.float64" } CONVERT_API = { - "fp16_to_fp32": ["conv2d", "batch_norm", "relu", "max_pool2d", "interpolate", "group_norm", "layer_norm", "bmm", "tanh", "cross_entropy", "linear", "numel"], "int32_to_int64": ["cross_entropy"] } diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 29b10608e..e3dabc432 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -150,11 +150,11 @@ def flatten_compare_result(result): def compare_core(bench_out, npu_out, alg): msg = "" if not isinstance(bench_out, type(npu_out)): - return CompareConst.NAN, False, "bench and npu output type is different.", CompareConst.NAN, CompareConst.NAN + return CompareConst.NAN, False, "bench and npu output type is different.", CompareConst.NAN if isinstance(bench_out, (list, tuple)): compare_result, test_success, bench_dtype, npu_dtype = [], True, [], [] if len(bench_out) != len(npu_out): - return CompareConst.NAN, False, "bench and npu output structure is different", CompareConst.NAN, CompareConst.NAN + return CompareConst.NAN, False, "bench and npu output structure is different", CompareConst.NAN for b_out_i, n_out_i in zip(bench_out, npu_out): compare_result_i, test_success_i, bench_dtype_i, npu_dtype_i = compare_core(b_out_i, n_out_i, alg) compare_result.append(compare_result_i) @@ -165,9 +165,11 @@ def compare_core(bench_out, npu_out, alg): b_keys, n_keys = set(bench_out.keys()), set(npu_out.keys()) if b_keys != n_keys: compare_result, test_success, msg = CompareConst.NAN, False, "bench and npu output dict keys are different", \ - CompareConst.NAN, CompareConst.NAN + CompareConst.NAN compare_result, test_success, bench_dtype, npu_dtype = compare_core(list(bench_out.values()), list(npu_out.values()), alg) elif isinstance(bench_out, torch.Tensor): + if bench_out.dtype in [torch.float32, torch.float64] and bench_out.dtype != npu_out.dtype: + npu_out = npu_out.type(bench_out.dtype) compare_result, test_success, msg = compare_torch_tensor(bench_out.detach().numpy(), npu_out.detach().cpu().numpy(), alg) bench_dtype = str(bench_out.dtype) npu_dtype = str(npu_out.dtype) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 813d0cb58..2cb15f3a4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -7,7 +7,7 @@ import torch from tqdm import tqdm from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ - print_error_log, check_file_or_directory_path, initialize_save_path + print_error_log, check_file_or_directory_path, initialize_save_path, Const from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate @@ -44,7 +44,7 @@ def exec_api(api_type, api_name, args, kwargs): return out -def generate_npu_params(cpu_args, cpu_kwargs, need_backward): +def generate_npu_params(input_args, input_kwargs, need_backward): def recursive_arg_to_npu(arg_in): if isinstance(arg_in, (list, tuple)): return type(arg_in)(recursive_arg_to_npu(arg) for arg in arg_in) @@ -60,10 +60,34 @@ def generate_npu_params(cpu_args, cpu_kwargs, need_backward): else: return arg_in - npu_args = recursive_arg_to_npu(cpu_args) - npu_kwargs = {key: recursive_arg_to_npu(value) for key, value in cpu_kwargs.items()} + npu_args = recursive_arg_to_npu(input_args) + npu_kwargs = {key: recursive_arg_to_npu(value) for key, value in input_kwargs.items()} return npu_args, npu_kwargs +def generate_cpu_params(input_args, input_kwargs, need_backward): + def recursive_arg_to_cpu(arg_in): + if isinstance(arg_in, (list, tuple)): + return type(arg_in)(recursive_arg_to_cpu(arg) for arg in arg_in) + elif isinstance(arg_in, torch.Tensor): + if need_backward and arg_in.requires_grad: + if str(arg_in.dtype) in Const.CONVERT.keys(): + arg_in = arg_in.clone().type(eval(Const.CONVERT[str(arg_in.dtype)])).detach().requires_grad_() + else: + arg_in = arg_in.clone().detach().requires_grad_() + temp_arg_in = arg_in * 1 + arg_in = temp_arg_in.type_as(arg_in) + arg_in.retain_grad() + return arg_in + else: + if str(arg_in.dtype) in Const.CONVERT.keys(): + return arg_in.clone().type(eval(Const.CONVERT[str(arg_in.dtype)])).detach() + return arg_in.clone().detach() + else: + return arg_in + + npu_args = recursive_arg_to_cpu(input_args) + npu_kwargs = {key: recursive_arg_to_cpu(value) for key, value in input_kwargs.items()} + return npu_args, npu_kwargs def run_ut(forward_file, backward_file, out_path, save_error_data): print_info_log("start UT test") @@ -118,11 +142,12 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di need_backward = need_backward and need_grad if inplace or not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) + cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) grad_out, npu_grad_out = None, None if kwargs.get("device"): del kwargs["device"] - out = exec_api(api_type, api_name, args, kwargs) + out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) grad_input_index = api_setting_dict.get(api_name) grad_index = None @@ -131,7 +156,7 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di grad_index = grad_input_index.get('grad_index') if need_backward: - grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, args, backward_content, grad_index, npu_args, + grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, cpu_args, backward_content, grad_index, npu_args, npu_out, out) if grad_index is not None: return UtDataInfo(grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], grad, in_fwd_data_list) @@ -153,12 +178,13 @@ def get_api_info(api_info_dict, api_name): def run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out): backward_args = backward_content[api_full_name] grad = gen_args(backward_args)[0] + cpu_grad, _ = generate_cpu_params(grad, {}, False) if grad_index is not None: - out[grad_index].backward(grad) + out[grad_index].backward(cpu_grad) elif isinstance(out, (list, tuple)): raise NotImplementedError("Multiple backward is not supported.") else: - out.backward(grad) + out.backward(cpu_grad) args_grad = [] for arg in args: if isinstance(arg, torch.Tensor): -- Gitee From 09195bee7f819d1909333c7d93fd756e9440291e Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 30 Aug 2023 15:18:07 +0800 Subject: [PATCH 2/6] fix --- debug/accuracy_tools/api_accuracy_checker/common/utils.py | 8 +++++--- .../api_accuracy_checker/compare/algorithm.py | 4 ++-- .../accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 8 ++++---- 3 files changed, 11 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 9b70b154a..94907f014 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -79,13 +79,15 @@ class Const: API_PATTERN = r"^[A-Za-z0-9]+[_]+([A-Za-z0-9]+[_]*[A-Za-z0-9]+)[_]+[0-9]+[_]+[A-Za-z0-9]+" WRITE_FLAGS = os.O_WRONLY | os.O_CREAT WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR - - CONVERT = { - "int32_to_int64": ["torch.int32", "torch.int64"], + + RAISE_PRECISION = { "torch.float16" : "torch.float32", "torch.bfloat16" : "torch.float32", "torch.float32" : "torch.float64" } + CONVERT = { + "int32_to_int64": ["torch.int32", "torch.int64"], + } CONVERT_API = { "int32_to_int64": ["cross_entropy"] diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index e3dabc432..30b8d9403 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -168,11 +168,11 @@ def compare_core(bench_out, npu_out, alg): CompareConst.NAN compare_result, test_success, bench_dtype, npu_dtype = compare_core(list(bench_out.values()), list(npu_out.values()), alg) elif isinstance(bench_out, torch.Tensor): + bench_dtype = str(bench_out.dtype) + npu_dtype = str(npu_out.dtype) if bench_out.dtype in [torch.float32, torch.float64] and bench_out.dtype != npu_out.dtype: npu_out = npu_out.type(bench_out.dtype) compare_result, test_success, msg = compare_torch_tensor(bench_out.detach().numpy(), npu_out.detach().cpu().numpy(), alg) - bench_dtype = str(bench_out.dtype) - npu_dtype = str(npu_out.dtype) elif isinstance(bench_out, (bool, int, float, str)): compare_result, test_success, msg = compare_builtin_type(bench_out, npu_out) bench_dtype = str(type(bench_out)) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 2cb15f3a4..637013d3c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -70,8 +70,8 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): return type(arg_in)(recursive_arg_to_cpu(arg) for arg in arg_in) elif isinstance(arg_in, torch.Tensor): if need_backward and arg_in.requires_grad: - if str(arg_in.dtype) in Const.CONVERT.keys(): - arg_in = arg_in.clone().type(eval(Const.CONVERT[str(arg_in.dtype)])).detach().requires_grad_() + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys(): + arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach().requires_grad_() else: arg_in = arg_in.clone().detach().requires_grad_() temp_arg_in = arg_in * 1 @@ -79,8 +79,8 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): arg_in.retain_grad() return arg_in else: - if str(arg_in.dtype) in Const.CONVERT.keys(): - return arg_in.clone().type(eval(Const.CONVERT[str(arg_in.dtype)])).detach() + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys(): + return arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach() return arg_in.clone().detach() else: return arg_in -- Gitee From f75b0e6d2f7d7725232db35c8756ad1120a5ff20 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 30 Aug 2023 17:01:57 +0800 Subject: [PATCH 3/6] fix --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 637013d3c..9b3c66559 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -85,9 +85,9 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): else: return arg_in - npu_args = recursive_arg_to_cpu(input_args) - npu_kwargs = {key: recursive_arg_to_cpu(value) for key, value in input_kwargs.items()} - return npu_args, npu_kwargs + cpu_args = recursive_arg_to_cpu(input_args) + cpu_kwargs = {key: recursive_arg_to_cpu(value) for key, value in input_kwargs.items()} + return cpu_args, cpu_kwargs def run_ut(forward_file, backward_file, out_path, save_error_data): print_info_log("start UT test") -- Gitee From 1badf08d1d1c77095c32b9794ccd307cd979f8d9 Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 30 Aug 2023 20:04:54 +0800 Subject: [PATCH 4/6] fix --- .../api_accuracy_checker/compare/algorithm.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 30b8d9403..7a2a0ad63 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -154,7 +154,7 @@ def compare_core(bench_out, npu_out, alg): if isinstance(bench_out, (list, tuple)): compare_result, test_success, bench_dtype, npu_dtype = [], True, [], [] if len(bench_out) != len(npu_out): - return CompareConst.NAN, False, "bench and npu output structure is different", CompareConst.NAN + return [(CompareConst.NAN, "bench and npu output structure is different")], False, CompareConst.NA, CompareConst.NA for b_out_i, n_out_i in zip(bench_out, npu_out): compare_result_i, test_success_i, bench_dtype_i, npu_dtype_i = compare_core(b_out_i, n_out_i, alg) compare_result.append(compare_result_i) @@ -164,8 +164,9 @@ def compare_core(bench_out, npu_out, alg): elif isinstance(bench_out, dict): b_keys, n_keys = set(bench_out.keys()), set(npu_out.keys()) if b_keys != n_keys: - compare_result, test_success, msg = CompareConst.NAN, False, "bench and npu output dict keys are different", \ - CompareConst.NAN + compare_result, test_success, bench_dtype, npu_dtype = [(CompareConst.NAN, "bench and npu output dict keys are different")], False, \ + CompareConst.NA, CompareConst.NA + return compare_result, test_success, bench_dtype, npu_dtype compare_result, test_success, bench_dtype, npu_dtype = compare_core(list(bench_out.values()), list(npu_out.values()), alg) elif isinstance(bench_out, torch.Tensor): bench_dtype = str(bench_out.dtype) -- Gitee From a1c6c77ff7272a0420fcc6dca9d162f74db87a8f Mon Sep 17 00:00:00 2001 From: gitee Date: Wed, 30 Aug 2023 20:12:40 +0800 Subject: [PATCH 5/6] fix --- debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 7a2a0ad63..3fffddc2d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -150,7 +150,7 @@ def flatten_compare_result(result): def compare_core(bench_out, npu_out, alg): msg = "" if not isinstance(bench_out, type(npu_out)): - return CompareConst.NAN, False, "bench and npu output type is different.", CompareConst.NAN + return [(CompareConst.NAN, "bench and npu output type is different.")], False, CompareConst.NA, CompareConst.NA if isinstance(bench_out, (list, tuple)): compare_result, test_success, bench_dtype, npu_dtype = [], True, [], [] if len(bench_out) != len(npu_out): -- Gitee From 5dae117bd4c517e34054104322eaf46ef00f6733 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 31 Aug 2023 11:51:16 +0800 Subject: [PATCH 6/6] fix --- .../accuracy_tools/api_accuracy_checker/compare/algorithm.py | 1 - debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 5 +---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 3fffddc2d..b0b1aaf60 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -166,7 +166,6 @@ def compare_core(bench_out, npu_out, alg): if b_keys != n_keys: compare_result, test_success, bench_dtype, npu_dtype = [(CompareConst.NAN, "bench and npu output dict keys are different")], False, \ CompareConst.NA, CompareConst.NA - return compare_result, test_success, bench_dtype, npu_dtype compare_result, test_success, bench_dtype, npu_dtype = compare_core(list(bench_out.values()), list(npu_out.values()), alg) elif isinstance(bench_out, torch.Tensor): bench_dtype = str(bench_out.dtype) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 9b3c66559..27efa4bd3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -105,10 +105,7 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: [_, api_name, _] = api_full_name.split("*") - if "not implemented for 'Half'" in str(err): - print_warn_log(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API " - f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.") - elif "expected scalar type Long" in str(err): + if "expected scalar type Long" in str(err): print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") else: -- Gitee