From 89ed7ad369aaea74e279208dab40fdb0ef1ca885 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Wed, 26 Mar 2025 11:04:45 +0800 Subject: [PATCH 1/2] support mint distributed --- .../msprobe/core/common/const.py | 8 ++++-- .../core/data_dump/data_processor/base.py | 6 +++++ .../data_processor/mindspore_processor.py | 26 ++++++++++++++++++- .../data_processor/pytorch_processor.py | 9 ------- .../msprobe/docs/06.data_dump_MindSpore.md | 1 + .../mindspore/dump/hook_cell/api_register.py | 13 ++++++++-- .../dump/hook_cell/support_wrap_ops.yaml | 19 ++++++++++++++ 7 files changed, 68 insertions(+), 14 deletions(-) diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 65b099b22d..7b5d0f5e0e 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -250,6 +250,7 @@ class Const: MS_API_TYPE_MINT = "mint.ops" MS_API_TYPE_MINT_FUNC = "mint.nn.functional" MS_API_TYPE_COM = "communication.comm_func" + MS_API_TYPE_MINT_DIST = "mint.distributed" FUNCTIONAL_API_TYPE_PREFIX = "Functional" TENSOR_API_TYPE_PREFIX = "Tensor" @@ -262,6 +263,7 @@ class Const: MINT_API_TYPE_PREFIX = "Mint" MINT_FUNC_API_TYPE_PREFIX = "MintFunctional" + MINT_DIST_API_TYPE_PREFIX = "MintDistributed" SUPPORT_API_DICT_KEY_MAP = { PT_FRAMEWORK: { @@ -280,7 +282,8 @@ class Const: MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR, MS_API_TYPE_MINT: MS_API_TYPE_MINT, MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC, - MS_API_TYPE_COM: MS_API_TYPE_COM + MS_API_TYPE_COM: MS_API_TYPE_COM, + MS_API_TYPE_MINT_DIST: MS_API_TYPE_MINT_DIST }, MT_FRAMEWORK: { PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, @@ -308,7 +311,8 @@ class Const: MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX, MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX, MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX, - MS_API_TYPE_COM: DIST_API_TYPE_PREFIX + MS_API_TYPE_COM: DIST_API_TYPE_PREFIX, + MS_API_TYPE_MINT_DIST: MINT_DIST_API_TYPE_PREFIX }, MT_FRAMEWORK: { PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 282ff5946c..44061f9acd 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -176,6 +176,10 @@ class BaseDataProcessor: else: raise ValueError("set_value_into_nested_structure failed: " "invalid data_structure type or invalid index") + + @staticmethod + def is_distributed_op(module): + return getattr(module, "op_is_distributed", False) @staticmethod def _convert_numpy_to_builtin(arg): @@ -350,6 +354,8 @@ class BaseDataProcessor: return api_info_struct def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): + if self.is_distributed_op(module): + module_input_output.update_output_with_args_and_kwargs() api_info_struct = {} # check whether data_mode contains forward or input if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c6ab0293cf..ef6f049c85 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -13,10 +13,12 @@ # limitations under the License. # ============================================================================ +import hashlib import zlib import mindspore as ms from mindspore import mint, ops, hal +from mindspore.mint import distributed from mindspore._c_expression.typing import Number import numpy as np @@ -36,7 +38,7 @@ except ImportError: class MindsporeDataProcessor(BaseDataProcessor): - mindspore_special_type = tuple([ms.Tensor, Number]) + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -104,6 +106,12 @@ class MindsporeDataProcessor(BaseDataProcessor): def is_hookable_element(element): return hasattr(element, "register_hook") and callable(element.register_hook) + @staticmethod + def process_group_hash(arg): + group_ranks = distributed.get_process_group_ranks(arg) + group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() + return group_ranks_hash + @classmethod def get_special_types(cls): return super().get_special_types() + cls.mindspore_special_type @@ -136,8 +144,24 @@ class MindsporeDataProcessor(BaseDataProcessor): return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): return self._analyze_builtin(element) + if isinstance(element, distributed.P2POp): + return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) return {} + def _analyze_p2pop(self, arg, suffix): + p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"} + try: + tensor_info = self._analyze_tensor(arg.tensor, suffix) + p2pop_info.update({"tensor": tensor_info}) + p2pop_info.update({"op": arg.op}) + p2pop_info.update({"peer": arg.peer}) + p2pop_info.update({"tag": arg.tag}) + group_id = self.process_group_hash(arg.group) if arg.group else None + p2pop_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to parse the P2POp content with error info: {e}.") + return p2pop_info + def _analyze_tensor(self, tensor, suffix): tensor_stat = self.get_stat_info(tensor) tensor_json = { diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 973bfd981e..ea52de12fe 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -177,10 +177,6 @@ class PytorchDataProcessor(BaseDataProcessor): group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() return group_ranks_hash - @staticmethod - def is_distributed_op(module): - return getattr(module, "op_is_distributed", False) - @staticmethod def is_hookable_element(element): return (hasattr(element, "register_hook") and callable(element.register_hook)) and \ @@ -257,11 +253,6 @@ class PytorchDataProcessor(BaseDataProcessor): return self._analyze_builtin(element) return {} - def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): - if self.is_distributed_op(module): - module_input_output.update_output_with_args_and_kwargs() - return super().analyze_forward_output(name, module, module_input_output) - def _analyze_p2pop(self, arg, suffix): p2pop_info = {"class_type": "torch.distributed.P2POp"} try: diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index f8670c93c3..aabd1de3a3 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -414,6 +414,7 @@ npy文件名的前缀含义如下: | Primitive | mindspore.ops.Primitive API数据 | | Mint | mindspore.mint API数据 | | MintFunctional | mindspore.mint.nn.functional API数据 | +| MintDistributed | mindspore.mint.distributed API数据 | | Distributed | mindspore.communication.comm_func API数据 | | Jit | 被"jit"装饰的模块或函数数据 | | Cell | mindspore.nn.Cell 类(模块)数据 | diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py index 7a5737662d..2b93df899e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -16,6 +16,7 @@ import os from mindspore import Tensor, ops, mint +from mindspore.mint import distributed from mindspore.mint.nn import functional from mindspore.communication import comm_func @@ -41,7 +42,8 @@ if not is_mindtorch(): Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)), Const.MS_API_TYPE_MINT: (mint, (mint,)), Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)), - Const.MS_API_TYPE_COM: (comm_func, (comm_func,)) + Const.MS_API_TYPE_COM: (comm_func, (comm_func,)), + Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,)) } } if stub_tensor_existed: @@ -84,6 +86,8 @@ class ApiTemplate(HOOKCell): self.api_func = api_func self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP super().__init__(hook_build_func) + if prefix == Const.MINT_DIST_API_TYPE_PREFIX: + self.op_is_distributed = True @staticmethod def async_to_sync(output): @@ -103,9 +107,14 @@ class ApiTemplate(HOOKCell): output = self.api_func(*args, **kwargs) - if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): + if self.prefix_api_name.startswith( + (MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX) + ): if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: output = self.async_to_sync(output) + if self.api_name == "batch_isend_irecv" and isinstance(output, list): + output = [self.async_to_sync(handle) for handle in output] + return output def forward(self, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 364062b464..d16a69d973 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -1025,3 +1025,22 @@ communication.comm_func: - recv - isend - irecv + +mint.distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - all_to_all_single + - all_to_all + - all_gather_into_tensor + - reduce_scatter_tensor + - batch_isend_irecv + \ No newline at end of file -- Gitee From 974ab21a309ebd14d0bc12712330a52ff4849215 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Thu, 27 Mar 2025 19:52:51 +0800 Subject: [PATCH 2/2] modify view type replace clone --- .../dump/module_dump/module_processer.py | 27 ++++++++++--------- 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index d9b67c9317..4f148008eb 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -47,30 +47,33 @@ class ModuleProcesser: def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None - BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) - BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + BackwardHook.setup_input_hook = ModuleProcesser.modify_view_type_return_value(BackwardHook.setup_input_hook) + BackwardHook.setup_output_hook = ModuleProcesser.modify_view_type_return_value(BackwardHook.setup_output_hook) replace_checkpoint() @staticmethod - def clone_return_value(func): + def modify_view_type_return_value(func): @wraps(func) - def clone_return_value_func(*args, **kwargs): + def modify_view_type_return_value_func(*args, **kwargs): result = func(*args, **kwargs) - return ModuleProcesser.clone_if_tensor(result) + return ModuleProcesser.modify_view_type(result) - return clone_return_value_func + return modify_view_type_return_value_func @staticmethod - @recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH) - def clone_if_tensor(result): + @recursion_depth_decorator("ModuleDump: ModuleProcesser.modify_view_type", max_depth=Const.DUMP_MAX_DEPTH) + def modify_view_type(result): if isinstance(result, torch.Tensor) and not is_float8_tensor(result): - return result.clone() + if hasattr(result, "_base") and result._base is not None: + if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0): + torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0)) + return result elif type(result) is tuple: - return tuple(ModuleProcesser.clone_if_tensor(x) for x in result) + return tuple(ModuleProcesser.modify_view_type(x) for x in result) elif type(result) is list: - return list(ModuleProcesser.clone_if_tensor(x) for x in result) + return list(ModuleProcesser.modify_view_type(x) for x in result) elif type(result) is dict: - return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()} + return {k: ModuleProcesser.modify_view_type(v) for k, v in result.items()} else: return result -- Gitee