diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index 6dc5d510ef51ab2a135a8bdf9f15ac670fba9e56..3ca9fb4951e0257b968cead790ba96c9f3ccd0f3 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,6 +17,14 @@ from msprobe.core.data_dump.scope import ModuleRangeScope, MixRangeScope from msprobe.core.common.const import Const +def get_cell_construct(construct): + def _construct(self, *args, **kwargs): + if hasattr(self, 'msprobe_hook'): + setattr(self, 'msprobe_input_kwargs', kwargs) + return construct(self, *args, **kwargs) + return _construct + + class CellProcessor: cell_count = {} cell_stack = [] diff --git a/debug/accuracy_tools/msprobe/mindspore/common/utils.py b/debug/accuracy_tools/msprobe/mindspore/common/utils.py index d73fe9c3e8f4c3d426d336623af3f1ed3537cb14..453af63641e8deb49b9d91b7b4330a504ad137f0 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/utils.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/utils.py @@ -82,8 +82,8 @@ def convert_to_int(value): def clean_input_kwargs(cell): - if hasattr(cell, 'input_kwargs'): - del cell.input_kwargs + if hasattr(cell, 'msprobe_input_kwargs'): + del cell.msprobe_input_kwargs def list_lowest_level_directories(root_dir): diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py index 7007992ca4540a06b1ebc85a068179e88ec589cc..868d71bfc20d5a058a7dcb42ff05b44be3ba3862 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/hook_cell.py @@ -49,7 +49,7 @@ def __init__(self, hook_build_func) -> None: # 重载call,加全局标志。 def __call__(self, *args, **kwargs): try: - self.input_kwargs = kwargs + setattr(self, 'msprobe_input_kwargs', kwargs) out = super(HOOKCell, self).__call__(*args, **kwargs) except Exception as e: raise e diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 397e62884def9351b6a1c337ed28ac34e15ba259..9bf68827f506c2143269f71239851167eb981632 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -37,7 +37,7 @@ from msprobe.core.data_dump.data_collector import build_data_collector from msprobe.core.data_dump.data_processor.base import (ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, ModuleBackwardInputs) from msprobe.core.data_dump.scope import BaseScope -from msprobe.mindspore.cell_processor import CellProcessor +from msprobe.mindspore.cell_processor import CellProcessor, get_cell_construct from msprobe.mindspore.common.log import logger from msprobe.mindspore.common.utils import (get_rank_if_initialized, clean_input_kwargs, is_mindtorch, register_backward_hook_functions) @@ -98,18 +98,9 @@ class Service: MsprobeException.INVALID_PARAM_ERROR, error_info) return models - @staticmethod - def prepare_module_input_output(target_type, cell, input_data, output): - if target_type == BaseScope.Module_Type_Module: - module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs={}, output=output) - else: - module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.input_kwargs, output=output) - return module_input_output - def build_hook(self, target_type, name): def pre_hook(api_or_cell_name, cell, input_data): if not self.should_execute_hook(target_type, cell, True): - clean_input_kwargs(cell) return None with _no_grad(): @@ -119,7 +110,8 @@ class Service: else: cell.forward_data_collected = True HOOKCell.add_cell_count(name) - module_input_output = self.prepare_module_input_output(target_type, cell, input_data, None) + module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.msprobe_input_kwargs, + output=None) self.data_collector.update_api_or_module_name(api_or_cell_name) self.data_collector.forward_input_data_collect(api_or_cell_name, cell, pid, module_input_output) self.inner_switch = False @@ -176,7 +168,8 @@ class Service: return None with _no_grad(): self.inner_switch = True - module_input_output = self.prepare_module_input_output(target_type, cell, input_data, output) + module_input_output = ModuleForwardInputsOutputs(args=input_data, kwargs=cell.msprobe_input_kwargs, + output=output) if target_type == BaseScope.Module_Type_Module: api_or_cell_name = self.cell_processor.set_and_get_reserved_name(cell, api_or_cell_name) params_dict = {} @@ -200,11 +193,11 @@ class Service: self.data_collector.update_api_or_module_name(api_or_cell_name) self.data_collector.forward_output_data_collect(api_or_cell_name, cell, pid, module_input_output) + clean_input_kwargs(cell) if self.data_collector.if_return_forward_new_output(): forward_new_output = self.data_collector.get_forward_new_output() self.inner_switch = False return forward_new_output - clean_input_kwargs(cell) self.inner_switch = False return output @@ -468,6 +461,10 @@ class Service: for name, cell in cells_and_names: if cell == model: continue + if not hasattr(cell.__class__, 'msprobe_construct'): + setattr(cell.__class__, 'msprobe_construct', True) + setattr(cell.__class__, 'construct', get_cell_construct(cell.__class__.construct)) + setattr(cell, 'msprobe_hook', True) cell_index = (index + Const.SEP) if index != "-1" else "" prefix = (model_type + Const.SEP + cell_index + name + Const.SEP + cell.__class__.__name__ + Const.SEP) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 96d76c17d3ce50cba87feb5a5f392f179758aad9..f12fb2b0fbc16b2be86a5e9b4dc9d967be68788b 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -285,6 +285,7 @@ class TestService(unittest.TestCase): global register_backward_hook_functions self.service.config.level = Const.LEVEL_L0 cell_mock = MagicMock() + setattr(MagicMock, 'construct', None) self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] register_backward_hook_functions["pre"] = cell_mock.register_backward_pre_hook register_backward_hook_functions["full"] = cell_mock.register_backward_hook @@ -293,6 +294,7 @@ class TestService(unittest.TestCase): cell_mock.register_backward_hook.assert_called() mock_node_hook.assert_called() register_backward_hook_functions = {} + del MagicMock.construct def test_register_hook_new_without_model_raises_exception(self): self.service.config.level = Const.LEVEL_L0