diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index ff8d58dc0a30f3c2dbab2f7a1f0a8a62149f46c9..d90e9e66f6cd07732f4e3eccd4c63cd9ab3c0e0b 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -254,6 +254,7 @@ class Const: MS_API_TYPE_MINT = "mint.ops" MS_API_TYPE_MINT_FUNC = "mint.nn.functional" MS_API_TYPE_COM = "communication.comm_func" + MS_API_TYPE_MINT_DIST = "mint.distributed" FUNCTIONAL_API_TYPE_PREFIX = "Functional" TENSOR_API_TYPE_PREFIX = "Tensor" @@ -266,6 +267,7 @@ class Const: MINT_API_TYPE_PREFIX = "Mint" MINT_FUNC_API_TYPE_PREFIX = "MintFunctional" + MINT_DIST_API_TYPE_PREFIX = "MintDistributed" SUPPORT_API_DICT_KEY_MAP = { PT_FRAMEWORK: { @@ -284,7 +286,8 @@ class Const: MS_API_TYPE_STUB_TENSOR: MS_API_TYPE_TENSOR, MS_API_TYPE_MINT: MS_API_TYPE_MINT, MS_API_TYPE_MINT_FUNC: MS_API_TYPE_MINT_FUNC, - MS_API_TYPE_COM: MS_API_TYPE_COM + MS_API_TYPE_COM: MS_API_TYPE_COM, + MS_API_TYPE_MINT_DIST: MS_API_TYPE_MINT_DIST }, MT_FRAMEWORK: { PT_API_TYPE_FUNCTIONAL: PT_API_TYPE_FUNCTIONAL, @@ -312,7 +315,8 @@ class Const: MS_API_TYPE_STUB_TENSOR: TENSOR_API_TYPE_PREFIX, MS_API_TYPE_MINT: MINT_API_TYPE_PREFIX, MS_API_TYPE_MINT_FUNC: MINT_FUNC_API_TYPE_PREFIX, - MS_API_TYPE_COM: DIST_API_TYPE_PREFIX + MS_API_TYPE_COM: DIST_API_TYPE_PREFIX, + MS_API_TYPE_MINT_DIST: MINT_DIST_API_TYPE_PREFIX }, MT_FRAMEWORK: { PT_API_TYPE_FUNCTIONAL: FUNCTIONAL_API_TYPE_PREFIX, diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 775a80b2418ef356867228b4ca09fad8c86cce25..d1e43adc2c978e36633913d0c9b5b93d7b90ae89 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -176,6 +176,10 @@ class BaseDataProcessor: else: raise ValueError("set_value_into_nested_structure failed: " "invalid data_structure type or invalid index") + + @staticmethod + def is_distributed_op(module): + return getattr(module, "op_is_distributed", False) @staticmethod def _convert_numpy_to_builtin(arg): @@ -350,6 +354,8 @@ class BaseDataProcessor: return api_info_struct def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): + if self.is_distributed_op(module): + module_input_output.update_output_with_args_and_kwargs() api_info_struct = {} # check whether data_mode contains forward or input if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index 782deec51c9e33de16836d10b35fae835b1fcd66..34c248f5c4a03349672b302ecb48086761d29458 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ - +import hashlib import zlib import mindspore as ms from mindspore import mint, ops, hal +from mindspore.mint import distributed from mindspore._c_expression.typing import Number import numpy as np @@ -36,7 +37,7 @@ except ImportError: class MindsporeDataProcessor(BaseDataProcessor): - mindspore_special_type = tuple([ms.Tensor, Number]) + mindspore_special_type = tuple([ms.Tensor, Number, distributed.P2POp]) def __init__(self, config, data_writer): super().__init__(config, data_writer) @@ -103,6 +104,12 @@ class MindsporeDataProcessor(BaseDataProcessor): @staticmethod def is_hookable_element(element): return hasattr(element, "register_hook") and callable(element.register_hook) + + @staticmethod + def process_group_hash(arg): + group_ranks = distributed.get_process_group_ranks(arg) + group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() + return group_ranks_hash @classmethod def get_special_types(cls): @@ -136,8 +143,24 @@ class MindsporeDataProcessor(BaseDataProcessor): return self._analyze_numpy(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) if isinstance(element, (bool, int, float, str, slice, type(Ellipsis))): return self._analyze_builtin(element) + if isinstance(element, distributed.P2POp): + return self._analyze_p2pop(element, Const.SEP.join([str(suffix) for suffix in suffix_stack])) return {} + def _analyze_p2pop(self, arg, suffix): + p2pop_info = {"class_type": "mindspore.mint.distributed.P2POp"} + try: + tensor_info = self._analyze_tensor(arg.tensor, suffix) + p2pop_info.update({"tensor": tensor_info}) + p2pop_info.update({"op": arg.op}) + p2pop_info.update({"peer": arg.peer}) + p2pop_info.update({"tag": arg.tag}) + group_id = self.process_group_hash(arg.group) if arg.group else None + p2pop_info.update({"group_id": group_id}) + except Exception as e: + logger.warning(f"Failed to parse the P2POp content with error info: {e}.") + return p2pop_info + def _analyze_tensor(self, tensor, suffix): tensor_stat = self.get_stat_info(tensor) tensor_json = { diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py index 66523da9c5576df9d447954473bb725ce6c07c19..8aadcbd935f104ee1803d7bb8eb51eb5bbe7d3e0 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/pytorch_processor.py @@ -174,10 +174,6 @@ class PytorchDataProcessor(BaseDataProcessor): group_ranks_hash = hashlib.md5(str(group_ranks).encode('utf-8')).hexdigest() return group_ranks_hash - @staticmethod - def is_distributed_op(module): - return getattr(module, "op_is_distributed", False) - @staticmethod def is_hookable_element(element): return (hasattr(element, "register_hook") and callable(element.register_hook)) and \ @@ -242,11 +238,6 @@ class PytorchDataProcessor(BaseDataProcessor): return self._analyze_builtin(element) return {} - def analyze_forward_output(self, name, module, module_input_output: ModuleForwardInputsOutputs): - if self.is_distributed_op(module): - module_input_output.update_output_with_args_and_kwargs() - return super().analyze_forward_output(name, module, module_input_output) - def _analyze_p2pop(self, arg, suffix): p2pop_info = {"class_type": "torch.distributed.P2POp"} try: diff --git a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md index fbbbd387644fa9d02236af4c4703913e2484c1ad..b51c3d75dfc2322cc4ffa7349aeb4b4719527e22 100644 --- a/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md +++ b/debug/accuracy_tools/msprobe/docs/06.data_dump_MindSpore.md @@ -467,6 +467,7 @@ npy文件名的前缀含义如下: | Primitive | mindspore.ops.Primitive API数据 | | Mint | mindspore.mint API数据 | | MintFunctional | mindspore.mint.nn.functional API数据 | +| MintDistributed | mindspore.mint.distributed API数据 | | Distributed | mindspore.communication.comm_func API数据 | | Jit | 被"jit"装饰的模块或函数数据 | | Cell | mindspore.nn.Cell 类(模块)数据 | diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py index 53271ff07bea418db66b3d3724a84eda5b52c296..38637b8fece6f31a9323881137fb713a72a35174 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_register.py @@ -16,6 +16,7 @@ import os from mindspore import Tensor, ops, mint +from mindspore.mint import distributed from mindspore.mint.nn import functional from mindspore.common._stub_tensor import StubTensor from mindspore.communication import comm_func @@ -37,7 +38,8 @@ if not is_mindtorch(): Const.MS_API_TYPE_STUB_TENSOR: (StubTensor, (StubTensor,)), Const.MS_API_TYPE_MINT: (mint, (mint,)), Const.MS_API_TYPE_MINT_FUNC: (functional, (functional,)), - Const.MS_API_TYPE_COM: (comm_func, (comm_func,)) + Const.MS_API_TYPE_COM: (comm_func, (comm_func,)), + Const.MS_API_TYPE_MINT_DIST: (distributed, (distributed,)) } } _supported_api_list_path = (os.path.join(cur_path, MsConst.SUPPORTED_API_LIST_FILE),) @@ -75,6 +77,8 @@ class ApiTemplate(HOOKCell): self.api_func = api_func self.prefix_api_name = prefix + Const.SEP + str(api_name.split(Const.SEP)[-1]) + Const.SEP super().__init__(hook_build_func) + if prefix == Const.MINT_DIST_API_TYPE_PREFIX: + self.op_is_distributed = True @staticmethod def async_to_sync(output): @@ -94,9 +98,14 @@ class ApiTemplate(HOOKCell): output = self.api_func(*args, **kwargs) - if self.prefix_api_name.startswith(MsConst.DISTRIBUTED_DATA_PREFIX): + if self.prefix_api_name.startswith( + (MsConst.DISTRIBUTED_DATA_PREFIX, Const.MINT_DIST_API_TYPE_PREFIX) + ): if kwargs.get("async_op") or self.api_name in ["isend", "irecv"]: output = self.async_to_sync(output) + if self.api_name == "batch_isend_irecv" and isinstance(output, list): + output = [self.async_to_sync(handle) for handle in output] + return output def forward(self, *args, **kwargs): diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml index 723b0cbc93f78d50f703838eb488de6733008906..b4f8b114c7e77e2f9f2e377e4c57ab7bedd6a610 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/support_wrap_ops.yaml @@ -1027,3 +1027,21 @@ communication.comm_func: - recv - isend - irecv + +mint.distributed: + - send + - recv + - broadcast + - all_reduce + - reduce + - all_gather + - gather + - isend + - irecv + - scatter + - reduce_scatter + - all_to_all_single + - all_to_all + - all_gather_into_tensor + - reduce_scatter_tensor + - batch_isend_irecv diff --git a/debug/accuracy_tools/msprobe/test/run_ut.py b/debug/accuracy_tools/msprobe/test/run_ut.py index c5ebc6e3f052b8ef7d16694c31c22d16f8ec930a..06671c3d0d0e4440736712cb0718873280482781 100644 --- a/debug/accuracy_tools/msprobe/test/run_ut.py +++ b/debug/accuracy_tools/msprobe/test/run_ut.py @@ -2,6 +2,7 @@ import os import shutil import subprocess import sys +import tempfile from msprobe.core.common.log import logger @@ -20,6 +21,23 @@ def run_ut(): shutil.rmtree(report_dir) os.makedirs(report_dir) + tmpdir = tempfile.mkdtemp() + sitecustomize_path = os.path.join(tmpdir, "sitecustomize.py") + + with open(sitecustomize_path, "w") as f: + f.write(""" +import mindspore + +class Distributed: + P2POp = None + +if not hasattr(mindspore.mint, 'distributed'): + setattr(mindspore.mint, 'distributed', Distributed()) + """) + + env = os.environ.copy() + env["PYTHONPATH"] = f"{tmpdir}:{env.get('PYTHONPATH', '')}" + pytest_cmd = [ "python3", "-m", "pytest", ut_path,