From 8fc1d08f5825ab91b97d2ba615a707ccbc2b2918 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 11 Jan 2024 19:09:47 +0800 Subject: [PATCH 1/5] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9numpy=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E7=B1=BB=E5=9E=8B=E5=88=86=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/dump/api_info.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index adb0c4b0f3..25b87dee3a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -2,6 +2,7 @@ import os import inspect import torch +import numpy as np from api_accuracy_checker.common.config import msCheckerConfig from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory, DumpException from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create @@ -92,6 +93,10 @@ class APIInfo: else: out_dict[key] = self.analyze_element(value) return out_dict + + converted_numpy = self._convert_numpy_to_builtin(element) + if converted_numpy is not element: + return self._analyze_builtin(converted_numpy) if isinstance(element, torch.Tensor): return self._analyze_tensor(element) @@ -135,6 +140,21 @@ class APIInfo: single_arg.update({'type': get_type_name(str(type(arg)))}) single_arg.update({'value': arg}) return single_arg + + def _convert_numpy_to_builtin(self, arg): + type_mapping = { + np.integer: int, + np.floating: float, + np.bool_: bool, + np.complexfloating: complex, + np.str_: str, + np.bytes_: bytes, + np.unicode_: str + } + for numpy_type, native_type in type_mapping.items(): + if isinstance(arg, numpy_type): + return native_type(arg) + return arg class ForwardAPIInfo(APIInfo): -- Gitee From d09300b5d043ab9725d22deec9f38520d533400a Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 11 Jan 2024 19:11:45 +0800 Subject: [PATCH 2/5] =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E5=AF=B9numpy=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E7=B1=BB=E5=9E=8B=E5=88=86=E6=9E=90?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 25b87dee3a..ddf058d3d2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -151,9 +151,9 @@ class APIInfo: np.bytes_: bytes, np.unicode_: str } - for numpy_type, native_type in type_mapping.items(): + for numpy_type, builtin_type in type_mapping.items(): if isinstance(arg, numpy_type): - return native_type(arg) + return builtin_type(arg) return arg -- Gitee From 38161952de8d637d47f49fde08af8e4a3e2645f0 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 18 Jan 2024 15:08:44 +0800 Subject: [PATCH 3/5] =?UTF-8?q?=E6=96=B9=E6=A1=88=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/dump/api_info.py | 16 ++++++++++++---- .../api_accuracy_checker/run_ut/data_generate.py | 9 ++++++--- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index ddf058d3d2..12d59820c6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -94,9 +94,9 @@ class APIInfo: out_dict[key] = self.analyze_element(value) return out_dict - converted_numpy = self._convert_numpy_to_builtin(element) + converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) if converted_numpy is not element: - return self._analyze_builtin(converted_numpy) + return self._analyze_numpy(converted_numpy, numpy_type) if isinstance(element, torch.Tensor): return self._analyze_tensor(element) @@ -141,6 +141,14 @@ class APIInfo: single_arg.update({'value': arg}) return single_arg + def _analyze_numpy(self, value, numpy_type): + single_arg = {} + if self.is_save_data: + self.args_num += 1 + single_arg.update({'type': numpy_type}) + single_arg.update({'value': value}) + return single_arg + def _convert_numpy_to_builtin(self, arg): type_mapping = { np.integer: int, @@ -153,8 +161,8 @@ class APIInfo: } for numpy_type, builtin_type in type_mapping.items(): if isinstance(arg, numpy_type): - return builtin_type(arg) - return arg + return builtin_type(arg), get_type_name(str(type(arg))) + return arg, '' class ForwardAPIInfo(APIInfo): diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 5765f980d2..ead6b55179 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -17,7 +17,7 @@ import os import torch -import numpy as np +import numpy from api_accuracy_checker.common.utils import Const, check_file_or_directory_path, check_object_type, print_warn_log, print_error_log, \ CompareException @@ -50,6 +50,9 @@ def gen_data(info, need_grad, convert_type): temp_data = data * 1 data = temp_data.type_as(data) data.retain_grad() + elif data_type.startswith("numpy"): + data = info.get("value") + data = eval(data_type)(data) else: data = info.get('value') if info.get("type") == "slice": @@ -73,7 +76,7 @@ def gen_real_tensor(data_path, convert_type): if data_path.endswith('.pt'): data = torch.load(data_path) else: - data_np = np.load(data_path) + data_np = numpy.load(data_path) data = torch.from_numpy(data_np) if convert_type: ori_dtype = Const.CONVERT.get(convert_type)[0] @@ -193,7 +196,7 @@ def gen_kwargs(api_info, convert_type=None): for key, value in kwargs_params.items(): if isinstance(value, (list, tuple)): kwargs_params[key] = gen_list_kwargs(value, convert_type) - elif value.get('type') in TENSOR_DATA_LIST: + elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"): kwargs_params[key] = gen_data(value, False, convert_type) elif value.get('type') in TORCH_TYPE: gen_torch_kwargs(kwargs_params, key, value) -- Gitee From 869d5c68f9f40e699b22fd5a66a47bbdfabe0b2e Mon Sep 17 00:00:00 2001 From: gitee Date: Mon, 22 Jan 2024 16:04:02 +0800 Subject: [PATCH 4/5] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=A3=80=E8=A7=86?= =?UTF-8?q?=E6=84=8F=E8=A7=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/run_ut/data_generate.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index ead6b55179..3aa1496e63 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -26,6 +26,9 @@ TORCH_TYPE = ["torch.device", "torch.dtype"] TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] FLOAT_TYPE = ['torch.float32', 'torch.float', 'torch.float64', 'torch.double', 'torch.float16', 'torch.half', 'torch.bfloat16'] +NUMPY_TYPE = ["numpy.int8", "numpy.int16", "numpy.int32", "numpy.int64", "numpy.uint8", "numpy.uint16", "numpy.uint32", + "numpy.uint64", "numpy.float16", "numpy.float32", "numpy.float64", "numpy.float128", "numpy.complex64", + "numpy.complex128", "numpy.complex256", "numpy.bool_", "numpy.string_", "numpy.bytes_", "numpy.unicode_"] def gen_data(info, need_grad, convert_type): @@ -52,7 +55,10 @@ def gen_data(info, need_grad, convert_type): data.retain_grad() elif data_type.startswith("numpy"): data = info.get("value") - data = eval(data_type)(data) + try: + data = eval(data_type)(data) + except Exception as err: + print_error_log("Failed to convert the type to numpy: %s" % str(err)) else: data = info.get('value') if info.get("type") == "slice": -- Gitee From 2998950c43945ccbeee53cd7c85456dbb2cbda61 Mon Sep 17 00:00:00 2001 From: gitee Date: Mon, 22 Jan 2024 16:09:38 +0800 Subject: [PATCH 5/5] fix --- .../accuracy_tools/api_accuracy_checker/run_ut/data_generate.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 3aa1496e63..21bc23cfb8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -54,6 +54,8 @@ def gen_data(info, need_grad, convert_type): data = temp_data.type_as(data) data.retain_grad() elif data_type.startswith("numpy"): + if data_type not in NUMPY_TYPE: + raise Exception("{} is not supported now".format(data_type)) data = info.get("value") try: data = eval(data_type)(data) -- Gitee