From 8946c548586a78428d5591ed728aa03224740044 Mon Sep 17 00:00:00 2001 From: curry3 <485078529@qq.com> Date: Tue, 26 Aug 2025 20:10:21 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E4=BF=AE=E5=A4=8Dwrap?= =?UTF-8?q?=20API=E5=90=8E=E5=B1=9E=E6=80=A7=E7=BC=BA=E5=A4=B1=E7=9A=84?= =?UTF-8?q?=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/core/common/const.py | 2 ++ .../msprobe/core/data_dump/api_registry.py | 18 +++++++++------ .../data_processor/pytorch_processor.py | 22 ++++++++++++------- .../mindspore/dump/hook_cell/api_register.py | 8 +++---- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 2fa89beb97..1aaf7d055f 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -277,6 +277,8 @@ class Const: SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml" + API_ATTR_LIST = ["__name__", "default"] + PT_API_TYPE_FUNCTIONAL = "functional" PT_API_TYPE_TENSOR = "tensor" PT_API_TYPE_TORCH = "torch" diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py index 8d676dd2bb..cc90aaa94d 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -35,7 +35,7 @@ class ApiWrapper: def __init__( self, api_types: Dict[str, Dict[str, Any]], api_list_paths: Union[str, List[str], Tuple[str]], - backlist: Union[List[str], Tuple[str]] = None + blacklist: Union[List[str], Tuple[str]] = None ): self.api_types = api_types if not isinstance(api_list_paths, (list, tuple)): @@ -44,7 +44,7 @@ class ApiWrapper: raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', " "when api_list_paths is a list or tuple.") self.api_list_paths = api_list_paths - self.backlist = backlist if backlist else [] + self.blacklist = blacklist if blacklist else [] self.api_names = self._get_api_names() self.wrapped_api_functions = dict() @@ -93,7 +93,11 @@ class ApiWrapper: return api_func(*args, **kwargs) return api_instance(*args, **kwargs) - api_function.__name__ = api_name + for attr_name in Const.API_ATTR_LIST: + if hasattr(api_func, attr_name): + attr_value = getattr(api_func, attr_name) + setattr(api_function, attr_name, attr_value) + return api_function def wrap_api( @@ -142,7 +146,7 @@ class ApiWrapper: api_from_file = api_list.get(key_in_file, []) names = set() for api_name in api_from_file: - if f'{key_in_file}.{api_name}' in self.backlist: + if f'{key_in_file}.{api_name}' in self.blacklist: continue target_attr = api_name for module in api_modules[0]: @@ -164,7 +168,7 @@ class ApiRegistry: Base class for api registry. """ - def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, backlist=None): + def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, blacklist=None): self.ori_api_attr = dict() self.wrapped_api_attr = dict() self.inner_used_ori_attr = dict() @@ -173,7 +177,7 @@ class ApiRegistry: self.inner_used_api = inner_used_api self.supported_api_list_path = supported_api_list_path self.api_templates = api_templates - self.backlist = backlist if backlist else [] + self.blacklist = blacklist if blacklist else [] self.all_api_registered = False @staticmethod @@ -232,7 +236,7 @@ class ApiRegistry: self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {})) def initialize_hook(self, hook_build_func): - api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.backlist) + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.blacklist) wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func) for framework, api_types in self.api_types.items(): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index cfa974b9aa..9e81ac5691 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -279,8 +279,8 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_stat = self.get_stat_info(common_tensor, self.config.async_dump, self.config.precision) tensor_json = {} tensor_json.update({'type': self.tensor_handler.get_tensor_type(tensor)}) - tensor_json.update({'dtype': str(tensor.dtype)}) - tensor_json.update({"shape": tensor.shape}) + tensor_json.update({'dtype': str(common_tensor.dtype)}) + tensor_json.update({"shape": common_tensor.shape}) stat_values = [ tensor_stat.max, @@ -299,11 +299,10 @@ class PytorchDataProcessor(BaseDataProcessor): if self.config.summary_mode == Const.MD5 and not self.config.async_dump: tensor_md5 = None if not self.tensor_handler.is_empty_data(tensor): - logger.debug("Calculating the md5 value of fake tensor or meta tensor is not supported.") # 拷贝并搬到 CPU - if tensor.dtype == torch.bfloat16: - tensor = tensor.float() - tensor_bytes = tensor.cpu().detach().numpy() + if common_tensor.dtype == torch.bfloat16: + common_tensor = common_tensor.float() + tensor_bytes = common_tensor.cpu().detach().numpy() future = self._crc_executor.submit( PytorchDataProcessor.compute_crc32_bytes, @@ -313,14 +312,21 @@ class PytorchDataProcessor(BaseDataProcessor): crc_placeholder = self.data_writer.append_crc32_to_buffer(future) tensor_json[Const.MD5_INDEX] = crc_placeholder else: + logger.debug( + "Calculating the md5 value of fake tensor or meta tensor is not supported, " + f"the current api/module name is {self.current_api_or_module_name}." + ) tensor_json.update({Const.MD5: tensor_md5}) return tensor_json def _analyze_and_save_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix) - if self.tensor_handler.is_empty_data(tensor) or tensor.storage().data_ptr() == 0: - logger.debug("Collecting real data of fake tensor or meta tensor is not supported or data_ptr is 0.") + if self.tensor_handler.is_empty_data(tensor) or tensor.untyped_storage().data_ptr() == 0: + logger.debug( + "Collecting real data of fake tensor or meta tensor is not supported or data_ptr is 0, " + f"the current api/module name is {self.current_api_or_module_name}." + ) return single_arg single_arg.update({"data_name": dump_data_name}) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py index d10db82015..0ff9c56fb5 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -54,7 +54,7 @@ if not is_mindtorch(): ) _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),) - _backlist = [] + _blacklist = [] else: import torch import torch_npu @@ -69,7 +69,7 @@ else: } _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module', MsConst.SUPPORTED_API_LIST_FILE),) - _backlist = [] + _blacklist = [] _inner_used_api = { Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: ( @@ -160,7 +160,7 @@ def get_api_register(return_new=False): _inner_used_api, _supported_api_list_path, ApiTemplate, - _backlist + _blacklist ) global api_register @@ -170,6 +170,6 @@ def get_api_register(return_new=False): _inner_used_api, _supported_api_list_path, ApiTemplate, - _backlist + _blacklist ) return api_register -- Gitee