From dcea5940514f7959e10cf46ad5166cee6d3e4327 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Mon, 13 Jan 2025 16:29:11 +0800 Subject: [PATCH 1/3] =?UTF-8?q?primitive=E6=94=AF=E6=8C=81mindtorch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/mindspore/common/utils.py | 33 ++++++- .../dump/hook_cell/primitive_hooks.py | 95 +++++++++++++++++-- .../msprobe/mindspore/service.py | 39 +++++--- 3 files changed, 142 insertions(+), 25 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 99fa4bd4d..6d8688883 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -20,6 +20,7 @@ import mindspore as ms from mindspore import ops from mindspore.mint import nn +from mindspore._c_expression import Tensor as Tensor_ from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.core.common.file_utils import path_len_exceeds_limit, check_path_exists, save_npy from msprobe.core.common.log import logger @@ -54,7 +55,7 @@ def convert_to_int(value): return int(value) except Exception: return -1 - + def clean_input_kwargs(cell): if hasattr(cell, 'input_kwargs'): @@ -139,3 +140,33 @@ def remove_dropout(): ops.operations.Dropout3D = Dropout3D nn.Dropout = DropoutExt nn.functional.dropout = dropout_ext + + +mindtorch_check_result = None + + +def is_mindtorch(): + global mindtorch_check_result + if mindtorch_check_result is None: + try: + import torch + import torch_npu # 假设 torch_npu 是 MindTorch 所需的包 + except ImportError: + mindtorch_check_result = False + return mindtorch_check_result + + tensor_object = torch.tensor(0) + # 假设 Tensor_ 是 MindTorch 的张量类 + if isinstance(tensor_object, Tensor_) or isinstance(tensor_object.data, Tensor_): + mindtorch_check_result = True + else: + mindtorch_check_result = False + return mindtorch_check_result + + +# 然后,定义一个通用的函数来获取模型的子模块 +def get_model_modules(nn_model): + if is_mindtorch(): + return nn_model.cells_and_names() + else: + return nn_model.named_modules() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py index 656e48c67..3c97f4686 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py @@ -17,6 +17,7 @@ import os from mindspore import ops from mindspore.common.tensor import Tensor +from mindspore.ops import Primitive from msprobe.core.common.utils import Const, DumpException from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, ModuleBackwardOutputs, @@ -24,6 +25,54 @@ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, Mo from msprobe.mindspore.common.log import logger +def find_primitives(nn_model, primitives_set, visited=None): + """ + 递归遍历模块,查找所有 Primitive 的实例,并将其名称和实例添加到 primitives_set 中。 + + :param module: 需要遍历的 MindSpore 模块 + :param primitives_set: 用于存储 (pname, instance) 的集合 + :param visited: 用于跟踪已访问模块的集合 + """ + # if visited is None: + # visited = set() + + print("=== Debugging: find_primitives 被调用 ===") + print("module:", nn_model) + print("module 类型:", type(nn_model)) + print("===================================") + + # if id(nn_model) in visited: + # print(f"已访问模块: {prefix if prefix else 'root'},跳过") + # return + # visited.add(id(nn_model)) + + print(f"正在遍历模块类型: {type(nn_model)}") + + # 遍历 cells_and_names() + # for name, sub_module in nn_model.cells_and_names(): + # named_modules + # for name, sub_module in nn_model.named_modules(): + + # 遍历所有其他属性,检查是否为 Primitive 实例 + print(f"遍历子模块: {nn_model}, 类型: {type(nn_model)}") + for attr_name, attr_value in vars(nn_model).items(): + print(f"检查属性: {attr_name}, 类型: {type(attr_value)}") + # if attr_name.startswith('_'): + # continue # 跳过私有和特殊属性 + full_attr_name = f"{prefix}.{attr_name}" if prefix else attr_name + print(f"完整属性名: {full_attr_name}") # 打印完整的属性名称 + + if isinstance(attr_value, Primitive): + if (full_attr_name, attr_value) not in primitives_set: + primitives_set.add((attr_name, attr_value)) + print(f"找到 Primitive (属性): {full_attr_name}, 实例: {attr_value}") + else: + print(f"Primitive 已存在于集合中: {full_attr_name}") # 打印已存在的 Primitive + else: + print(f"属性 {full_attr_name} 不是 Primitive 类型") # 打印属性不是 Primitive 类型的提示 + print("-" * 50) # 分隔线,便于阅读输出 + + class PrimitiveHookService: def __init__(self, service_instance): self.primitive_counters = {} @@ -58,7 +107,9 @@ class PrimitiveHookService: def backward_hook(grad): captured_grads.extend(grad) backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" - + if self.service_instance.inner_switch: + return + self.service_instance.inner_switch = True try: if hook_type == Const.INPUT: self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name) @@ -77,6 +128,7 @@ class PrimitiveHookService: logger.error(f"This is a primitive op {hook_type}_backward dump error: {exception}, " f"updated_primitive_name: {updated_primitive_name}") raise DumpException(DumpException.BACKWARD_DATA_COLLECTION_ERROR) from exception + self.service_instance.inner_switch = False return backward_hook @@ -96,11 +148,20 @@ class PrimitiveHookService: num_tensors = sum(isinstance(arg, Tensor) for arg in args) input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, Const.INPUT) - for arg in args: + print(f"[hook_primitive_inputs] (args) 类型: {type((args))}") + for idx, arg in enumerate(args): if isinstance(arg, Tensor): - arg_hooked = ops.HookBackward(input_backward_hook)(arg) - hooked_inputs.append(arg_hooked) + # print(f"[hook_primitive_inputs] arg: {arg}") + print(f"[hook_primitive_inputs] arg 类型: {type(arg)}") + print(f"[hook_primitive_inputs] 为第 {idx} 个 Tensor 输入添加 hook。") + # arg_hooked = ops.HookBackward(input_backward_hook)(arg) + arg_hooked = arg.register_hook(input_backward_hook) + hooked_inputs.append(arg) + print(f"[hook_primitive_inputs] Tensor 输入 {idx} 已添加 hook。") + # print(f"[hook_primitive_inputs] arg_hooked: {arg_hooked}") + print(f"[hook_primitive_inputs] arg_hooked 类型: {type(arg_hooked)}") else: + print(f"[hook_primitive_inputs] 第 {idx} 个输入不是 Tensor,跳过 hook。") hooked_inputs.append(arg) return tuple(hooked_inputs) @@ -124,12 +185,14 @@ class PrimitiveHookService: updated_primitive_name, Const.OUTPUT) if isinstance(out, Tensor): - return ops.HookBackward(output_backward_hook)(out) + out.register_hook(output_backward_hook) + return out elif isinstance(out, tuple): hooked_outputs = [] for tensor in out: if isinstance(tensor, Tensor): - hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor)) + tensor.register_hook(output_backward_hook) + hooked_outputs.append(tensor) else: hooked_outputs.append(tensor) return tuple(hooked_outputs) @@ -137,6 +200,9 @@ class PrimitiveHookService: def pre_forward_hook(primitive_name, primitive_instance, args, kwargs): module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) + if self.service_instance.inner_switch: + return + self.service_instance.inner_switch = True try: self.service_instance.data_collector.forward_input_data_collect( primitive_name, @@ -148,9 +214,13 @@ class PrimitiveHookService: logger.error(f"This is a primitive op dump error during forward input data collection: {exception}, " f"primitive_name: {primitive_name}") raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception + self.service_instance.inner_switch = False def post_forward_hook(primitive_name, primitive_instance, args, kwargs, output): module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + if self.service_instance.inner_switch: + return + self.service_instance.inner_switch = True try: self.service_instance.data_collector.forward_output_data_collect( primitive_name, @@ -162,6 +232,7 @@ class PrimitiveHookService: logger.error(f"This is a primitive op dump error during forward output data collection: {exception}, " f"primitive_name: {primitive_name}") raise DumpException(DumpException.FORWARD_DATA_COLLECTION_ERROR) from exception + self.service_instance.inner_switch = False def wrapped_primitive_call(instance_self, *args, **kwargs): """ @@ -175,13 +246,14 @@ class PrimitiveHookService: Returns: Tensor/tuple: primitive 的返回值。 """ + if not self.service_instance.primitive_switch: + return origin_func(*args, **kwargs) + if self.service_instance.inner_switch: + return origin_func(*args, **kwargs) self.update_primitive_counters(primitive_name) current_count = self.primitive_counters.get(primitive_name, 0) updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}" - if not self.service_instance.primitive_switch: - return origin_func(*args, **kwargs) - captured_grads_input, captured_grads_output = [], [] try: @@ -191,8 +263,13 @@ class PrimitiveHookService: f"primitive_name: {primitive_name}") raise DumpException(DumpException.INPUT_HOOK_ERROR) from exception + if self.service_instance.inner_switch: + return origin_func(*args, **kwargs) + self.service_instance.inner_switch = True forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}" + self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name) + self.service_instance.inner_switch = False pre_forward_hook(forward_primitive_name, instance_self, hooked_inputs, kwargs) try: diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 3b173ac26..7018e0ca1 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -20,6 +20,8 @@ from collections import defaultdict import mindspore as ms from mindspore import nn +import torch +from mindspore.ops import Primitive from mindspore.common.api import _no_grad try: from mindspore.common._pijit_context import PIJitCaptureContext @@ -36,9 +38,9 @@ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutpu from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import get_rank_if_initialized, clean_input_kwargs +from msprobe.mindspore.common.utils import get_rank_if_initialized, clean_input_kwargs, is_mindtorch, get_model_modules from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService +from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService, find_primitives from msprobe.mindspore.dump.jit_dump import JitDump from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell @@ -64,11 +66,11 @@ class Service: @staticmethod def check_model_valid(model): - if not model or isinstance(model, nn.Cell): - return model - raise MsprobeException( - MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。" - ) + # if not model or isinstance(model, nn.Cell) : + return model + # raise MsprobeException( + # MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。" + # ) @staticmethod def prepare_module_input_output(target_type, cell, input_data, output): @@ -202,17 +204,24 @@ class Service: return wrap_pre_forward_hook, wrap_forward_hook, wrap_backward_hook - def update_primitive_counters(self, primitive_name): - if primitive_name not in self.primitive_counters: - self.primitive_counters[primitive_name] = 0 - else: - self.primitive_counters[primitive_name] += 1 def register_primitive_hooks(self): primitive_set = set() - for _, cell in self.model.cells_and_names(): - for pname, primitive in cell._primitives.items(): - primitive_set.add((pname, primitive)) + print("=== Debugging: 在 register_primitive_hooks 中调用 find_primitives 前的 self.model ===") + # print(self.model) + # print(type(self.model)) + # # 使用 find_primitives 递归查找所有 Primitive 实例 + # self.find_primitives(self.model, primitive_set) + # for _, cell in self.model.cells_and_names(): + # for pname, primitive in cell._primitives.items(): + # primitive_set.add((pname, primitive)) + for name, module in get_model_modules(self.model): + # full_name = f"{prefix}.{name}" if prefix else name + print(f"遍历子模块: {name}, 类型: {type(module)}") + # if isinstance(sub_module, torch.nn.Module): + find_primitives(module, primitive_set, name) + + for pname, primitive in primitive_set: primitive_class_name = primitive.__class__.__name__ -- Gitee From 61b0c9560da1f8ac1680dc52bae5f36c975de4a9 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Mon, 13 Jan 2025 17:16:55 +0800 Subject: [PATCH 2/3] =?UTF-8?q?primitive=E6=94=AF=E6=8C=81mindtorch?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../dump/hook_cell/primitive_hooks.py | 41 ++++--------------- 1 file changed, 8 insertions(+), 33 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py index 3c97f4686..d2448e827 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py @@ -25,51 +25,26 @@ from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputs, Mo from msprobe.mindspore.common.log import logger -def find_primitives(nn_model, primitives_set, visited=None): +def find_primitives(module_or_cell, primitives_set): """ 递归遍历模块,查找所有 Primitive 的实例,并将其名称和实例添加到 primitives_set 中。 - :param module: 需要遍历的 MindSpore 模块 + :param module_or_cell: 需要遍历的 MindSpore 与 Mindtorch模块 :param primitives_set: 用于存储 (pname, instance) 的集合 - :param visited: 用于跟踪已访问模块的集合 """ - # if visited is None: - # visited = set() - - print("=== Debugging: find_primitives 被调用 ===") - print("module:", nn_model) - print("module 类型:", type(nn_model)) - print("===================================") - - # if id(nn_model) in visited: - # print(f"已访问模块: {prefix if prefix else 'root'},跳过") - # return - # visited.add(id(nn_model)) - - print(f"正在遍历模块类型: {type(nn_model)}") - - # 遍历 cells_and_names() - # for name, sub_module in nn_model.cells_and_names(): - # named_modules - # for name, sub_module in nn_model.named_modules(): # 遍历所有其他属性,检查是否为 Primitive 实例 - print(f"遍历子模块: {nn_model}, 类型: {type(nn_model)}") - for attr_name, attr_value in vars(nn_model).items(): - print(f"检查属性: {attr_name}, 类型: {type(attr_value)}") - # if attr_name.startswith('_'): - # continue # 跳过私有和特殊属性 - full_attr_name = f"{prefix}.{attr_name}" if prefix else attr_name - print(f"完整属性名: {full_attr_name}") # 打印完整的属性名称 + print(f"遍历子模块: {module_or_cell}, 类型: {type(module_or_cell)}") + for attr_name, attr_value in vars(module_or_cell).items(): if isinstance(attr_value, Primitive): - if (full_attr_name, attr_value) not in primitives_set: + if (attr_name, attr_value) not in primitives_set: primitives_set.add((attr_name, attr_value)) - print(f"找到 Primitive (属性): {full_attr_name}, 实例: {attr_value}") + print(f"找到 Primitive (属性): {attr_name}, 实例: {attr_value}") else: - print(f"Primitive 已存在于集合中: {full_attr_name}") # 打印已存在的 Primitive + print(f"Primitive 已存在于集合中: {attr_name}") # 打印已存在的 Primitive else: - print(f"属性 {full_attr_name} 不是 Primitive 类型") # 打印属性不是 Primitive 类型的提示 + print(f"属性 {attr_name} 不是 Primitive 类型") # 打印属性不是 Primitive 类型的提示 print("-" * 50) # 分隔线,便于阅读输出 -- Gitee From 97b53c3b8d4ef36499348523953064f9416eb969 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Tue, 14 Jan 2025 17:21:10 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E4=BF=AE=E6=94=B9bug?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/msprobe/mindspore/common/utils.py | 4 ++-- debug/accuracy_tools/msprobe/mindspore/service.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index 6d8688883..783ec7e4d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -150,7 +150,7 @@ def is_mindtorch(): if mindtorch_check_result is None: try: import torch - import torch_npu # 假设 torch_npu 是 MindTorch 所需的包 + import torch_npu except ImportError: mindtorch_check_result = False return mindtorch_check_result @@ -164,7 +164,7 @@ def is_mindtorch(): return mindtorch_check_result -# 然后,定义一个通用的函数来获取模型的子模块 +# 定义一个通用的函数来获取模型的子模块 def get_model_modules(nn_model): if is_mindtorch(): return nn_model.cells_and_names() diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 7018e0ca1..583e54f13 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -219,7 +219,7 @@ class Service: # full_name = f"{prefix}.{name}" if prefix else name print(f"遍历子模块: {name}, 类型: {type(module)}") # if isinstance(sub_module, torch.nn.Module): - find_primitives(module, primitive_set, name) + find_primitives(module, primitive_set) -- Gitee