diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 99fa4bd4ddaedf4b77c07876a9d2dbe51b85e32b..783ec7e4d6f154059fba039ad273eb758c087ab6 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -20,6 +20,7 @@ import mindspore as ms from mindspore import ops from mindspore.mint import nn +from mindspore._c_expression import Tensor as Tensor_ from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy from msprobe.core.common.log import logger @@ -54,7 +55,7 @@ def convert_to_int(value): return int(value) except Exception: return -1 - + def clean_input_kwargs(cell): if hasattr(cell, 'input_kwargs'): @@ -139,3 +140,33 @@ def remove_dropout(): ops.operations.Dropout3D = Dropout3D nn.Dropout = DropoutExt nn.functional.dropout = dropout_ext + + +mindtorch_check_result = None + + +def is_mindtorch(): + global mindtorch_check_result + if mindtorch_check_result is None: + try: + import torch + import torch_npu + except ImportError: + mindtorch_check_result = False + return mindtorch_check_result + + tensor_object = torch.tensor(0) + # 假设 Tensor_ 是 MindTorch 的张量类 + if isinstance(tensor_object, Tensor_) or isinstance(tensor_object.data, Tensor_): + mindtorch_check_result = True + else: + mindtorch_check_result = False + return mindtorch_check_result + + +# 定义一个通用的函数来获取模型的子模块 +def get_model_modules(nn_model): + if is_mindtorch(): + return nn_model.cells_and_names() + else: + return nn_model.named_modules() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py index 656e48c678956563a6f2d1d5f5ab8a4d03f074e7..d2448e827a22d2e01c3da0c403e0068a66031f99 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py @@ -17,6 +17,7 @@ import os from mindspore import ops from mindspore.common.tensor import Tensor +from mindspore.ops import Primitive from msprobe.core.common.utils import Const, DumpException from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs, @@ -24,6 +25,29 @@ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, Mo from msprobe.mindspore.common.log import logger +def find_primitives(module_or_cell, primitives_set): + """ + 递归遍历模块,查找所有 Primitive 的实例,并将其名称和实例添加到 primitives_set 中。 + + :param module_or_cell: 需要遍历的 MindSpore 与 Mindtorch模块 + :param primitives_set: 用于存储 (pname, instance) 的集合 + """ + + # 遍历所有其他属性,检查是否为 Primitive 实例 + print(f"遍历子模块: {module_or_cell}, 类型: {type(module_or_cell)}") + for attr_name, attr_value in vars(module_or_cell).items(): + + if isinstance(attr_value, Primitive): + if (attr_name, attr_value) not in primitives_set: + primitives_set.add((attr_name, attr_value)) + print(f"找到 Primitive (属性): {attr_name}, 实例: {attr_value}") + else: + print(f"Primitive 已存在于集合中: {attr_name}") # 打印已存在的 Primitive + else: + print(f"属性 {attr_name} 不是 Primitive 类型") # 打印属性不是 Primitive 类型的提示 + print("-" * 50) # 分隔线,便于阅读输出 + + class PrimitiveHookService: def __init__(self, service_instance): self.primitive_counters = {} @@ -58,7 +82,9 @@ class PrimitiveHookService: def backward_hook(grad): captured_grads.extend(grad) backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" - + if self.service_instance.inner_switch: + return + self.service_instance.inner_switch = True try: if hook_type == Const.INPUT: self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name) @@ -77,6 +103,7 @@ class PrimitiveHookService: logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, " f"updated_primitive_name: {updated_primitive_name}") raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception + self.service_instance.inner_switch = False return backward_hook @@ -96,11 +123,20 @@ class PrimitiveHookService: num_tensors = sum(isinstance(arg, Tensor) for arg in args) input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, Const.INPUT) - for arg in args: + print(f"[hook_primitive_inputs] (args) 类型: {type((args))}") + for idx, arg in enumerate(args): if isinstance(arg, Tensor): - arg_hooked = ops.HookBackward(input_backward_hook)(arg) - hooked_inputs.append(arg_hooked) + # print(f"[hook_primitive_inputs] arg: {arg}") + print(f"[hook_primitive_inputs] arg 类型: {type(arg)}") + print(f"[hook_primitive_inputs] 为第 {idx} 个 Tensor 输入添加 hook。") + # arg_hooked = ops.HookBackward(input_backward_hook)(arg) + arg_hooked = arg.register_hook(input_backward_hook) + hooked_inputs.append(arg) + print(f"[hook_primitive_inputs] Tensor 输入 {idx} 已添加 hook。") + # print(f"[hook_primitive_inputs] arg_hooked: {arg_hooked}") + print(f"[hook_primitive_inputs] arg_hooked 类型: {type(arg_hooked)}") else: + print(f"[hook_primitive_inputs] 第 {idx} 个输入不是 Tensor,跳过 hook。") hooked_inputs.append(arg) return tuple(hooked_inputs) @@ -124,12 +160,14 @@ class PrimitiveHookService: updated_primitive_name, Const.OUTPUT) if isinstance(out, Tensor): - return ops.HookBackward(output_backward_hook)(out) + out.register_hook(output_backward_hook) + return out elif isinstance(out, tuple): hooked_outputs = [] for tensor in out: if isinstance(tensor, Tensor): - hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor)) + tensor.register_hook(output_backward_hook) + hooked_outputs.append(tensor) else: hooked_outputs.append(tensor) return tuple(hooked_outputs) @@ -137,6 +175,9 @@ class PrimitiveHookService: def pre_forward_hook(primitive_name, primitive_instance, args, kwargs): module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) + if self.service_instance.inner_switch: + return + self.service_instance.inner_switch = True try: self.service_instance.data_collector.forward_input_data_collect( primitive_name, @@ -148,9 +189,13 @@ class PrimitiveHookService: logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, " f"primitive_name: {primitive_name}") raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception + self.service_instance.inner_switch = False def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output): module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + if self.service_instance.inner_switch: + return + self.service_instance.inner_switch = True try: self.service_instance.data_collector.forward_output_data_collect( primitive_name, @@ -162,6 +207,7 @@ class PrimitiveHookService: logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, " f"primitive_name: {primitive_name}") raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception + self.service_instance.inner_switch = False def wrapped_primitive_call(instance_self, *args, **kwargs): """ @@ -175,13 +221,14 @@ class PrimitiveHookService: Returns: Tensor/tuple: primitive 的返回值。 """ + if not self.service_instance.primitive_switch: + return origin_func(*args, **kwargs) + if self.service_instance.inner_switch: + return origin_func(*args, **kwargs) self.update_primitive_counters(primitive_name) current_count = self.primitive_counters.get(primitive_name, 0) updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}" - if not self.service_instance.primitive_switch: - return origin_func(*args, **kwargs) - captured_grads_input, captured_grads_output = [], [] try: @@ -191,8 +238,13 @@ class PrimitiveHookService: f"primitive_name: {primitive_name}") raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception + if self.service_instance.inner_switch: + return origin_func(*args, **kwargs) + self.service_instance.inner_switch = True forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}" + self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name) + self.service_instance.inner_switch = False pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs) try: diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 3b173ac26408512d9d2ee77c666b5d86e66cc01e..583e54f13076e1d693f05be9dfd96eccfde6cdb1 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -20,6 +20,8 @@ from collections import defaultdict import mindspore as ms from mindspore import nn +import torch +from mindspore.ops import Primitive from mindspore.common.api import _no_grad try: from mindspore.common._pijit_context import PIJitCaptureContext @@ -36,9 +38,9 @@ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutpu from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import get_rank_if_initialized, clean_input_kwargs +from msprobe.mindspore.common.utils import get_rank_if_initialized, clean_input_kwargs, is_mindtorch, get_model_modules from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService +from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService, find_primitives from msprobe.mindspore.dump.jit_dump import JitDump from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell @@ -64,11 +66,11 @@ class Service: @staticmethod def check_model_valid(model): - if not model or isinstance(model, nn.Cell): - return model - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。" - ) + # if not model or isinstance(model, nn.Cell) : + return model + # raise MsprobeException( + # MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。" + # ) @staticmethod def prepare_module_input_output(target_type, cell, input_data, output): @@ -202,17 +204,24 @@ class Service: return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook - def update_primitive_counters(self, primitive_name): - if primitive_name not in self.primitive_counters: - self.primitive_counters[primitive_name] = 0 - else: - self.primitive_counters[primitive_name] += 1 def register_primitive_hooks(self): primitive_set = set() - for _, cell in self.model.cells_and_names(): - for pname, primitive in cell._primitives.items(): - primitive_set.add((pname, primitive)) + print("=== Debugging: 在 register_primitive_hooks 中调用 find_primitives 前的 self.model ===") + # print(self.model) + # print(type(self.model)) + # # 使用 find_primitives 递归查找所有 Primitive 实例 + # self.find_primitives(self.model, primitive_set) + # for _, cell in self.model.cells_and_names(): + # for pname, primitive in cell._primitives.items(): + # primitive_set.add((pname, primitive)) + for name, module in get_model_modules(self.model): + # full_name = f"{prefix}.{name}" if prefix else name + print(f"遍历子模块: {name}, 类型: {type(module)}") + # if isinstance(sub_module, torch.nn.Module): + find_primitives(module, primitive_set) + + for pname, primitive in primitive_set: primitive_class_name = primitive.__class__.__name__