From 8fcedc756a270c79a4910a73931c1aba5b82f1f5 Mon Sep 17 00:00:00 2001 From: wangchao Date: Tue, 29 Aug 2023 17:14:39 +0800 Subject: [PATCH] replace * to . --- .../api_accuracy_checker/common/base_api.py | 18 +++++++++--------- .../api_accuracy_checker/run_ut/run_ut.py | 13 +++++++------ 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 3689a4c1d1..0db84b45bf 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -14,7 +14,7 @@ class BaseAPIInfo: self.save_path = save_path self.forward_path = forward_path self.backward_path = backward_path - + def analyze_element(self, element): if isinstance(element, (list, tuple)): out = [] @@ -45,16 +45,16 @@ class BaseAPIInfo: def analyze_tensor(self, arg): single_arg = {} if not self.is_save_data: - + single_arg.update({'type' : 'torch.Tensor'}) single_arg.update({'dtype' : str(arg.dtype)}) single_arg.update({'shape' : arg.shape}) single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg,'max'), str(arg.dtype))}) single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg,'min'), str(arg.dtype))}) single_arg.update({'requires_grad': arg.requires_grad}) - + else: - api_args = self.api_name + '*' + str(self.args_num) + api_args = self.api_name + '.' + str(self.args_num) if self.is_forward: forward_real_data_path = os.path.join(self.save_path, self.forward_path) @@ -89,12 +89,12 @@ class BaseAPIInfo: if element is None or isinstance(element, (bool,int,float,str,slice)): return True return False - + def analyze_device_in_kwargs(self, element): single_arg = {} single_arg.update({'type' : 'torch.device'}) if not isinstance(element, str): - + if hasattr(element, "index"): device_value = element.type + ":" + str(element.index) single_arg.update({'value' : device_value}) @@ -103,13 +103,13 @@ class BaseAPIInfo: else: single_arg.update({'value' : element}) return single_arg - + def analyze_dtype_in_kwargs(self, element): single_arg = {} single_arg.update({'type' : 'torch.dtype'}) single_arg.update({'value' : str(element)}) return single_arg - + def get_tensor_extremum(self, data, operator): if data.dtype is torch.bool: if operator == 'max': @@ -120,7 +120,7 @@ class BaseAPIInfo: return torch._C._VariableFunctionsClass.max(data).item() else: return torch._C._VariableFunctionsClass.min(data).item() - + def get_type_name(self, name): left = name.index("'") diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 813d0cb586..ee962e0298 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -95,16 +95,17 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: + api_full_name = api_full_name.replace("*", ".") for element in data_info.in_fwd_data_list: - UtAPIInfo(api_full_name + '*forward*input', element) + UtAPIInfo(api_full_name + '.forward.input', element) if data_info.bench_out is not None: - UtAPIInfo(api_full_name + '*forward*output*bench', data_info.bench_out) - UtAPIInfo(api_full_name + '*forward*output*npu', data_info.npu_out) + UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) + UtAPIInfo(api_full_name + '.forward.output.npu', data_info.npu_out) if data_info.grad_in is not None: - UtAPIInfo(api_full_name + '*backward*input', data_info.grad_in) + UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) if data_info.bench_grad_out is not None: - UtAPIInfo(api_full_name + '*backward*output*bench', data_info.bench_grad_out) - UtAPIInfo(api_full_name + '*backward*output*npu', data_info.npu_grad_out) + UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out) + UtAPIInfo(api_full_name + '.backward.output.npu', data_info.npu_grad_out) -- Gitee