diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 2fa89beb9751012f77df748386cd6b3bf59159cd..1aaf7d055fdcaab86028d8c7c2d19317ea050fa2 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -277,6 +277,8 @@ class Const: SUPPORT_API_FILE_NAME = "support_wrap_ops.yaml" + API_ATTR_LIST = ["__name__", "default"] + PT_API_TYPE_FUNCTIONAL = "functional" PT_API_TYPE_TENSOR = "tensor" PT_API_TYPE_TORCH = "torch" diff --git a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py index 8d676dd2bb61ae9b4bd59dd4ed2d0ca9f41eb326..cc90aaa94ddbcd539919385b58ac4253b0120e09 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/api_registry.py @@ -35,7 +35,7 @@ class ApiWrapper: def __init__( self, api_types: Dict[str, Dict[str, Any]], api_list_paths: Union[str, List[str], Tuple[str]], - backlist: Union[List[str], Tuple[str]] = None + blacklist: Union[List[str], Tuple[str]] = None ): self.api_types = api_types if not isinstance(api_list_paths, (list, tuple)): @@ -44,7 +44,7 @@ class ApiWrapper: raise RuntimeError("The number of api_list_paths must be equal to the number of frameworks in 'api_types', " "when api_list_paths is a list or tuple.") self.api_list_paths = api_list_paths - self.backlist = backlist if backlist else [] + self.blacklist = blacklist if blacklist else [] self.api_names = self._get_api_names() self.wrapped_api_functions = dict() @@ -93,7 +93,11 @@ class ApiWrapper: return api_func(*args, **kwargs) return api_instance(*args, **kwargs) - api_function.__name__ = api_name + for attr_name in Const.API_ATTR_LIST: + if hasattr(api_func, attr_name): + attr_value = getattr(api_func, attr_name) + setattr(api_function, attr_name, attr_value) + return api_function def wrap_api( @@ -142,7 +146,7 @@ class ApiWrapper: api_from_file = api_list.get(key_in_file, []) names = set() for api_name in api_from_file: - if f'{key_in_file}.{api_name}' in self.backlist: + if f'{key_in_file}.{api_name}' in self.blacklist: continue target_attr = api_name for module in api_modules[0]: @@ -164,7 +168,7 @@ class ApiRegistry: Base class for api registry. """ - def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, backlist=None): + def __init__(self, api_types, inner_used_api, supported_api_list_path, api_templates, blacklist=None): self.ori_api_attr = dict() self.wrapped_api_attr = dict() self.inner_used_ori_attr = dict() @@ -173,7 +177,7 @@ class ApiRegistry: self.inner_used_api = inner_used_api self.supported_api_list_path = supported_api_list_path self.api_templates = api_templates - self.backlist = backlist if backlist else [] + self.blacklist = blacklist if blacklist else [] self.all_api_registered = False @staticmethod @@ -232,7 +236,7 @@ class ApiRegistry: self.set_api_attr(self.inner_used_api.get(api_type)[0], self.inner_used_ori_attr.get(api_type, {})) def initialize_hook(self, hook_build_func): - api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.backlist) + api_wrapper = ApiWrapper(self.api_types, self.supported_api_list_path, self.blacklist) wrapped_api_functions = api_wrapper.wrap_api(self.api_templates, hook_build_func) for framework, api_types in self.api_types.items(): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index cfa974b9aa82c03a2f65e6c43a385fbd9cd16e2f..9e81ac56914c689940e2acfbb4cb5a5e152a337d 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -279,8 +279,8 @@ class PytorchDataProcessor(BaseDataProcessor): tensor_stat = self.get_stat_info(common_tensor, self.config.async_dump, self.config.precision) tensor_json = {} tensor_json.update({'type': self.tensor_handler.get_tensor_type(tensor)}) - tensor_json.update({'dtype': str(tensor.dtype)}) - tensor_json.update({"shape": tensor.shape}) + tensor_json.update({'dtype': str(common_tensor.dtype)}) + tensor_json.update({"shape": common_tensor.shape}) stat_values = [ tensor_stat.max, @@ -299,11 +299,10 @@ class PytorchDataProcessor(BaseDataProcessor): if self.config.summary_mode == Const.MD5 and not self.config.async_dump: tensor_md5 = None if not self.tensor_handler.is_empty_data(tensor): - logger.debug("Calculating the md5 value of fake tensor or meta tensor is not supported.") # 拷贝并搬到 CPU - if tensor.dtype == torch.bfloat16: - tensor = tensor.float() - tensor_bytes = tensor.cpu().detach().numpy() + if common_tensor.dtype == torch.bfloat16: + common_tensor = common_tensor.float() + tensor_bytes = common_tensor.cpu().detach().numpy() future = self._crc_executor.submit( PytorchDataProcessor.compute_crc32_bytes, @@ -313,14 +312,21 @@ class PytorchDataProcessor(BaseDataProcessor): crc_placeholder = self.data_writer.append_crc32_to_buffer(future) tensor_json[Const.MD5_INDEX] = crc_placeholder else: + logger.debug( + "Calculating the md5 value of fake tensor or meta tensor is not supported, " + f"the current api/module name is {self.current_api_or_module_name}." + ) tensor_json.update({Const.MD5: tensor_md5}) return tensor_json def _analyze_and_save_tensor(self, tensor, suffix): dump_data_name, file_path = self.get_save_file_path(suffix) single_arg = PytorchDataProcessor._analyze_tensor(self, tensor, suffix) - if self.tensor_handler.is_empty_data(tensor) or tensor.storage().data_ptr() == 0: - logger.debug("Collecting real data of fake tensor or meta tensor is not supported or data_ptr is 0.") + if self.tensor_handler.is_empty_data(tensor) or tensor.untyped_storage().data_ptr() == 0: + logger.debug( + "Collecting real data of fake tensor or meta tensor is not supported or data_ptr is 0, " + f"the current api/module name is {self.current_api_or_module_name}." + ) return single_arg single_arg.update({"data_name": dump_data_name}) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py index d10db82015a9ca3d7972296c3bbf1b424c2355ae..0ff9c56fb5a8fa7c2f044b2c93b69a831bd408b1 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -54,7 +54,7 @@ if not is_mindtorch(): ) _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),) - _backlist = [] + _blacklist = [] else: import torch import torch_npu @@ -69,7 +69,7 @@ else: } _supported_api_list_path = (os.path.join(cur_path, '../../../pytorch/hook_module', MsConst.SUPPORTED_API_LIST_FILE),) - _backlist = [] + _blacklist = [] _inner_used_api = { Const.MS_FRAMEWORK + Const.SEP + Const.MS_API_TYPE_OPS: ( @@ -160,7 +160,7 @@ def get_api_register(return_new=False): _inner_used_api, _supported_api_list_path, ApiTemplate, - _backlist + _blacklist ) global api_register @@ -170,6 +170,6 @@ def get_api_register(return_new=False): _inner_used_api, _supported_api_list_path, ApiTemplate, - _backlist + _blacklist ) return api_register