diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index f7ba9f90d0e27df9a40ee37b838fbd2919f0d25b..977f9310b772ceae96055c057a9cb3f57ab54ffb 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -247,6 +247,7 @@ class Const: MS_API_TYPE_MINT = "mint.ops" MS_API_TYPE_MINT_FUNC = "mint.nn.functional" MS_API_TYPE_COM = "communication.comm_func" + MS_API_TYPE_MINT_DIST = "mint.distributed" FUNCTIONAL_API_TYPE_PREFIX = "Functional" TENSOR_API_TYPE_PREFIX = "Tensor" @@ -259,6 +260,7 @@ class Const: MINT_API_TYPE_PREFIX = "Mint" MINT_FUNC_API_TYPE_PREFIX = "MintFunctional" + MINT_DIST_API_TYPE_PREFIX = "MintDistributed" SUPPORT_API_DICT_KEY_MAP = { PT_FRAMEWORK: { @@ -277,7 +279,8 @@ class Const: MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR, MS_API_TYPE_MINT: MS_API_TYPE_MINT, MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC, - MS_API_TYPE_COM: MS_API_TYPE_COM + MS_API_TYPE_COM: MS_API_TYPE_COM, + MS_API_TYPE_MINT_DIST: MS_API_TYPE_MINT_DIST }, MT_FRAMEWORK: { PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, @@ -305,7 +308,8 @@ class Const: MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX, MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX, MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX, - MS_API_TYPE_COM: DIST_API_TYPE_PREFIX + MS_API_TYPE_COM: DIST_API_TYPE_PREFIX, + MS_API_TYPE_MINT_DIST: MINT_DIST_API_TYPE_PREFIX }, MT_FRAMEWORK: { PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 775a80b2418ef356867228b4ca09fad8c86cce25..d1e43adc2c978e36633913d0c9b5b93d7b90ae89 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -176,6 +176,10 @@ class BaseDataProcessor: else: raise ValueError("set_value_into_nested_structure failed: " "invalid data_structure type or invalid index") + + @staticmethod + def is_distributed_op(module): + return getattr(module, "op_is_distributed", False) @staticmethod def _convert_numpy_to_builtin(arg): @@ -350,6 +354,8 @@ class BaseDataProcessor: return api_info_struct def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): + if self.is_distributed_op(module): + module_input_output.update_output_with_args_and_kwargs() api_info_struct = {} # check whether data_mode contains forward or input if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c6ab0293cf3edafab06a5bf03e1a429d86e92720..1d06d0e256626030e4520081ee5291b4fe497b7a 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - +import hashlib import zlib import mindspore as ms from mindspore import mint, ops, hal +from mindspore.mint import distributed from mindspore._c_expression.typing import Number import numpy as np @@ -36,7 +37,7 @@ except ImportError: class MindsporeDataProcessor(BaseDataProcessor): - mindspore_special_type = tuple([ms.Tensor, Number]) + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -103,6 +104,12 @@ class MindsporeDataProcessor(BaseDataProcessor): @staticmethod def is_hookable_element(element): return hasattr(element, "register_hook") and callable(element.register_hook) + + @staticmethod + def process_group_hash(arg): + group_ranks = distributed.get_process_group_ranks(arg) + group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() + return group_ranks_hash @classmethod def get_special_types(cls): @@ -136,8 +143,24 @@ class MindsporeDataProcessor(BaseDataProcessor): return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): return self._analyze_builtin(element) + if isinstance(element, distributed.P2POp): + return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) return {} + def _analyze_p2pop(self, arg, suffix): + p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"} + try: + tensor_info = self._analyze_tensor(arg.tensor, suffix) + p2pop_info.update({"tensor": tensor_info}) + p2pop_info.update({"op": arg.op}) + p2pop_info.update({"peer": arg.peer}) + p2pop_info.update({"tag": arg.tag}) + group_id = self.process_group_hash(arg.group) if arg.group else None + p2pop_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to parse the P2POp content with error info: {e}.") + return p2pop_info + def _analyze_tensor(self, tensor, suffix): tensor_stat = self.get_stat_info(tensor) tensor_json = { diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 4c56419dcb17b78918e4d46a3aaa50b12ef32777..bc50992d3edc5cafbe4df859718a4f4ef4ed26f3 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -176,10 +176,6 @@ class PytorchDataProcessor(BaseDataProcessor): group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() return group_ranks_hash - @staticmethod - def is_distributed_op(module): - return getattr(module, "op_is_distributed", False) - @staticmethod def is_hookable_element(element): return (hasattr(element, "register_hook") and callable(element.register_hook)) and \ @@ -244,11 +240,6 @@ class PytorchDataProcessor(BaseDataProcessor): return self._analyze_builtin(element) return {} - def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): - if self.is_distributed_op(module): - module_input_output.update_output_with_args_and_kwargs() - return super().analyze_forward_output(name, module, module_input_output) - def _analyze_p2pop(self, arg, suffix): p2pop_info = {"class_type": "torch.distributed.P2POp"} try: diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index d99f61ba40e2978ec950b375382bb58d7379f37c..6257c4b68db6f82c1a637f0b06678b0dcd5a39c0 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -414,6 +414,7 @@ npy文件名的前缀含义如下: | Primitive | mindspore.ops.Primitive API数据 | | Mint | mindspore.mint API数据 | | MintFunctional | mindspore.mint.nn.functional API数据 | +| MintDistributed | mindspore.mint.distributed API数据 | | Distributed | mindspore.communication.comm_func API数据 | | Jit | 被"jit"装饰的模块或函数数据 | | Cell | mindspore.nn.Cell 类(模块)数据 | diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py index 53271ff07bea418db66b3d3724a84eda5b52c296..38637b8fece6f31a9323881137fb713a72a35174 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -16,6 +16,7 @@ import os from mindspore import Tensor, ops, mint +from mindspore.mint import distributed from mindspore.mint.nn import functional from mindspore.common._stub_tensor import StubTensor from mindspore.communication import comm_func @@ -37,7 +38,8 @@ if not is_mindtorch(): Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,)), Const.MS_API_TYPE_MINT: (mint, (mint,)), Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)), - Const.MS_API_TYPE_COM: (comm_func, (comm_func,)) + Const.MS_API_TYPE_COM: (comm_func, (comm_func,)), + Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,)) } } _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),) @@ -75,6 +77,8 @@ class ApiTemplate(HOOKCell): self.api_func = api_func self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP super().__init__(hook_build_func) + if prefix == Const.MINT_DIST_API_TYPE_PREFIX: + self.op_is_distributed = True @staticmethod def async_to_sync(output): @@ -94,9 +98,14 @@ class ApiTemplate(HOOKCell): output = self.api_func(*args, **kwargs) - if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): + if self.prefix_api_name.startswith( + (MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX) + ): if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: output = self.async_to_sync(output) + if self.api_name == "batch_isend_irecv" and isinstance(output, list): + output = [self.async_to_sync(handle) for handle in output] + return output def forward(self, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 364062b46478b63369269c2470ea526eec59a3d3..eae8f85a87fb2b0986cefb2e6faae7399a86f367 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -1025,3 +1025,21 @@ communication.comm_func: - recv - isend - irecv + +mint.distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - all_to_all_single + - all_to_all + - all_gather_into_tensor + - reduce_scatter_tensor + - batch_isend_irecv