From 2afcc2a85fb4a26dfb3551b00cf951db0d8aa039 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 1 Nov 2023 11:22:25 +0800 Subject: [PATCH 1/6] clearcode --- .../api_accuracy_checker/common/base_api.py | 6 +++--- .../api_accuracy_checker/common/config.py | 3 ++- .../api_accuracy_checker/common/utils.py | 17 +++++++++++++---- .../api_accuracy_checker/compare/algorithm.py | 15 ++++++++++++++- .../api_accuracy_checker/compare/compare.py | 9 +++++---- .../api_accuracy_checker/dump/api_info.py | 3 ++- .../api_accuracy_checker/dump/dump.py | 2 ++ .../api_accuracy_checker/dump/dump_scope.py | 1 + .../api_accuracy_checker/dump/info_dump.py | 4 +++- .../api_accuracy_checker/run_ut/run_ut.py | 4 +++- 10 files changed, 48 insertions(+), 16 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 2c3086184..64ce9e717 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -48,8 +48,8 @@ class BaseAPIInfo: single_arg.update({'type' : 'torch.Tensor'}) single_arg.update({'dtype' : str(arg.dtype)}) single_arg.update({'shape' : arg.shape}) - single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg,'max'), str(arg.dtype))}) - single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg,'min'), str(arg.dtype))}) + single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg, 'max'), str(arg.dtype))}) + single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg, 'min'), str(arg.dtype))}) single_arg.update({'requires_grad': arg.requires_grad}) else: @@ -87,7 +87,7 @@ class BaseAPIInfo: return float(data) def is_builtin_class(self, element): - if element is None or isinstance(element, (bool,int,float,str,slice)): + if element is None or isinstance(element, (bool, int, float, str, slice)): return True return False diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index c47911e21..36df4bb01 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -1,8 +1,9 @@ -import yaml import os +import yaml from api_accuracy_checker.common.utils import check_file_or_directory_path from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + class Config: def __init__(self, yaml_file): check_file_or_directory_path(yaml_file, False) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index b031b92a1..8c0cceebe 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -29,9 +29,6 @@ import numpy as np import torch import csv -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, FileOpen -from ptdbg_ascend.src.python.ptdbg_ascend.common import file_check_util - try: import torch_npu except ImportError: @@ -39,6 +36,9 @@ except ImportError: else: IS_GPU = False +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, FileOpen +from ptdbg_ascend.src.python.ptdbg_ascend.common import file_check_util + torch_without_guard_version_list = ['2.1'] for version in torch_without_guard_version_list: if torch.__version__.startswith(version): @@ -65,7 +65,7 @@ class Const: DOT = "." DUMP_RATIO_MAX = 100 SUMMERY_DATA_NUMS = 256 - ONE_HUNDRED_MB = 100*1024*1024 + ONE_HUNDRED_MB = 100 * 1024 * 1024 FLOAT_EPSILON = np.finfo(float).eps SUPPORT_DUMP_MODE = ['api', 'acl'] ON = 'ON' @@ -103,6 +103,7 @@ class Const: "int32_to_int64": ["cross_entropy"] } + class CompareConst: """ Class for compare module const @@ -191,19 +192,23 @@ class CompareException(Exception): def __str__(self): return self.error_info + class DumpException(CompareException): pass + def read_json(file): with FileOpen(file, 'r') as f: obj = json.load(f) return obj + def write_csv(data, filepath): with FileOpen(filepath, 'a') as f: writer = csv.writer(f) writer.writerows(data) + def _print_log(level, msg): current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) pid = os.getgid() @@ -296,6 +301,7 @@ def check_file_or_directory_path(path, isdir=False): 'The path {} does not have permission to read. Please check the path permission'.format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) + def _check_pkl(pkl_file_handle, file_name): tensor_line = pkl_file_handle.readline() if len(tensor_line) == 0: @@ -573,6 +579,7 @@ def check_need_convert(api_name): convert_type = key return convert_type + def api_info_preprocess(api_name, api_info_dict): """ Function Description: @@ -589,6 +596,7 @@ def api_info_preprocess(api_name, api_info_dict): api_info_dict = cross_entropy_process(api_info_dict) return convert_type, api_info_dict + def cross_entropy_process(api_info_dict): """ Function Description: @@ -603,6 +611,7 @@ def cross_entropy_process(api_info_dict): api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. return api_info_dict + def initialize_save_path(save_path, dir_name): data_path = os.path.join(save_path, dir_name) if os.path.exists(data_path): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index cca521a29..a8792225d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -1,5 +1,4 @@ # 定义比对算法及比对标准 - import torch import numpy as np from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable @@ -14,6 +13,7 @@ def compare_torch_tensor(cpu_output, npu_output, compare_alg): return compare_bool_tensor(cpu_output, npu_output) return compare_alg(cpu_output, npu_output) + def compare_bool_tensor(cpu_output, npu_output): cpu_shape = cpu_output.shape npu_shape = npu_output.shape @@ -23,6 +23,7 @@ def compare_bool_tensor(cpu_output, npu_output): error_rate = float(error_nums / cpu_output.size) return error_rate, error_rate == 0, "" + def get_msg_and_handle_value(b_value, n_value): msg = "" if not isinstance(b_value, np.ndarray) or not isinstance(n_value, np.ndarray): @@ -46,6 +47,7 @@ def get_msg_and_handle_value(b_value, n_value): b_value[zero_mask] += np.finfo(float).eps return b_value, n_value, msg + def get_max_rel_err(b_value, n_value): b_value, n_value, msg = get_msg_and_handle_value(b_value, n_value) rel_err = np.abs((n_value - b_value) / b_value).max() @@ -55,15 +57,18 @@ def get_max_rel_err(b_value, n_value): bool_result = rel_err < 0.001 return rel_err, bool_result, msg + def get_max_abs_err(b_value, n_value): b_value, n_value, msg = get_msg_and_handle_value(b_value, n_value) abs_err = np.abs(b_value - n_value).max() bool_result = abs_err < 0.001 return abs_err, bool_result, msg + def get_rel_err_ratio_thousandth(b_value, n_value): return get_rel_err_ratio(b_value, n_value, 0.001) + def get_rel_err_ratio_ten_thousandth(b_value, n_value): ratio, bool_result, msg = get_rel_err_ratio(b_value, n_value, 0.0001) if n_value.dtype == np.float16: @@ -71,6 +76,7 @@ def get_rel_err_ratio_ten_thousandth(b_value, n_value): return ratio, True, msg return ratio, bool_result, msg + def get_rel_err_ratio(b_value, n_value, thresholding): b_value, n_value, msg = get_msg_and_handle_value(b_value, n_value) rel_errs = np.abs((n_value - b_value) / b_value) @@ -78,14 +84,17 @@ def get_rel_err_ratio(b_value, n_value, thresholding): bool_result = ratio > (1 - thresholding) return ratio, bool_result, msg + def max_rel_err_standard(max_rel_errs): bool_result = np.array(max_rel_errs) < 0.001 return np.all(bool_result), bool_result + def cosine_standard(compare_result): bool_result = np.array(compare_result) > 0.99 return np.all(bool_result), bool_result + def cosine_sim(cpu_output, npu_output): msg = "" n_value = npu_output.reshape(-1) @@ -116,12 +125,14 @@ def cosine_sim(cpu_output, npu_output): msg = "Dump data has NaN when comparing with Cosine Similarity." return cos, cos > 0.99, msg + def compare_uint8_data(b_value, n_value): if (b_value == n_value).all(): return 1, True else: return 0, False + def compare_builtin_type(bench_out, npu_out): if not isinstance(bench_out, (bool, int, float, str)): return CompareConst.NA, True, "" @@ -129,6 +140,7 @@ def compare_builtin_type(bench_out, npu_out): return CompareConst.NAN, False, "" return True, True, "" + def flatten_compare_result(result): flatten_result = [] for result_i in result: @@ -138,6 +150,7 @@ def flatten_compare_result(result): flatten_result.append(result_i) return flatten_result + # 本函数用alg比对bench_out 和npu_out,返回详细比对结果compare_result和标志比对是否通过的布尔变量test_success def compare_core(bench_out, npu_out, alg): msg = "" diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index f3e8a4cf4..b33626321 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -8,6 +8,7 @@ from api_accuracy_checker.common.utils import get_json_contents, print_info_log, from api_accuracy_checker.compare.compare_utils import CompareConst from api_accuracy_checker.common.config import msCheckerConfig + class Comparator: TEST_FILE_NAME = "accuracy_checking_result.csv" DETAIL_TEST_FILE_NAME = "accuracy_checking_details.csv" @@ -174,14 +175,14 @@ class Comparator: if name == "Max Absolute Error": max_abs_error_success = test_success if detailed_result_total: - for i in range(len(detailed_result_total)): - detailed_result_total[i] += detailed_result[i] + for i, detailed_result_item in enumerate(detailed_result): + detailed_result_total[i] += detailed_result_item else: detailed_result_total = detailed_result test_success_total = test_success_total or max_abs_error_success # dtype加到所有指标的前面, 是否pass放到所有指标的后面 - for i in range(len(detailed_result_total)): - detailed_result = list(detailed_result_total[i]) + for i, detailed_result in enumerate(detailed_result_total): + detailed_result = list(detailed_result) detailed_result.insert(0, bench_dtype_total[i]) detailed_result.insert(1, npu_dtype_total[i]) detailed_result.insert(2, shape_total[i]) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 2a86699d8..ca0d4021a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -25,7 +25,8 @@ class ForwardAPIInfo(APIInfo): def analyze_api_call_stack(self): stack_str = [] for (_, path, line, func, code, _) in inspect.stack()[3:]: - if not code: continue + if not code: + continue stack_line = " ".join([ "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), " ".join(["\n", code[0].strip()])])]) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 2a69e226c..0120c10dd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -31,6 +31,7 @@ def set_dump_switch(switch): initialize_output_json() DumpUtil.set_dump_switch(switch) + class DumpUtil(object): dump_switch = None call_num = 0 @@ -74,6 +75,7 @@ def pretest_info_dump(name, out_feat, module, phase): write_api_info_json(api_info) + def pretest_hook(name, phase): def pretest_info_dump_hook(module, in_feat, out_feat): pretest_info_dump(name, out_feat, module, phase) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 17f94da19..85f555ed7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -4,6 +4,7 @@ from torch.utils.data.dataloader import _BaseDataLoaderIter from api_accuracy_checker.dump.dump import DumpUtil from api_accuracy_checker.common.config import msCheckerConfig + def iter_tracer(func): def func_wrapper(*args, **kwargs): DumpUtil.dump_switch = "OFF" diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 7eeeeb590..05354226f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -11,6 +11,7 @@ from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen lock = threading.Lock() + def write_api_info_json(api_info): dump_path = msCheckerConfig.dump_path rank = api_info.rank @@ -26,8 +27,9 @@ def write_api_info_json(api_info): else: raise ValueError(f"Invalid api_info type {type(api_info)}") + def write_json(file_path, data, indent=None): - check_file_or_directory_path(os.path.dirname(file_path),True) + check_file_or_directory_path(os.path.dirname(file_path), True) if not os.path.exists(file_path): with FileOpen(file_path, 'w') as f: f.write("{\n}") diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 536d607dd..5f9e27216 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -67,8 +67,10 @@ def generate_npu_params(input_args, input_kwargs, need_backward): npu_kwargs = {key: recursive_arg_to_npu(value) for key, value in input_kwargs.items()} return npu_args, npu_kwargs + def generate_cpu_params(input_args, input_kwargs, need_backward): first_dtype = None + def recursive_arg_to_cpu(arg_in): nonlocal first_dtype if isinstance(arg_in, (list, tuple)): @@ -99,6 +101,7 @@ def generate_cpu_params(input_args, input_kwargs, need_backward): cpu_kwargs = {key: recursive_arg_to_cpu(value) for key, value in input_kwargs.items()} return cpu_args, cpu_kwargs + def run_ut(forward_file, backward_file, out_path, save_error_data): print_info_log("start UT test") forward_content = get_json_contents(forward_file) @@ -140,7 +143,6 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) UtAPIInfo(api_full_name + '.backward.output.npu', data_info.npu_grad_out) - def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): in_fwd_data_list = [] [api_type, api_name, _] = api_full_name.split("*") -- Gitee From c7cfcd5ffcbb3b87929388020d2cee2721c0a907 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 1 Nov 2023 15:21:06 +0800 Subject: [PATCH 2/6] clearcode --- debug/accuracy_tools/api_accuracy_checker/common/base_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 64ce9e717..516c5ebf8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -49,7 +49,7 @@ class BaseAPIInfo: single_arg.update({'dtype' : str(arg.dtype)}) single_arg.update({'shape' : arg.shape}) single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg, 'max'), str(arg.dtype))}) - single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg, 'min'), str(arg.dtype))}) + single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg, 'min'), str(arg.dtype))}) single_arg.update({'requires_grad': arg.requires_grad}) else: -- Gitee From d0b114d865f996808866bfc5252ebbabfb4d0545 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Thu, 2 Nov 2023 01:56:39 +0000 Subject: [PATCH 3/6] =?UTF-8?q?update=20debug/accuracy=5Ftools/api=5Faccur?= =?UTF-8?q?acy=5Fchecker/test/ut/run=5Fut/test=5Frun=5Fut.py.=20=E8=A7=A3?= =?UTF-8?q?=E9=99=A4npu=E4=BE=9D=E8=B5=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: sunyiming --- .../test/ut/run_ut/test_run_ut.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py index 9f464cab4..1f45e6199 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py @@ -25,17 +25,6 @@ class TestRunUtMethods(unittest.TestCase): self.assertEqual(out.requires_grad, True) self.assertEqual(out.shape, torch.Size([2, 2560, 24, 24])) - def test_generate_npu_params(self): - api_info = copy.deepcopy(api_info_dict) - [api_type, api_name, _] = api_full_name.split("*") - args, kwargs, need_grad = get_api_info(api_info, api_name) - npu_args, npu_kwargs = generate_npu_params(args, kwargs, True) - self.assertEqual(len(npu_args), 1) - self.assertEqual(npu_args[0].dtype, torch.float16) - self.assertEqual(npu_args[0].requires_grad, True) - self.assertEqual(npu_args[0].shape, torch.Size([2, 2560, 24, 24])) - self.assertEqual(npu_kwargs, {'inplace': False}) - def test_generate_cpu_params(self): api_info = copy.deepcopy(api_info_dict) [api_type, api_name, _] = api_full_name.split("*") -- Gitee From 55edf6dd056e31ce06b1767b335d03f00b8633cd Mon Sep 17 00:00:00 2001 From: s30048155 Date: Thu, 2 Nov 2023 15:57:30 +0800 Subject: [PATCH 4/6] rename --- debug/accuracy_tools/api_accuracy_checker/compare/compare.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index b33626321..0c8829b62 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -181,8 +181,8 @@ class Comparator: detailed_result_total = detailed_result test_success_total = test_success_total or max_abs_error_success # dtype加到所有指标的前面, 是否pass放到所有指标的后面 - for i, detailed_result in enumerate(detailed_result_total): - detailed_result = list(detailed_result) + for i, detailed_tuple in enumerate(detailed_result_total): + detailed_result = list(detailed_tuple) detailed_result.insert(0, bench_dtype_total[i]) detailed_result.insert(1, npu_dtype_total[i]) detailed_result.insert(2, shape_total[i]) -- Gitee From 76a221ffb494681bc0b662fb325f5dc7c2f6522e Mon Sep 17 00:00:00 2001 From: s30048155 Date: Thu, 2 Nov 2023 17:44:55 +0800 Subject: [PATCH 5/6] revert --- .../test/ut/run_ut/test_run_ut.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py index 1f45e6199..ffe9fc52b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py @@ -24,7 +24,18 @@ class TestRunUtMethods(unittest.TestCase): self.assertEqual(out.dtype, torch.float32) self.assertEqual(out.requires_grad, True) self.assertEqual(out.shape, torch.Size([2, 2560, 24, 24])) - + + def test_generate_npu_params(self): + api_info = copy.deepcopy(api_info_dict) + [api_type, api_name, _] = api_full_name.split("*") + args, kwargs, need_grad = get_api_info(api_info, api_name) + npu_args, npu_kwargs = generate_npu_params(args, kwargs, True) + self.assertEqual(len(npu_args), 1) + self.assertEqual(npu_args[0].dtype, torch.float16) + self.assertEqual(npu_args[0].requires_grad, True) + self.assertEqual(npu_args[0].shape, torch.Size([2, 2560, 24, 24])) + self.assertEqual(npu_kwargs, {'inplace': False}) + def test_generate_cpu_params(self): api_info = copy.deepcopy(api_info_dict) [api_type, api_name, _] = api_full_name.split("*") -- Gitee From 85ae5b205fd0b60ed5561fa494e81300d69196be Mon Sep 17 00:00:00 2001 From: s30048155 Date: Thu, 2 Nov 2023 17:45:44 +0800 Subject: [PATCH 6/6] revert --- .../api_accuracy_checker/test/ut/run_ut/test_run_ut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py index ffe9fc52b..9f464cab4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/run_ut/test_run_ut.py @@ -24,7 +24,7 @@ class TestRunUtMethods(unittest.TestCase): self.assertEqual(out.dtype, torch.float32) self.assertEqual(out.requires_grad, True) self.assertEqual(out.shape, torch.Size([2, 2560, 24, 24])) - + def test_generate_npu_params(self): api_info = copy.deepcopy(api_info_dict) [api_type, api_name, _] = api_full_name.split("*") @@ -35,7 +35,7 @@ class TestRunUtMethods(unittest.TestCase): self.assertEqual(npu_args[0].requires_grad, True) self.assertEqual(npu_args[0].shape, torch.Size([2, 2560, 24, 24])) self.assertEqual(npu_kwargs, {'inplace': False}) - + def test_generate_cpu_params(self): api_info = copy.deepcopy(api_info_dict) [api_type, api_name, _] = api_full_name.split("*") -- Gitee