From 07e489627876382ada7f8a7499847272780e0668 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 28 Apr 2025 19:57:05 +0800 Subject: [PATCH] fix kwargs acquisition bug --- .../msprobe/mindspore/cell_processor.py | 16 ------- .../test/mindspore_ut/test_cell_processor.py | 44 +------------------ 2 files changed, 1 insertion(+), 59 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index 3ea85de6db8..731fbfab0a7 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -27,14 +27,6 @@ from msprobe.mindspore.common.log import logger from msprobe.mindspore.common.utils import is_mindtorch, get_cells_and_names -def get_cell_construct(construct): - def _construct(self, *args, **kwargs): - if hasattr(self, 'msprobe_hook'): - setattr(self, 'msprobe_input_kwargs', kwargs) - return construct(self, *args, **kwargs) - return _construct - - class CellProcessor: cell_count = {} cell_stack = [] @@ -80,14 +72,6 @@ class CellProcessor: if cell == model: continue - if not hasattr(cell.__class__, 'msprobe_construct'): - setattr(cell.__class__, 'msprobe_construct', True) - if is_mindtorch(): - setattr(cell.__class__, 'forward', get_cell_construct(cell.__class__.forward)) - else: - setattr(cell.__class__, 'construct', get_cell_construct(cell.__class__.construct)) - setattr(cell, 'msprobe_hook', True) - cell_index = (index + Const.SEP) if index != "-1" else "" prefix = f'{model_type}{Const.SEP}{cell_index}{name}{Const.SEP}{cell.__class__.__name__}{Const.SEP}' diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py index e1d7e72ba2a..559b742b89a 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py @@ -23,7 +23,7 @@ from mindspore.ops.operations import _inner_ops from msprobe.core.common.const import Const from msprobe.core.common.exceptions import MsprobeException from msprobe.core.data_dump.scope import ModuleRangeScope -from msprobe.mindspore.cell_processor import CellProcessor, get_cell_construct +from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger @@ -52,24 +52,6 @@ class TestCellProcessor(unittest.TestCase): processor = CellProcessor(None) self.assertIsNone(processor.scope) - def test_get_cell_construct(self): - def construct(self, *args, **kwargs): - return len(args) - - _constrct = get_cell_construct(construct) - ret = _constrct(self, 'argument') - self.assertFalse(hasattr(self, 'msprobe_input_kwargs')) - self.assertEqual(ret, 1) - - setattr(self, 'msprobe_hook', True) - _constrct = get_cell_construct(construct) - ret = _constrct(self, 'argument') - self.assertEqual(self.msprobe_input_kwargs, {}) - self.assertEqual(ret, 1) - - del self.msprobe_hook - del self.msprobe_input_kwargs - def test_set_and_get_calls_number(self): CellProcessor.cell_count = {} count = self.processor.set_and_get_calls_number("cell") @@ -108,60 +90,36 @@ class TestCellProcessor(unittest.TestCase): with patch('msprobe.mindspore.cell_processor.is_mindtorch') as mock_is_mindtorch, \ patch('msprobe.mindspore.cell_processor.get_cells_and_names') as mock_get_cells_and_names, \ patch('msprobe.mindspore.cell_processor.CellProcessor.build_cell_hook') as mock_build_cell_hook, \ - patch('msprobe.mindspore.cell_processor.get_cell_construct') as mock_get_cell_construct, \ patch.object(logger, 'info') as mock_logger_info: mock_cell = MagicMock() mock_sub_cell = MagicMock() mock_get_cells_and_names.return_value = {'-1': [('cell', mock_cell), ('sub_cell', mock_sub_cell)]} mock_build_cell_hook.return_value = 'forward_pre_hook' - mock_get_cell_construct.return_value = '_construct' mock_is_mindtorch.return_value = False - setattr(MagicMock, 'construct', 'construct') self.processor.register_cell_hook(mock_cell, None) - self.assertTrue(mock_sub_cell.__class__.msprobe_construct) - mock_get_cell_construct.assert_called_with('construct') - self.assertEqual(mock_sub_cell.__class__.construct, '_construct') - self.assertTrue(mock_sub_cell.msprobe_hook) mock_build_cell_hook.assert_called_with('Cell.sub_cell.MagicMock.', None) mock_cell.assert_not_called() mock_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') mock_sub_cell.register_forward_hook.assert_not_called() mock_logger_info.assert_called_with('The cell hook function is successfully mounted to the model.') - del MagicMock.construct - del mock_sub_cell.__class__.construct - del mock_sub_cell.__class__.msprobe_construct - - mock_get_cell_construct.reset_mock() mock_another_sub_cell = MagicMock() - setattr(mock_another_sub_cell.__class__, 'msprobe_construct', True) mock_get_cells_and_names.return_value = {'-1': [('cell', mock_cell), ('another_sub_cell', mock_another_sub_cell)]} self.processor.register_cell_hook(mock_cell, None) - mock_get_cell_construct.assert_not_called() mock_another_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') mock_another_sub_cell.register_forward_hook.assert_not_called() - del mock_another_sub_cell.__class__.msprobe_construct - mock_build_cell_hook.reset_mock() - mock_get_cell_construct.reset_mock() mock_another_sub_cell.reset_mock() - setattr(MagicMock, 'forward', 'forward') mock_is_mindtorch.return_value = True self.processor.register_cell_hook(mock_cell, None) - self.assertTrue(mock_another_sub_cell.__class__.msprobe_construct) - mock_get_cell_construct.assert_called_with('forward') mock_build_cell_hook.assert_called_with('Module.another_sub_cell.MagicMock.', None) mock_cell.assert_not_called() mock_another_sub_cell.register_forward_pre_hook.assert_called_with('forward_pre_hook') mock_another_sub_cell.register_forward_hook.assert_not_called() - del MagicMock.forward - del mock_another_sub_cell.__class__.forward - del mock_another_sub_cell.__class__.msprobe_construct - def test_build_cell_hook(self): CellProcessor.reset_cell_stats() -- Gitee