diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index 65b099b22d9b07b27b1115c0adddf465963c093a..7b5d0f5e0e4b41f8090963174e238900fe19625b 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -250,6 +250,7 @@ class Const: MS_API_TYPE_MINT = "mint.ops" MS_API_TYPE_MINT_FUNC = "mint.nn.functional" MS_API_TYPE_COM = "communication.comm_func" + MS_API_TYPE_MINT_DIST = "mint.distributed" FUNCTIONAL_API_TYPE_PREFIX = "Functional" TENSOR_API_TYPE_PREFIX = "Tensor" @@ -262,6 +263,7 @@ class Const: MINT_API_TYPE_PREFIX = "Mint" MINT_FUNC_API_TYPE_PREFIX = "MintFunctional" + MINT_DIST_API_TYPE_PREFIX = "MintDistributed" SUPPORT_API_DICT_KEY_MAP = { PT_FRAMEWORK: { @@ -280,7 +282,8 @@ class Const: MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR, MS_API_TYPE_MINT: MS_API_TYPE_MINT, MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC, - MS_API_TYPE_COM: MS_API_TYPE_COM + MS_API_TYPE_COM: MS_API_TYPE_COM, + MS_API_TYPE_MINT_DIST: MS_API_TYPE_MINT_DIST }, MT_FRAMEWORK: { PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, @@ -308,7 +311,8 @@ class Const: MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX, MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX, MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX, - MS_API_TYPE_COM: DIST_API_TYPE_PREFIX + MS_API_TYPE_COM: DIST_API_TYPE_PREFIX, + MS_API_TYPE_MINT_DIST: MINT_DIST_API_TYPE_PREFIX }, MT_FRAMEWORK: { PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 282ff5946cb5cbaf03bc68bd6b643bf24558af1b..44061f9acdef08d763f0bd81af873267392c8915 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -176,6 +176,10 @@ class BaseDataProcessor: else: raise ValueError("set_value_into_nested_structure failed: " "invalid data_structure type or invalid index") + + @staticmethod + def is_distributed_op(module): + return getattr(module, "op_is_distributed", False) @staticmethod def _convert_numpy_to_builtin(arg): @@ -350,6 +354,8 @@ class BaseDataProcessor: return api_info_struct def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): + if self.is_distributed_op(module): + module_input_output.update_output_with_args_and_kwargs() api_info_struct = {} # check whether data_mode contains forward or input if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c6ab0293cf3edafab06a5bf03e1a429d86e92720..ef6f049c85191836a730c647b39dd94cee33ac2c 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -13,10 +13,12 @@ # limitations under the License. # ============================================================================ +import hashlib import zlib import mindspore as ms from mindspore import mint, ops, hal +from mindspore.mint import distributed from mindspore._c_expression.typing import Number import numpy as np @@ -36,7 +38,7 @@ except ImportError: class MindsporeDataProcessor(BaseDataProcessor): - mindspore_special_type = tuple([ms.Tensor, Number]) + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -104,6 +106,12 @@ class MindsporeDataProcessor(BaseDataProcessor): def is_hookable_element(element): return hasattr(element, "register_hook") and callable(element.register_hook) + @staticmethod + def process_group_hash(arg): + group_ranks = distributed.get_process_group_ranks(arg) + group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() + return group_ranks_hash + @classmethod def get_special_types(cls): return super().get_special_types() + cls.mindspore_special_type @@ -136,8 +144,24 @@ class MindsporeDataProcessor(BaseDataProcessor): return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): return self._analyze_builtin(element) + if isinstance(element, distributed.P2POp): + return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) return {} + def _analyze_p2pop(self, arg, suffix): + p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"} + try: + tensor_info = self._analyze_tensor(arg.tensor, suffix) + p2pop_info.update({"tensor": tensor_info}) + p2pop_info.update({"op": arg.op}) + p2pop_info.update({"peer": arg.peer}) + p2pop_info.update({"tag": arg.tag}) + group_id = self.process_group_hash(arg.group) if arg.group else None + p2pop_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to parse the P2POp content with error info: {e}.") + return p2pop_info + def _analyze_tensor(self, tensor, suffix): tensor_stat = self.get_stat_info(tensor) tensor_json = { diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 973bfd981eaeba8145ee304cab8a886b4aa95a70..ea52de12fe7494db4430a70c7b4c764847dd3053 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -177,10 +177,6 @@ class PytorchDataProcessor(BaseDataProcessor): group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() return group_ranks_hash - @staticmethod - def is_distributed_op(module): - return getattr(module, "op_is_distributed", False) - @staticmethod def is_hookable_element(element): return (hasattr(element, "register_hook") and callable(element.register_hook)) and \ @@ -257,11 +253,6 @@ class PytorchDataProcessor(BaseDataProcessor): return self._analyze_builtin(element) return {} - def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): - if self.is_distributed_op(module): - module_input_output.update_output_with_args_and_kwargs() - return super().analyze_forward_output(name, module, module_input_output) - def _analyze_p2pop(self, arg, suffix): p2pop_info = {"class_type": "torch.distributed.P2POp"} try: diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index f8670c93c308b76bb2f177a3342d1a85f8e868fb..aabd1de3a3ed4aaaef14eb6af45712a38990507e 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -414,6 +414,7 @@ npy文件名的前缀含义如下: | Primitive | mindspore.ops.Primitive API数据 | | Mint | mindspore.mint API数据 | | MintFunctional | mindspore.mint.nn.functional API数据 | +| MintDistributed | mindspore.mint.distributed API数据 | | Distributed | mindspore.communication.comm_func API数据 | | Jit | 被"jit"装饰的模块或函数数据 | | Cell | mindspore.nn.Cell 类(模块)数据 | diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py index 7a5737662d4e6619d90a6744f975d49fe1784825..2b93df899efe84d983cd793fb20f1b58b0eaf303 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -16,6 +16,7 @@ import os from mindspore import Tensor, ops, mint +from mindspore.mint import distributed from mindspore.mint.nn import functional from mindspore.communication import comm_func @@ -41,7 +42,8 @@ if not is_mindtorch(): Const.MS_API_TYPE_TENSOR: (Tensor, (Tensor,)), Const.MS_API_TYPE_MINT: (mint, (mint,)), Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)), - Const.MS_API_TYPE_COM: (comm_func, (comm_func,)) + Const.MS_API_TYPE_COM: (comm_func, (comm_func,)), + Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,)) } } if stub_tensor_existed: @@ -84,6 +86,8 @@ class ApiTemplate(HOOKCell): self.api_func = api_func self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP super().__init__(hook_build_func) + if prefix == Const.MINT_DIST_API_TYPE_PREFIX: + self.op_is_distributed = True @staticmethod def async_to_sync(output): @@ -103,9 +107,14 @@ class ApiTemplate(HOOKCell): output = self.api_func(*args, **kwargs) - if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): + if self.prefix_api_name.startswith( + (MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX) + ): if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: output = self.async_to_sync(output) + if self.api_name == "batch_isend_irecv" and isinstance(output, list): + output = [self.async_to_sync(handle) for handle in output] + return output def forward(self, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 364062b46478b63369269c2470ea526eec59a3d3..d16a69d97318faf51d34134ad5c06997ed2bc8f3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -1025,3 +1025,22 @@ communication.comm_func: - recv - isend - irecv + +mint.distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - all_to_all_single + - all_to_all + - all_gather_into_tensor + - reduce_scatter_tensor + - batch_isend_irecv + \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index d9b67c93175641cbd0009c1f58562787626d375f..4f148008ebf9bc40b76535d1adae476d6a8bf7c1 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -47,30 +47,33 @@ class ModuleProcesser: def __init__(self, scope): self.scope = scope if isinstance(scope, (ModuleRangeScope, MixRangeScope)) else None - BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) - BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + BackwardHook.setup_input_hook = ModuleProcesser.modify_view_type_return_value(BackwardHook.setup_input_hook) + BackwardHook.setup_output_hook = ModuleProcesser.modify_view_type_return_value(BackwardHook.setup_output_hook) replace_checkpoint() @staticmethod - def clone_return_value(func): + def modify_view_type_return_value(func): @wraps(func) - def clone_return_value_func(*args, **kwargs): + def modify_view_type_return_value_func(*args, **kwargs): result = func(*args, **kwargs) - return ModuleProcesser.clone_if_tensor(result) + return ModuleProcesser.modify_view_type(result) - return clone_return_value_func + return modify_view_type_return_value_func @staticmethod - @recursion_depth_decorator("ModuleDump: ModuleProcesser.clone_if_tensor", max_depth=Const.DUMP_MAX_DEPTH) - def clone_if_tensor(result): + @recursion_depth_decorator("ModuleDump: ModuleProcesser.modify_view_type", max_depth=Const.DUMP_MAX_DEPTH) + def modify_view_type(result): if isinstance(result, torch.Tensor) and not is_float8_tensor(result): - return result.clone() + if hasattr(result, "_base") and result._base is not None: + if torch._C._autograd._get_creation_meta(result) != torch._C._autograd.CreationMeta(0): + torch._C._autograd._set_creation_meta(result, torch._C._autograd.CreationMeta(0)) + return result elif type(result) is tuple: - return tuple(ModuleProcesser.clone_if_tensor(x) for x in result) + return tuple(ModuleProcesser.modify_view_type(x) for x in result) elif type(result) is list: - return list(ModuleProcesser.clone_if_tensor(x) for x in result) + return list(ModuleProcesser.modify_view_type(x) for x in result) elif type(result) is dict: - return {k: ModuleProcesser.clone_if_tensor(v) for k, v in result.items()} + return {k: ModuleProcesser.modify_view_type(v) for k, v in result.items()} else: return result