From d00a2c7a401b710ffa8211b5d7d90ad1453a622e Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 11 Dec 2023 11:30:51 +0800 Subject: [PATCH 1/2] add real_data_list --- .../api_accuracy_checker/common/base_api.py | 17 +++++++++-------- .../api_accuracy_checker/common/config.py | 8 +++++++- .../api_accuracy_checker/config.yaml | 1 + 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 5bcc0e78af..9ed60dd33f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -2,6 +2,7 @@ import os import torch from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create +from api_accuracy_checker.common.config import msCheckerConfig class BaseAPIInfo: @@ -44,13 +45,13 @@ class BaseAPIInfo: def analyze_tensor(self, arg): single_arg = {} - if not self.is_save_data: + if not self.is_save_data and self.api_name not in msCheckerConfig.real_data_list: - single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'dtype' : str(arg.dtype)}) - single_arg.update({'shape' : arg.shape}) - single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg, 'max'), str(arg.dtype))}) - single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg, 'min'), str(arg.dtype))}) + single_arg.update({'type': 'torch.Tensor'}) + single_arg.update({'dtype': str(arg.dtype)}) + single_arg.update({'shape': arg.shape}) + single_arg.update({'Max': self.transfer_types(self.get_tensor_extremum(arg, 'max'), str(arg.dtype))}) + single_arg.update({'Min': self.transfer_types(self.get_tensor_extremum(arg, 'min'), str(arg.dtype))}) single_arg.update({'requires_grad': arg.requires_grad}) else: @@ -68,8 +69,8 @@ class BaseAPIInfo: file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') self.args_num += 1 pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) - single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'datapath' : pt_path}) + single_arg.update({'type': 'torch.Tensor'}) + single_arg.update({'datapath': pt_path}) single_arg.update({'requires_grad': arg.requires_grad}) return single_arg diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index f9b882f47b..921a03e6fa 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -23,7 +23,8 @@ class Config: 'error_data_path': str, 'target_iter': list, 'precision': int, - 'white_list': list + 'white_list': list, + 'real_data_list': list } if not isinstance(value, validators.get(key)): raise ValueError(f"{key} must be {validators[key].__name__} type") @@ -46,6 +47,11 @@ class Config: invalid_api = [i for i in value if i not in WrapApi] if invalid_api: raise ValueError(f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list") + if key == 'real_data_list': + if not isinstance(value, list): + raise ValueError("real_data_list must be a list type") + if not all(isinstance(i, str) for i in value): + raise ValueError("All elements in real_data_list must be of str type") return value def __getattr__(self, item): diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 0bd145893e..ece20e79e1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -6,4 +6,5 @@ error_data_path: './' target_iter: [1] precision: 14 white_list: [] +real_data_list: [] \ No newline at end of file -- Gitee From 17a66bb3df8c4bf7701ca7ba2f8a155596251bcf Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 13 Dec 2023 18:07:17 +0800 Subject: [PATCH 2/2] update --- debug/accuracy_tools/api_accuracy_checker/common/base_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 978ea1aab3..d23bfec437 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -45,7 +45,7 @@ class BaseAPIInfo: def analyze_tensor(self, arg): single_arg = {} - if not self.is_save_data and self.api_name not in msCheckerConfig.real_data_list: + if not self.is_save_data and all(name not in self.api_name for name in msCheckerConfig.real_data_list): single_arg.update({'type': 'torch.Tensor'}) single_arg.update({'dtype': str(arg.dtype)}) -- Gitee