From 8ff8fdda7009b9da8b2ab3c8b9772ca376606c6e Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Mon, 31 Jul 2023 01:44:26 +0000 Subject: [PATCH 1/6] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- .../api_accuracy_checker/dump/api_info.py | 89 +++++++++++++++++++ 1 file changed, 89 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 5c6138f6e..911bba77c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -1,6 +1,95 @@ # 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 +import inspect +import torch +from .utils import DumpUtil, DumpConst, write_npy +from ..common.utils import print_error_log class APIInfo: def __init__(self, api_name): self.api_name = api_name + self.save_real_data = DumpUtil.save_real_data + self.buildin_class = [] + + def analyze_element(self, element): + if isinstance(element, (list, tuple)): + out = [] + for item in element: + out.append(self.analyze_element(item)) + elif isinstance(element, dict): + out = {} + for key, value in element.items(): + out[key] = self.analyze_element(value) + elif isinstance(element, torch.Tensor): + out = self.analyze_tensor(element, self.save_real_data) + + elif self.is_builtin_class(element): + out = self.analyze_builtin(element) + else: + msg = f"Type {type(element)} is unsupported at analyze_element" + print_error_log(msg) + + raise NotImplementedError(msg) + return out + + def analyze_tensor(self, arg, save_real_data): + single_arg = {} + if not save_real_data: + + single_arg.update({'type' : 'torch.Tensor'}) + single_arg.update({'dtype' : str(arg.dtype)}) + single_arg.update({'shape' : arg.shape}) + single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg,'max'), str(arg.dtype))}) + single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg,'min'), str(arg.dtype))}) + single_arg.update({'requires_grad': arg.requires_grad}) + + else: + npy_path = write_npy(self.api_name, arg.contiguous().cpu().detach().numpy()) + single_arg.update({'type' : 'torch.Tensor'}) + single_arg.update({'datapath' : npy_path}) + single_arg.update({'requires_grad': arg.requires_grad}) + return single_arg + + def analyze_builtin(self, arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({'type' : "slice"}) + single_arg.update({'value' : [arg.start, arg.stop, arg.step]}) + else: + single_arg.update({'type' : self.get_type_name(str(type(arg)))}) + single_arg.update({'value' : arg}) + return single_arg + + def transfer_types(self, data, dtype): + if 'bool' in dtype: + breakpoint() + if 'int' in dtype or 'bool' in dtype: + return int(data) + else: + return float(data) + + def is_builtin_class(self, element): + if element is None or isinstance(element, (bool,int,float,str,slice)): + return True + return False + + + def get_tensor_extremum(self, data, operator): + if data.dtype is torch.bool: + if operator == 'max': + return True in data + elif operator == 'min': + if False in data: + return False + else: + return True + if operator == 'max': + return torch._C._VariableFunctionsClass.max(data).item() + else: + return torch._C._VariableFunctionsClass.min(data).item() + + def get_type_name(self, name): + + left = name.index("'") + right = name.rindex("'") + return name[left + 1 : right] -- Gitee From e48e2c4d95e524c7a335dc89bba845789b785b96 Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Mon, 31 Jul 2023 01:46:16 +0000 Subject: [PATCH 2/6] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 911bba77c..2cbba0c29 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -19,7 +19,10 @@ class APIInfo: elif isinstance(element, dict): out = {} for key, value in element.items(): - out[key] = self.analyze_element(value) + if isinstance(value, torch.Tensor): + out[key] = self.analyze_element(value) + else: + out[key] = value elif isinstance(element, torch.Tensor): out = self.analyze_tensor(element, self.save_real_data) -- Gitee From 5aee98c29e1665538be3c127a5c88e801c395709 Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Mon, 31 Jul 2023 02:20:09 +0000 Subject: [PATCH 3/6] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 2cbba0c29..2ca2d1947 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -19,10 +19,8 @@ class APIInfo: elif isinstance(element, dict): out = {} for key, value in element.items(): - if isinstance(value, torch.Tensor): - out[key] = self.analyze_element(value) - else: - out[key] = value + out[key] = self.analyze_element(value) + elif isinstance(element, torch.Tensor): out = self.analyze_tensor(element, self.save_real_data) -- Gitee From 497597787eaf8e4d857ce1de8c88710888041b9a Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Mon, 31 Jul 2023 12:28:59 +0000 Subject: [PATCH 4/6] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 2ca2d1947..7ced5cc2e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -62,8 +62,6 @@ class APIInfo: return single_arg def transfer_types(self, data, dtype): - if 'bool' in dtype: - breakpoint() if 'int' in dtype or 'bool' in dtype: return int(data) else: -- Gitee From c25319fb301c511495cd7f67e3f608f3e90a13ae Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Mon, 31 Jul 2023 12:34:02 +0000 Subject: [PATCH 5/6] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 1 - 1 file changed, 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 7ced5cc2e..11f7c83b0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -9,7 +9,6 @@ class APIInfo: def __init__(self, api_name): self.api_name = api_name self.save_real_data = DumpUtil.save_real_data - self.buildin_class = [] def analyze_element(self, element): if isinstance(element, (list, tuple)): -- Gitee From 2079d337b5d33ccf4bc5e25e9e09057f6af9929b Mon Sep 17 00:00:00 2001 From: jiangchangting1 Date: Tue, 1 Aug 2023 02:39:42 +0000 Subject: [PATCH 6/6] update debug/accuracy_tools/api_accuracy_checker/dump/api_info.py. Signed-off-by: jiangchangting1 --- debug/accuracy_tools/api_accuracy_checker/dump/api_info.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 11f7c83b0..0216e1c64 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -77,10 +77,7 @@ class APIInfo: if operator == 'max': return True in data elif operator == 'min': - if False in data: - return False - else: - return True + return False not in data if operator == 'max': return torch._C._VariableFunctionsClass.max(data).item() else: -- Gitee