diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 55e06a7be13282d9c1acd178f12fd874aa87fec4..d23bfec437bf1a7ce8e4b06e4c7426c3bc4ffe01 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -45,13 +45,13 @@ class BaseAPIInfo: def analyze_tensor(self, arg): single_arg = {} - if not self.is_save_data: + if not self.is_save_data and all(name not in self.api_name for name in msCheckerConfig.real_data_list): - single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'dtype' : str(arg.dtype)}) - single_arg.update({'shape' : arg.shape}) - single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg, 'max'), str(arg.dtype))}) - single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg, 'min'), str(arg.dtype))}) + single_arg.update({'type': 'torch.Tensor'}) + single_arg.update({'dtype': str(arg.dtype)}) + single_arg.update({'shape': arg.shape}) + single_arg.update({'Max': self.transfer_types(self.get_tensor_extremum(arg, 'max'), str(arg.dtype))}) + single_arg.update({'Min': self.transfer_types(self.get_tensor_extremum(arg, 'min'), str(arg.dtype))}) single_arg.update({'requires_grad': arg.requires_grad}) else: @@ -69,8 +69,8 @@ class BaseAPIInfo: file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') self.args_num += 1 pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) - single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'datapath' : pt_path}) + single_arg.update({'type': 'torch.Tensor'}) + single_arg.update({'datapath': pt_path}) single_arg.update({'requires_grad': arg.requires_grad}) return single_arg diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 1d9eda41052af523618183468bf61cdae334c0a5..d8ba427aa229e1d942713f4dc06c84a3a0516779 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -24,6 +24,7 @@ class Config: 'target_iter': list, 'precision': int, 'white_list': list, + 'real_data_list': list, 'enable_dataloader': bool } if not isinstance(value, validators.get(key)): @@ -47,6 +48,11 @@ class Config: invalid_api = [i for i in value if i not in WrapApi] if invalid_api: raise ValueError(f"{', '.join(invalid_api)} is not in support_wrap_ops.yaml, please check the white_list") + if key == 'real_data_list': + if not isinstance(value, list): + raise ValueError("real_data_list must be a list type") + if not all(isinstance(i, str) for i in value): + raise ValueError("All elements in real_data_list must be of str type") return value def __getattr__(self, item): diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index ece957347ae8073973e67237ff75572271e55077..39b9d2944a4492d0636f273350df38d105b3948a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -6,5 +6,6 @@ error_data_path: './' target_iter: [1] precision: 14 white_list: [] +real_data_list: [] enable_dataloader: True \ No newline at end of file