From 7c4218b975372c9dbe90d1a11d46039699f63b1f Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Thu, 27 Jul 2023 08:42:07 +0000 Subject: [PATCH 1/2] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=8A=BD=E5=8F=96api?= =?UTF-8?q?=E4=BF=A1=E6=81=AF=E5=9F=BA=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: jiangchangting1 --- precision/api_ut_tools/dump/api_info.py | 89 +++++++++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/precision/api_ut_tools/dump/api_info.py b/precision/api_ut_tools/dump/api_info.py index 5c6138f6e4..8e8adf1a05 100644 --- a/precision/api_ut_tools/dump/api_info.py +++ b/precision/api_ut_tools/dump/api_info.py @@ -1,6 +1,95 @@ # 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 +import inspect +import torch +from .utils import DumpUtil, DumpConst, write_npy +from ..common.utils import print_error_log class APIInfo: def __init__(self, api_name): self.api_name = api_name + self.save_real_data = DumpUtil.save_real_data + self.buildin_class = [] + + def analyze_element(self, element): + if isinstance(element, (list, tuple)): + out = [] + for item in element: + out.append(self.analyze_element(item)) + elif isinstance(element, dict): + out = {} + for key, value in element.items(): + out[key] = self.analyze_element(value) + elif isinstance(element, torch.Tensor): + out = self.analyze_tensor(element, self.save_real_data) + + elif self.is_builtin_class(element): + out = self.analyze_builtin(element) + else: + msg = f"Type {type(element)} is unsupported at analyze_element" + print_error_log(msg) + return None + # raise NotImplementedError(msg) + return out + + def analyze_tensor(self, arg, save_real_data): + single_arg = {} + if not save_real_data: + + single_arg.update({'type' : 'torch.Tensor'}) + single_arg.update({'dtype' : str(arg.dtype)}) + single_arg.update({'shape' : arg.shape}) + single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg,'max'), str(arg.dtype))}) + single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg,'min'), str(arg.dtype))}) + single_arg.update({'requires_grad': arg.requires_grad}) + + else: + npy_path = write_npy(self.api_name, arg.contiguous().cpu().detach().numpy()) + single_arg.update({'type' : 'torch.Tensor'}) + single_arg.update({'datapath' : npy_path}) + single_arg.update({'requires_grad': arg.requires_grad}) + return single_arg + + def analyze_builtin(self, arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({'type' : "slice"}) + single_arg.update({'value' : [arg.start, arg.stop, arg.step]}) + else: + single_arg.update({'type' : self.get_type_name(str(type(arg)))}) + single_arg.update({'value' : arg}) + return single_arg + + def transfer_types(self, data, dtype): + if 'bool' in dtype: + breakpoint() + if 'int' in dtype or 'bool' in dtype: + return int(data) + else: + return float(data) + + def is_builtin_class(self, element): + if element is None or isinstance(element, (bool,int,float,str,slice)): + return True + return False + + + def get_tensor_extremum(self, data, operator): + if data.dtype is torch.bool: + if operator == 'max': + return True in data + elif operator == 'min': + if False in data: + return False + else: + return True + if operator == 'max': + return torch._C._VariableFunctionsClass.max(data).item() + else: + return torch._C._VariableFunctionsClass.min(data).item() + + def get_type_name(self, name): + + left = name.index("'") + right = name.rindex("'") + return name[left + 1 : right] -- Gitee From 50d15d7308c3db89ac68f2e0969c3236d732af83 Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Thu, 27 Jul 2023 11:19:01 +0000 Subject: [PATCH 2/2] update precision/api_ut_tools/dump/api_info.py. Signed-off-by: jiangchangting1 --- precision/api_ut_tools/dump/api_info.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/precision/api_ut_tools/dump/api_info.py b/precision/api_ut_tools/dump/api_info.py index 8e8adf1a05..911bba77ce 100644 --- a/precision/api_ut_tools/dump/api_info.py +++ b/precision/api_ut_tools/dump/api_info.py @@ -28,8 +28,8 @@ class APIInfo: else: msg = f"Type {type(element)} is unsupported at analyze_element" print_error_log(msg) - return None - # raise NotImplementedError(msg) + + raise NotImplementedError(msg) return out def analyze_tensor(self, arg, save_real_data): -- Gitee