diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 5a1e7569d8fc3423da7664dbd7582858e75ad062..c33ca941abb7954c7dc3066598bffba60ff4cdb9 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -28,6 +28,7 @@ from msprobe.core.common.file_utils import path_len_exceeds_limit from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy from msprobe.mindspore.common.log import logger from msprobe.mindspore.dump.hook_cell.api_register import get_api_register +from msprobe.mindspore.common.utils import is_mindtorch has_adump = True try: @@ -35,9 +36,15 @@ try: except ImportError: has_adump = False +if is_mindtorch(): + from torch import distributed as dist + class MindsporeDataProcessor(BaseDataProcessor): - mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) + if is_mindtorch(): + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp, dist.ProcessGroup]) + else: + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -114,6 +121,19 @@ class MindsporeDataProcessor(BaseDataProcessor): group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) return f"{group_ranks_hash:08x}" + @staticmethod + def _analyze_process_group(arg): + group_info = {"type": "mindspore.ProcessGroup"} + try: + group_ranks = dist.get_process_group_ranks(arg) + group_info.update({"group_ranks": group_ranks}) + group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) + group_id = f"{group_ranks_hash:08x}" + group_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to get process group ranks info with error info: {e}.") + return group_info + @classmethod def get_special_types(cls): return super().get_special_types() + cls.mindspore_special_type @@ -149,6 +169,8 @@ class MindsporeDataProcessor(BaseDataProcessor): (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)), (distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)) ] + if is_mindtorch(): + type_analyzer.append((dist.ProcessGroup, self._analyze_process_group)) for type_key, analyze_fn in type_analyzer: if isinstance(element, type_key): return analyze_fn(element)