From 9dac692e8538ce19f66b2840f74c5bd57b450dd2 Mon Sep 17 00:00:00 2001 From: wugengjun <451676383@qq.com> Date: Mon, 30 Jun 2025 10:26:07 +0800 Subject: [PATCH 1/2] =?UTF-8?q?mindtorch=E5=9C=BA=E6=99=AF=E4=B8=8B?= =?UTF-8?q?=E9=87=87=E9=9B=86group=E4=BF=A1=E6=81=AF=E4=B8=BAnone=E9=97=AE?= =?UTF-8?q?=E9=A2=98=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit 08dfefb45d83e13b275596fdd6859e38418a3736) --- .../data_processor/mindspore_processor.py | 24 ++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 5a1e7569d..42d63be3f 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -28,6 +28,7 @@ from msprobe.core.common.file_utils import path_len_exceeds_limit from msprobe.mindspore.common.utils import convert_bf16_to_fp32, save_tensor_as_npy from msprobe.mindspore.common.log import logger from msprobe.mindspore.dump.hook_cell.api_register import get_api_register +from msprobe.mindspore.common.utils import is_mindtorch has_adump = True try: @@ -35,9 +36,15 @@ try: except ImportError: has_adump = False +if is_mindtorch(): + from torch import distributed as dist + class MindsporeDataProcessor(BaseDataProcessor): - mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) + if is_mindtorch(): + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp, dist.ProcessGroup]) + else: + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -108,6 +115,19 @@ class MindsporeDataProcessor(BaseDataProcessor): def is_hookable_element(element): return hasattr(element, "register_hook") and callable(element.register_hook) + @staticmethod + def _analyze_process_group(arg): + group_info = {"type": "mindspore.ProcessGroup"} + try: + group_ranks = dist.get_process_group_ranks(arg) + group_info.update({"group_ranks": group_ranks}) + group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) + group_id = f"{group_ranks_hash:08x}" + group_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to get process group ranks info with error info: {e}.") + return group_info + @staticmethod def process_group_hash(arg): group_ranks = distributed.get_process_group_ranks(arg) @@ -149,6 +169,8 @@ class MindsporeDataProcessor(BaseDataProcessor): (np.ndarray, lambda e: self._analyze_ndarray(e, suffix_str)), (distributed.P2POp, lambda e: self._analyze_p2pop(e, suffix_str)) ] + if is_mindtorch(): + type_analyzer.append((dist.ProcessGroup, self._analyze_process_group)) for type_key, analyze_fn in type_analyzer: if isinstance(element, type_key): return analyze_fn(element) -- Gitee From 483aafc7b19eaf13427a3985ffb1d43b94cf64bf Mon Sep 17 00:00:00 2001 From: wugengjun <451676383@qq.com> Date: Mon, 30 Jun 2025 10:41:26 +0800 Subject: [PATCH 2/2] =?UTF-8?q?mindtorch=E5=9C=BA=E6=99=AF=E4=B8=8B?= =?UTF-8?q?=E9=87=87=E9=9B=86group=E4=BF=A1=E6=81=AF=E4=B8=BAnone=E9=97=AE?= =?UTF-8?q?=E9=A2=98=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit (cherry picked from commit ab7ac54eeb7e3232e00f7c244c96d112d34a25d0) --- .../data_dump/data_processor/mindspore_processor.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 42d63be3f..c33ca941a 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -115,6 +115,12 @@ class MindsporeDataProcessor(BaseDataProcessor): def is_hookable_element(element): return hasattr(element, "register_hook") and callable(element.register_hook) + @staticmethod + def process_group_hash(arg): + group_ranks = distributed.get_process_group_ranks(arg) + group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) + return f"{group_ranks_hash:08x}" + @staticmethod def _analyze_process_group(arg): group_info = {"type": "mindspore.ProcessGroup"} @@ -128,12 +134,6 @@ class MindsporeDataProcessor(BaseDataProcessor): logger.warning(f"Failed to get process group ranks info with error info: {e}.") return group_info - @staticmethod - def process_group_hash(arg): - group_ranks = distributed.get_process_group_ranks(arg) - group_ranks_hash = zlib.crc32(str(group_ranks).encode('utf-8')) - return f"{group_ranks_hash:08x}" - @classmethod def get_special_types(cls): return super().get_special_types() + cls.mindspore_special_type -- Gitee