From 2903e4a0da1f6c3dd92f2cd912f2934966f9f149 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 09:47:15 +0800 Subject: [PATCH 01/28] V1.2 --- .../msprobe/mindspore/cell_processor.py | 5 +- .../test/mindspore_ut/test_cell_processor.py | 87 +++++++++++++++++-- 2 files changed, 85 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index b1c510d944..5a384e300c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -180,8 +180,9 @@ class CellProcessor: result = (result,) if len(result) != len(outputs): raise TypeError( - "The backward pre hook return value size is {} not equal to output size {}".format( - len(result), len(outputs))) + f"The backward pre hook return value size is {len(result)} " + f"not equal to output size {len(outputs)}" + ) return result return forward_pre_hook, forward_hook diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py index fbae6599d7..eef22ebe8f 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py @@ -182,7 +182,7 @@ class TestCellProcessor(unittest.TestCase): target_args = (Tensor([0.0]),) full_forward_name = f'{cell_name}{Const.FORWARD}.0' full_backward_name = f'{cell_name}{Const.BACKWARD}.0' - + # call testing function - forward_pre_hook ret = forward_pre_hook(mock_cell, args) self.assertIsNone(CellProcessor.module_node[full_forward_name]) self.assertEqual(CellProcessor.cell_stack, [full_forward_name]) @@ -193,11 +193,13 @@ class TestCellProcessor(unittest.TestCase): mock_CellBackwardHook.assert_called_with(full_backward_name, mock_cell, CellProcessor.cell_backward_hook[-1]) mock_bw.register_backward_hook.assert_called_once() + mock_bw.assert_called_with(*args) self.assertTrue((ret[0] == target_args[0]).all()) backward_hook = CellProcessor.cell_backward_hook[-1][full_backward_name] grad_input = (Tensor([1.0]),) grad_output = (Tensor([2.0]),) + # call testing function - backward_hook ret = backward_hook(mock_cell, grad_input, grad_output) mock_backward_data_hook.assert_called_with(mock_cell, grad_input, grad_output) self.assertFalse(mock_cell.has_pre_hook_called) @@ -209,6 +211,7 @@ class TestCellProcessor(unittest.TestCase): mock_build_data_hook.reset_mock() args = (Tensor([1], dtype=ms.int32),) full_forward_name = f'{cell_name}{Const.FORWARD}.1' + # call testing function - forward_pre_hook ret = forward_pre_hook(mock_cell, args) self.assertIsNone(CellProcessor.module_node[full_forward_name]) self.assertEqual(CellProcessor.cell_stack, [full_forward_name]) @@ -222,18 +225,92 @@ class TestCellProcessor(unittest.TestCase): CellProcessor.cell_stack = [full_forward_name] CellProcessor.api_parent_node = full_forward_name CellProcessor.module_node = {full_forward_name: None} + self.scope.reset_mock() mock_CellBackwardHook.reset_mock() mock_bw.reset_mock() - mock_backward_data_hook.reset_mock() - mock_forward_data_hook_hook = MagicMock() target_output = Tensor([0.5]) - mock_forward_data_hook_hook.return_value = target_output - mock_build_data_hook.return_value = (None, mock_forward_data_hook_hook, mock_backward_data_hook, None) args = (Tensor([1.0]),) output = Tensor([2.0]) + mock_bw.return_value = target_output + mock_backward_data_hook.reset_mock() + mock_forward_data_hook_hook = MagicMock() + mock_forward_data_hook_hook.return_value = output + mock_build_data_hook.return_value = (None, mock_forward_data_hook_hook, mock_backward_data_hook, None) + # call testing function - forward_hook ret = forward_hook(mock_cell, args, output) + self.assertEqual(CellProcessor.cell_count.get(cell_name), 0) + self.assertEqual(CellProcessor.cell_stack, []) + self.assertIsNone(CellProcessor.api_parent_node) + self.scope.end_module.assert_called_with(full_forward_name) + self.assertEqual(mock_bw.call_count, 2) + self.assertEqual(mock_bw.call_args_list[0][0][0], output) + self.assertEqual(mock_bw.call_args_list[1][0][0], target_output) + self.assertEqual(mock_CellBackwardHook.call_count, 1) + self.assertEqual(len(CellProcessor.cell_backward_pre_hook), 1) self.assertTrue((ret == target_output).all()) + backward_pre_hook = CellProcessor.cell_backward_pre_hook[-1][full_backward_name] + mock_backward_data_hook.reset_mock() + grad_output = (Tensor([2.0]),) + # call testing function - backward_pre_hook + ret = backward_pre_hook(mock_cell, grad_output) + self.assertTrue(mock_cell.has_pre_hook_called) + self.scope.begin_module.assert_called_with(full_backward_name) + self.assertEqual(CellProcessor.cell_stack, [full_backward_name]) + self.assertEqual(CellProcessor.api_parent_node, full_backward_name) + self.assertEqual(CellProcessor.module_node, {full_forward_name: None, full_backward_name: None}) + self.scope.begin_module.assert_called_with(full_backward_name) + mock_backward_data_hook.assert_not_called() + self.assertIsNone(ret) + + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack = [full_forward_name] + CellProcessor.api_parent_node = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + mock_bw.reset_mock() + args = (Tensor([1.0]),) + output = (Tensor([2.0]),) + mock_forward_data_hook_hook.return_value = output + target_output = (Tensor([0.5]),) + # call testing function - forward_hook + ret = forward_hook(mock_cell, args, output) + self.assertEqual(mock_bw.call_count, 2) + self.assertEqual(mock_bw.call_args_list[0][0][0], *output) + self.assertEqual(mock_bw.call_args_list[1][0][0], mock_bw.return_value) + self.assertTrue((ret[0] == target_output[0]).all()) + + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack = [full_forward_name] + CellProcessor.api_parent_node = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + CellProcessor.cell_bw_hook_kernels.clear() + CellProcessor.cell_backward_pre_hook.clear() + mock_bw.reset_mock() + mock_bw.return_value = (Tensor([0.5]),) + output = (Tensor([1.0]), Tensor([2.0])) + mock_forward_data_hook_hook.return_value = output + with self.assertRaises(TypeError) as context: + # call testing function - forward_hook + forward_hook(mock_cell, args, output) + self.assertEqual(str(context.exception), + 'The backward pre hook return value size is 1 not equal to output size 2') + mock_bw.assert_called_with(*output) + + self.scope.reset_mock() + backward_pre_hook = CellProcessor.cell_backward_pre_hook[-1][full_backward_name] + # call testing function - backward_pre_hook + ret = backward_pre_hook(mock_cell, grad_output) + self.assertFalse(mock_cell.has_pre_hook_called) + self.scope.begin_module.assert_called_with(full_backward_name) + mock_backward_data_hook.assert_called_with(mock_cell, (), grad_output) + self.assertEqual(CellProcessor.cell_stack, []) + self.assertIsNone(CellProcessor.api_parent_node) + self.assertEqual(CellProcessor.module_node, {full_forward_name: None, full_backward_name: None}) + self.scope.end_module.assert_called_with(full_backward_name) + self.assertIsNone(ret) + + CellProcessor.reset_cell_stats() + def test_set_construct_info_in_pre_hook(self): CellProcessor.reset_cell_stats() self.processor.set_construct_info_in_pre_hook('full_name') -- Gitee From f02cce16ca22c16d5649c72ee5a52d964a745a2e Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 11:41:00 +0800 Subject: [PATCH 02/28] V1.3 --- .../test/mindspore_ut/test_ms_service.py | 381 +++++++++--------- 1 file changed, 182 insertions(+), 199 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index f12fb2b0fb..030659472f 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -26,7 +26,6 @@ from msprobe.core.data_dump.api_registry import ApiRegistry from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import register_backward_hook_functions from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.dump.jit_dump import JitDump @@ -103,201 +102,185 @@ class TestService(unittest.TestCase): framework=Const.MS_FRAMEWORK ) - @patch.object(Service, 'need_end_service', return_value=False) - def test_start_stop_cycle(self, mock_need_end_service): - self.service.model = nn.Cell() - with patch.object(self.service, 'register_cell_hook') as mock_register_hook: - self.should_stop_service = False - self.service.start(self.service.model) - self.assertTrue(self.service.switch) - self.service.stop() - self.assertFalse(self.service.switch) - mock_register_hook.assert_called_once() - mock_need_end_service.assert_called_once() - - def test_should_execute_hook_return_false(self): - cell = MagicMock() - self.service.switch = False - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - self.assertFalse(self.service.should_execute_hook("api", cell, True)) - - self.service.switch = True - cell.forward_data_collected = False - self.assertFalse(self.service.should_execute_hook("api", cell, False)) - - self.service.inner_switch = True - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - - self.service.inner_switch = False - self.service.data_collector = None - self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - - def test_should_execute_hook_return_true(self): - cell = MagicMock() - self.service.switch = True - self.service.inner_switch = False - self.service.data_collector = MagicMock() - self.service.data_collector.data_processor = MagicMock() - self.service.data_collector.data_processor.is_terminated = False - self.assertTrue(self.service.should_execute_hook("Module", cell, True)) - - cell.forward_data_collected = True - self.assertTrue(self.service.should_execute_hook("api", cell, False)) - - def test_need_end_service_with_high_step(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 4 - self.assertTrue(self.service.need_end_service()) - - def test_need_end_service_with_low_step(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 2 - self.service.data_collector.data_processor.is_terminated = False - self.assertFalse(self.service.need_end_service()) - - def test_start_with_termination_condition(self): - self.service.config.step = [1, 2, 3] - self.service.current_iter = 4 - self.service.start() - self.assertFalse(self.service.switch) - self.assertTrue(self.service.should_stop_service) - self.assertFalse(self.service.primitive_switch) - - @patch('msprobe.mindspore.service.print_tools_ends_info') - @patch.object(Service, 'need_end_service', return_value=True) - def test_start_with_end_service(self, mock_need_end_service, mock_print_tools_ends_info): - self.service.start(self.service.model) - mock_need_end_service.assert_called_once() - mock_print_tools_ends_info.assert_called_once() - self.assertFalse(self.service.switch) - self.assertTrue(self.service.should_stop_service) - - @patch.object(Service, 'need_end_service', return_value=False) - @patch.object(logger, 'info') - @patch.object(Service, 'register_cell_hook') - @patch.object(Service, 'register_primitive_hook') - @patch.object(Service, 'create_dirs') - @patch('msprobe.mindspore.service.get_rank_if_initialized', return_value=0) - def test_start_first_time(self, mock_get_rank, mock_create_dirs, mock_register_primitive_hook, - mock_register_cell_hook, mock_logger, mock_need_end_service): - self.service.first_start = True - self.service.should_stop_service = False - self.service.start(self.service.model) - mock_get_rank.assert_called_once() - mock_register_cell_hook.assert_called_once() - mock_register_primitive_hook.assert_called_once() - mock_need_end_service.assert_called_once() - mock_create_dirs.assert_called_once() - self.assertFalse(self.service.first_start) - self.assertTrue(self.service.switch) - self.assertTrue(self.service.primitive_switch) - mock_logger.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") - - @patch.object(Service, 'register_primitive_hook') - @patch.object(Service, 'register_cell_hook') - @patch.object(Service, 'need_end_service', return_value=False) - @patch.object(JitDump, 'set_config') - @patch.object(JitDump, 'set_data_collector') - @patch.object(ApiRegistry, 'register_all_api') - def test_start_with_jit_dump_enabled(self, mock_api_set_hook_func, mock_set_data_collector, - mock_set_config, mock_need_end_service, mock_register_cell_hook, - mock_register_primitive_hook): - self.service.config.level = Const.LEVEL_MIX - self.service.first_start = True - self.service.should_stop_service = False - self.service.start(self.service.model) - mock_set_config.assert_called_with(self.service.config) - mock_set_data_collector.assert_called_with(self.service.data_collector) - mock_api_set_hook_func.assert_called_once() - mock_need_end_service.assert_called_once() - mock_register_cell_hook.assert_called_once() - mock_register_primitive_hook.assert_called_once() - self.assertTrue(JitDump.jit_dump_switch) - - def test_step_updates(self): - CellProcessor.cell_count = {"test_api": 1} - HOOKCell.cell_count = {"test_api": 1} - JitDump.jit_count = {"test_api": 1} - self.service.primitive_hook_service.primitive_counters = {"test_api": 1} - self.service.loop = 0 - self.service.step() - self.assertEqual(self.service.loop, 1) - self.service.data_collector.reset_status.assert_called_once() - self.assertEqual(JitDump.jit_count, defaultdict(int)) - self.assertEqual((self.service.primitive_hook_service.primitive_counters), {}) - - @patch.object(Service, 'should_execute_hook') - def test_build_forward_and_backward_hooks(self, mock_should_execute_hook): - mock_should_execute_hook.return_value = True - self.service.data_collector = MagicMock() - self.service.data_collector.update_api_or_module_name = MagicMock() - self.service.data_collector.forward_data_collect = MagicMock() - self.service.data_collector.if_return_forward_new_output = MagicMock(return_value=False) - self.service.data_collector.backward_data_collect = MagicMock() - - mock_cell = MagicMock() - mock_cell.mindstudio_reserved_name = "TestCell" - mock_input = (MagicMock(),) - mock_output = MagicMock() - - _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook") - - forward_hook(mock_cell, mock_input, mock_output) - self.service.data_collector.update_api_or_module_name.assert_called_with('TestCell') - self.service.data_collector.forward_data_collect.assert_called() - - self.service.data_collector.reset_mock() - - mock_grad_input = (MagicMock(),) - mock_grad_output = MagicMock() - - backward_hook(mock_cell, mock_grad_input, mock_grad_output) - self.service.data_collector.update_api_or_module_name.assert_called_with('TestHookbackward.0') - self.service.data_collector.backward_data_collect.assert_called() - - def test_register_primitive_hook(self): - self.service.config.level = Const.LEVEL_MIX - primitive_attr = ops.Add() - primitive_name = "primitive_api" - cell_mock = MagicMock() - cell_mock.primitive_api = primitive_attr - primitive_combined_name = primitive_name + Const.SEP + primitive_attr.__class__.__name__ - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - self.service.register_primitive_hook() - self.assertTrue(hasattr(primitive_attr.__class__, '__call__')) - self.assertEqual(self.service.primitive_hook_service.wrap_primitive.call_args[0][1], - primitive_combined_name) - - @patch.object(ApiRegistry, 'initialize_hook') - @patch.object(ApiRegistry, 'register_all_api') - @patch("msprobe.mindspore.service.logger.info") - def test_register_hook_new_with_level_mix(self, mock_logger, mock_api_set_hook_func, mock_initialize_hook): - self.service.config.level = Const.LEVEL_MIX - self.service.register_api_hook() - self.service.register_cell_hook() - mock_logger.assert_called_with(f"The cell {self.service.config.task} hook function " - "is successfully mounted to the model.") - mock_api_set_hook_func.assert_called() - mock_initialize_hook.assert_called() - - @patch.object(CellProcessor, 'node_hook') - def test_register_hook_new_with_level_l0(self, mock_node_hook): - global register_backward_hook_functions - self.service.config.level = Const.LEVEL_L0 - cell_mock = MagicMock() - setattr(MagicMock, 'construct', None) - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - register_backward_hook_functions["pre"] = cell_mock.register_backward_pre_hook - register_backward_hook_functions["full"] = cell_mock.register_backward_hook - self.service.register_cell_hook() - cell_mock.register_forward_hook.assert_called() - cell_mock.register_backward_hook.assert_called() - mock_node_hook.assert_called() - register_backward_hook_functions = {} - del MagicMock.construct - - def test_register_hook_new_without_model_raises_exception(self): - self.service.config.level = Const.LEVEL_L0 - self.service.model = None - with self.assertRaises(MsprobeException): - self.service.register_cell_hook() + # @patch.object(Service, 'need_end_service', return_value=False) + # def test_start_stop_cycle(self, mock_need_end_service): + # self.service.model = nn.Cell() + # with patch.object(self.service, 'register_cell_hook') as mock_register_hook: + # self.should_stop_service = False + # self.service.start(self.service.model) + # self.assertTrue(self.service.switch) + # self.service.stop() + # self.assertFalse(self.service.switch) + # mock_register_hook.assert_called_once() + # mock_need_end_service.assert_called_once() + + # def test_should_execute_hook_return_false(self): + # cell = MagicMock() + # self.service.switch = False + # self.assertFalse(self.service.should_execute_hook("Module", cell, True)) + # self.assertFalse(self.service.should_execute_hook("api", cell, True)) + + # self.service.switch = True + # cell.forward_data_collected = False + # self.assertFalse(self.service.should_execute_hook("api", cell, False)) + + # self.service.inner_switch = True + # self.assertFalse(self.service.should_execute_hook("Module", cell, True)) + + # self.service.inner_switch = False + # self.service.data_collector = None + # self.assertFalse(self.service.should_execute_hook("Module", cell, True)) + + # def test_should_execute_hook_return_true(self): + # cell = MagicMock() + # self.service.switch = True + # self.service.inner_switch = False + # self.service.data_collector = MagicMock() + # self.service.data_collector.data_processor = MagicMock() + # self.service.data_collector.data_processor.is_terminated = False + # self.assertTrue(self.service.should_execute_hook("Module", cell, True)) + + # cell.forward_data_collected = True + # self.assertTrue(self.service.should_execute_hook("api", cell, False)) + + # def test_need_end_service_with_high_step(self): + # self.service.config.step = [1, 2, 3] + # self.service.current_iter = 4 + # self.assertTrue(self.service.need_end_service()) + + # def test_need_end_service_with_low_step(self): + # self.service.config.step = [1, 2, 3] + # self.service.current_iter = 2 + # self.service.data_collector.data_processor.is_terminated = False + # self.assertFalse(self.service.need_end_service()) + + # def test_start_with_termination_condition(self): + # self.service.config.step = [1, 2, 3] + # self.service.current_iter = 4 + # self.service.start() + # self.assertFalse(self.service.switch) + # self.assertTrue(self.service.should_stop_service) + # self.assertFalse(self.service.primitive_switch) + + # @patch('msprobe.mindspore.service.print_tools_ends_info') + # @patch.object(Service, 'need_end_service', return_value=True) + # def test_start_with_end_service(self, mock_need_end_service, mock_print_tools_ends_info): + # self.service.start(self.service.model) + # mock_need_end_service.assert_called_once() + # mock_print_tools_ends_info.assert_called_once() + # self.assertFalse(self.service.switch) + # self.assertTrue(self.service.should_stop_service) + + # @patch.object(Service, 'need_end_service', return_value=False) + # @patch.object(logger, 'info') + # @patch.object(Service, 'register_cell_hook') + # @patch.object(Service, 'register_primitive_hook') + # @patch.object(Service, 'create_dirs') + # @patch('msprobe.mindspore.service.get_rank_if_initialized', return_value=0) + # def test_start_first_time(self, mock_get_rank, mock_create_dirs, mock_register_primitive_hook, + # mock_register_cell_hook, mock_logger, mock_need_end_service): + # self.service.first_start = True + # self.service.should_stop_service = False + # self.service.start(self.service.model) + # mock_get_rank.assert_called_once() + # mock_register_cell_hook.assert_called_once() + # mock_register_primitive_hook.assert_called_once() + # mock_need_end_service.assert_called_once() + # mock_create_dirs.assert_called_once() + # self.assertFalse(self.service.first_start) + # self.assertTrue(self.service.switch) + # self.assertTrue(self.service.primitive_switch) + # mock_logger.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") + + # @patch.object(Service, 'register_primitive_hook') + # @patch.object(Service, 'register_cell_hook') + # @patch.object(Service, 'need_end_service', return_value=False) + # @patch.object(JitDump, 'set_config') + # @patch.object(JitDump, 'set_data_collector') + # @patch.object(ApiRegistry, 'register_all_api') + # def test_start_with_jit_dump_enabled(self, mock_api_set_hook_func, mock_set_data_collector, + # mock_set_config, mock_need_end_service, mock_register_cell_hook, + # mock_register_primitive_hook): + # self.service.config.level = Const.LEVEL_MIX + # self.service.first_start = True + # self.service.should_stop_service = False + # self.service.start(self.service.model) + # mock_set_config.assert_called_with(self.service.config) + # mock_set_data_collector.assert_called_with(self.service.data_collector) + # mock_api_set_hook_func.assert_called_once() + # mock_need_end_service.assert_called_once() + # mock_register_cell_hook.assert_called_once() + # mock_register_primitive_hook.assert_called_once() + # self.assertTrue(JitDump.jit_dump_switch) + + # def test_step_updates(self): + # CellProcessor.cell_count = {"test_api": 1} + # HOOKCell.cell_count = {"test_api": 1} + # JitDump.jit_count = {"test_api": 1} + # self.service.primitive_hook_service.primitive_counters = {"test_api": 1} + # self.service.loop = 0 + # self.service.step() + # self.assertEqual(self.service.loop, 1) + # self.service.data_collector.reset_status.assert_called_once() + # self.assertEqual(JitDump.jit_count, defaultdict(int)) + # self.assertEqual((self.service.primitive_hook_service.primitive_counters), {}) + + # @patch.object(Service, 'should_execute_hook') + # def test_build_forward_and_backward_hooks(self, mock_should_execute_hook): + # mock_should_execute_hook.return_value = True + # self.service.data_collector = MagicMock() + # self.service.data_collector.update_api_or_module_name = MagicMock() + # self.service.data_collector.forward_data_collect = MagicMock() + # self.service.data_collector.if_return_forward_new_output = MagicMock(return_value=False) + # self.service.data_collector.backward_data_collect = MagicMock() + + # mock_cell = MagicMock() + # mock_cell.mindstudio_reserved_name = "TestCell" + # mock_input = (MagicMock(),) + # mock_output = MagicMock() + + # _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook") + + # forward_hook(mock_cell, mock_input, mock_output) + # self.service.data_collector.update_api_or_module_name.assert_called_with('TestCell') + # self.service.data_collector.forward_data_collect.assert_called() + + # self.service.data_collector.reset_mock() + + # mock_grad_input = (MagicMock(),) + # mock_grad_output = MagicMock() + + # backward_hook(mock_cell, mock_grad_input, mock_grad_output) + # self.service.data_collector.update_api_or_module_name.assert_called_with('TestHookbackward.0') + # self.service.data_collector.backward_data_collect.assert_called() + + # def test_register_primitive_hook(self): + # self.service.config.level = Const.LEVEL_MIX + # primitive_attr = ops.Add() + # primitive_name = "primitive_api" + # cell_mock = MagicMock() + # cell_mock.primitive_api = primitive_attr + # primitive_combined_name = primitive_name + Const.SEP + primitive_attr.__class__.__name__ + # self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] + # self.service.register_primitive_hook() + # self.assertTrue(hasattr(primitive_attr.__class__, '__call__')) + # self.assertEqual(self.service.primitive_hook_service.wrap_primitive.call_args[0][1], + # primitive_combined_name) + + # @patch.object(ApiRegistry, 'initialize_hook') + # @patch.object(ApiRegistry, 'register_all_api') + # @patch("msprobe.mindspore.service.logger.info") + # def test_register_hook_new_with_level_mix(self, mock_logger, mock_api_set_hook_func, mock_initialize_hook): + # self.service.config.level = Const.LEVEL_MIX + # self.service.register_api_hook() + # self.service.register_cell_hook() + # mock_logger.assert_called_with(f"The cell {self.service.config.task} hook function " + # "is successfully mounted to the model.") + # mock_api_set_hook_func.assert_called() + # mock_initialize_hook.assert_called() + + # def test_register_hook_new_without_model_raises_exception(self): + # self.service.config.level = Const.LEVEL_L0 + # self.service.model = None + # with self.assertRaises(MsprobeException): + # self.service.register_cell_hook() -- Gitee From 7ae0fad600392b185df308161c1317523242401b Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 14:55:14 +0800 Subject: [PATCH 03/28] V1.3 --- .../test/mindspore_ut/test_ms_service.py | 139 ++++++++++-------- 1 file changed, 81 insertions(+), 58 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 030659472f..cccf8b1244 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -24,7 +24,6 @@ from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.utils import Const from msprobe.core.data_dump.api_registry import ApiRegistry from msprobe.core.data_dump.scope import BaseScope -from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell @@ -40,67 +39,91 @@ class TestService(unittest.TestCase): self.config_mock.step = [] self.config_mock.rank = [] self.config_mock.task = Const.TENSOR - self.config_mock.framework = Const.MS_FRAMEWORK self.config_mock.list = [] self.config_mock.scope = [] - self.service = Service(self.config_mock) - self.service.model = MagicMock(spec=nn.Cell) - self.service.data_collector = MagicMock() - self.service.primitive_hook_service = MagicMock() - - def tearDown(self) -> None: - get_api_register().restore_all_api() def test_init(self): - self.assertEqual(self.service.config.level, "L0") - self.assertFalse(self.service.switch) - self.assertFalse(self.service.should_stop_service) - self.assertFalse(self.service.start_call) - self.assertTrue(self.service.first_start) - - def test_check_model_valid_with_valid_cell(self): - model = nn.Cell() - model_list = [model] - self.assertEqual(self.service.check_model_valid(model), model) - self.assertEqual(self.service.check_model_valid(model_list), model_list) - - def test_check_model_valid_with_invalid_type(self): - model = nn.Cell() - with self.assertRaises(MsprobeException): - self.service.check_model_valid("not a cell") - with self.assertRaises(MsprobeException): - self.service.check_model_valid(["not a cell", model]) - - def test_update_primitive_counters(self): - self.service.primitive_counters = {} - self.service.update_primitive_counters("conv2d") - self.assertEqual(self.service.primitive_counters["conv2d"], 0) - self.service.update_primitive_counters("conv2d") - self.assertEqual(self.service.primitive_counters["conv2d"], 1) - - @patch('msprobe.mindspore.service.create_directory') - def test_create_dirs(self, mock_create_directory): - self.service.current_iter = 1 - self.service.current_rank = 0 - self.service.data_collector.tasks_need_tensor_data = [Const.TENSOR] - self.service.data_collector.update_dump_paths = MagicMock() - self.service.create_dirs() - expected_calls = [ - ("/tmp/dump"), - ("/tmp/dump/step1/rank0"), - "/tmp/dump/step1/rank0/dump_tensor_data" - ] - mock_create_directory.assert_has_calls( - [unittest.mock.call(path) for path in expected_calls], any_order=True) - - args, _ = self.service.data_collector.update_dump_paths.call_args - self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") - self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") - self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") - self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") - self.service.data_collector.initialize_json_file.assert_called_once_with( - framework=Const.MS_FRAMEWORK - ) + with patch('msprobe.mindspore.service.build_data_collector') as mock_build_data_collector, \ + patch('msprobe.mindspore.service.CellProcessor') as mock_CellProcessor, \ + patch('msprobe.mindspore.service.PrimitiveHookService') as mock_PrimitiveHookService, \ + patch('msprobe.mindspore.service.get_api_register') as mock_get_api_register, \ + patch.object(Service, 'register_api_hook') as mock_register_api_hook, \ + patch.object(Service, 'init_for_debug_level') as mock_init_for_debug_level: + self.service = Service(self.config_mock) + self.assertIsNone(self.service.model) + self.assertEqual(self.service.config.level_ori, Const.LEVEL_L0) + self.assertEqual(self.service.config.dump_path, '/tmp/dump') + self.assertEqual(self.service.config.step, []) + self.assertEqual(self.service.config.rank, []) + self.assertEqual(self.service.config.task, Const.TENSOR) + self.assertEqual(self.service.config.list, []) + self.assertEqual(self.service.config.scope, []) + self.assertEqual(self.service.config.level, Const.LEVEL_L0) + mock_build_data_collector.assert_called_with(self.service.config) + mock_CellProcessor.assert_called_with(mock_build_data_collector.scope) + mock_PrimitiveHookService.assert_called_with(self.service) + self.assertFalse(self.service.switch) + self.assertFalse(self.service.inner_switch) + self.assertFalse(self.service.primitive_switch) + self.assertEqual(self.service.current_iter, 0) + self.assertEqual(self.service.loop, 0) + self.assertEqual(self.service.init_step, 0) + self.assertTrue(self.service.first_start) + self.assertIsNone(self.service.current_rank) + self.assertIsNone(self.service.dump_iter_dir) + self.assertFalse(self.service.start_call) + self.assertFalse(self.service.should_stop_service) + self.assertEqual(self.service.params_grad_info, {}) + self.assertEqual(self.service.hook_handle_dict, {}) + mock_get_api_register.assert_called_with() + mock_register_api_hook.assert_called_with() + mock_init_for_debug_level.assert_called_with() + + # self.service.model = MagicMock(spec=nn.Cell) + + # def test_check_model_valid_with_valid_cell(self): + # model = nn.Cell() + # model_list = [model] + # self.assertEqual(self.service.check_model_valid(model), model) + # self.assertEqual(self.service.check_model_valid(model_list), model_list) + + # def test_check_model_valid_with_invalid_type(self): + # model = nn.Cell() + # with self.assertRaises(MsprobeException): + # self.service.check_model_valid("not a cell") + # with self.assertRaises(MsprobeException): + # self.service.check_model_valid(["not a cell", model]) + + # def test_update_primitive_counters(self): + # self.service.primitive_counters = {} + # self.service.update_primitive_counters("conv2d") + # self.assertEqual(self.service.primitive_counters["conv2d"], 0) + # self.service.update_primitive_counters("conv2d") + # self.assertEqual(self.service.primitive_counters["conv2d"], 1) + + # @patch('msprobe.mindspore.service.create_directory') + # def test_create_dirs(self, mock_create_directory): + # self.service.current_iter = 1 + # self.service.current_rank = 0 + # self.service.data_collector.tasks_need_tensor_data = [Const.TENSOR] + # self.service.data_collector.update_dump_paths = MagicMock() + # self.service.create_dirs() + # expected_calls = [ + # ("/tmp/dump"), + # ("/tmp/dump/step1/rank0"), + # "/tmp/dump/step1/rank0/dump_tensor_data" + # ] + # mock_create_directory.assert_has_calls( + # [unittest.mock.call(path) for path in expected_calls], any_order=True) + + # args, _ = self.service.data_collector.update_dump_paths.call_args + # self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") + # self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") + # self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") + # self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") + # self.service.data_collector.initialize_json_file.assert_called_once_with( + # framework=Const.MS_FRAMEWORK + # ) # @patch.object(Service, 'need_end_service', return_value=False) # def test_start_stop_cycle(self, mock_need_end_service): -- Gitee From 6e6dd140f82539e4854c8af40ae109fd103f3bd7 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 15:25:20 +0800 Subject: [PATCH 04/28] V1.3 --- .../test/mindspore_ut/test_ms_service.py | 65 ++++++++++++++----- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index cccf8b1244..2058e00fb8 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -19,6 +19,7 @@ from collections import defaultdict from unittest.mock import MagicMock, patch from mindspore import nn, ops +import torch from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.utils import Const @@ -41,6 +42,11 @@ class TestService(unittest.TestCase): self.config_mock.task = Const.TENSOR self.config_mock.list = [] self.config_mock.scope = [] + with patch('msprobe.mindspore.service.build_data_collector'), \ + patch('msprobe.mindspore.service.CellProcessor'), \ + patch('msprobe.mindspore.service.PrimitiveHookService'), \ + patch('msprobe.mindspore.service.get_api_register'): + self.service = Service(self.config_mock) def test_init(self): with patch('msprobe.mindspore.service.build_data_collector') as mock_build_data_collector, \ @@ -60,7 +66,7 @@ class TestService(unittest.TestCase): self.assertEqual(self.service.config.scope, []) self.assertEqual(self.service.config.level, Const.LEVEL_L0) mock_build_data_collector.assert_called_with(self.service.config) - mock_CellProcessor.assert_called_with(mock_build_data_collector.scope) + mock_CellProcessor.assert_called_with(mock_build_data_collector.return_value.scope) mock_PrimitiveHookService.assert_called_with(self.service) self.assertFalse(self.service.switch) self.assertFalse(self.service.inner_switch) @@ -79,20 +85,49 @@ class TestService(unittest.TestCase): mock_register_api_hook.assert_called_with() mock_init_for_debug_level.assert_called_with() - # self.service.model = MagicMock(spec=nn.Cell) - - # def test_check_model_valid_with_valid_cell(self): - # model = nn.Cell() - # model_list = [model] - # self.assertEqual(self.service.check_model_valid(model), model) - # self.assertEqual(self.service.check_model_valid(model_list), model_list) - - # def test_check_model_valid_with_invalid_type(self): - # model = nn.Cell() - # with self.assertRaises(MsprobeException): - # self.service.check_model_valid("not a cell") - # with self.assertRaises(MsprobeException): - # self.service.check_model_valid(["not a cell", model]) + def test_check_model_valid(self): + with patch('msprobe.mindspore.service.is_mindtorch') as mock_is_mindtorch: + mock_is_mindtorch.return_value = False + model = None + self.assertIsNone(self.service.check_model_valid(model)) + model = nn.Cell() + self.assertEqual(self.service.check_model_valid(model), model) + model = 'model' + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(model) + self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + self.assertEqual(str(context.exception), + "The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " + "currently there is a type.") + models = [model] + self.assertEqual(self.service.check_model_valid(models), models) + models = [model, 'model'] + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(models) + self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + self.assertEqual(str(context.exception), + "The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " + "currently there is a type.") + + mock_is_mindtorch.return_value = True + model = torch.nn.Module() + self.assertEqual(self.service.check_model_valid(model), model) + model = 'model' + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(model) + self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + self.assertEqual(str(context.exception), + "The 'model' parameter must be a torch.nn.Module or list[mindspore.nn.Cell] type, " + "currently there is a type.") + models = [model] + self.assertEqual(self.service.check_model_valid(models), models) + models = [model, 'model'] + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(models) + self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + self.assertEqual(str(context.exception), + "The 'model' parameter must be a torch.nn.Module or list[mindspore.nn.Cell] type, " + "currently there is a type.") # def test_update_primitive_counters(self): # self.service.primitive_counters = {} -- Gitee From a6caa875fdef3813cad4369e70d3d921bbb7c8ae Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 15:29:57 +0800 Subject: [PATCH 05/28] V1.3 --- .../test/mindspore_ut/test_ms_service.py | 20 ++++++++----------- 1 file changed, 8 insertions(+), 12 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 2058e00fb8..0ffce9d483 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -96,18 +96,16 @@ class TestService(unittest.TestCase): with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(model) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - self.assertEqual(str(context.exception), - "The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " - "currently there is a type.") + self.assertIn("The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " + "currently there is a type.", str(context.exception)) models = [model] self.assertEqual(self.service.check_model_valid(models), models) models = [model, 'model'] with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(models) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - self.assertEqual(str(context.exception), - "The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " - "currently there is a type.") + self.assertIn("The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " + "currently there is a type.", str(context.exception)) mock_is_mindtorch.return_value = True model = torch.nn.Module() @@ -116,18 +114,16 @@ class TestService(unittest.TestCase): with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(model) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - self.assertEqual(str(context.exception), - "The 'model' parameter must be a torch.nn.Module or list[mindspore.nn.Cell] type, " - "currently there is a type.") + self.assertIn("The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " + "currently there is a type.", str(context.exception)) models = [model] self.assertEqual(self.service.check_model_valid(models), models) models = [model, 'model'] with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(models) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) - self.assertEqual(str(context.exception), - "The 'model' parameter must be a torch.nn.Module or list[mindspore.nn.Cell] type, " - "currently there is a type.") + self.assertIn("The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " + "currently there is a type.", str(context.exception)) # def test_update_primitive_counters(self): # self.service.primitive_counters = {} -- Gitee From c380553ec8128c0ff5ebc3d24a974b1e95efd165 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 15:31:08 +0800 Subject: [PATCH 06/28] V1.3 --- .../msprobe/test/mindspore_ut/test_ms_service.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 0ffce9d483..710ee332d4 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -90,14 +90,14 @@ class TestService(unittest.TestCase): mock_is_mindtorch.return_value = False model = None self.assertIsNone(self.service.check_model_valid(model)) - model = nn.Cell() - self.assertEqual(self.service.check_model_valid(model), model) model = 'model' with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(model) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) self.assertIn("The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " "currently there is a type.", str(context.exception)) + model = nn.Cell() + self.assertEqual(self.service.check_model_valid(model), model) models = [model] self.assertEqual(self.service.check_model_valid(models), models) models = [model, 'model'] @@ -108,14 +108,14 @@ class TestService(unittest.TestCase): "currently there is a type.", str(context.exception)) mock_is_mindtorch.return_value = True - model = torch.nn.Module() - self.assertEqual(self.service.check_model_valid(model), model) model = 'model' with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(model) self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) self.assertIn("The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " "currently there is a type.", str(context.exception)) + model = torch.nn.Module() + self.assertEqual(self.service.check_model_valid(model), model) models = [model] self.assertEqual(self.service.check_model_valid(models), models) models = [model, 'model'] -- Gitee From 448cbb779c7933d48a134137b99910daca023566 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 15:35:56 +0800 Subject: [PATCH 07/28] V1.3 --- .../accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 710ee332d4..a04b7192c0 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -108,6 +108,7 @@ class TestService(unittest.TestCase): "currently there is a type.", str(context.exception)) mock_is_mindtorch.return_value = True + setattr(Service, 'torch', torch) model = 'model' with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(model) -- Gitee From 866c778e90cf7cf69f962e971fed399e9ba1a2b9 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 16:57:58 +0800 Subject: [PATCH 08/28] V1.3 --- .../msprobe/test/mindspore_ut/test_ms_service.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index a04b7192c0..b724f82fdd 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -108,7 +108,14 @@ class TestService(unittest.TestCase): "currently there is a type.", str(context.exception)) mock_is_mindtorch.return_value = True - setattr(Service, 'torch', torch) + ori_check_model_valid = Service.check_model_valid + + def mock_check_model_valid(*args): + import torch + return ori_check_model_valid(*args) + + setattr(Service, 'check_model_valid', mock_check_model_valid) + model = 'model' with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(model) @@ -126,6 +133,8 @@ class TestService(unittest.TestCase): self.assertIn("The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " "currently there is a type.", str(context.exception)) + setattr(Service, 'check_model_valid', ori_check_model_valid) + # def test_update_primitive_counters(self): # self.service.primitive_counters = {} # self.service.update_primitive_counters("conv2d") -- Gitee From 400444dcc468217a8719053b8de3a2084e0e9d5d Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Sat, 19 Apr 2025 20:42:13 +0800 Subject: [PATCH 09/28] V1.5 --- .../msprobe/test/common_set_up/mindtorch.py | 14 ++++++++++++++ .../msprobe/test/common_set_up/test_set_up.py | 6 ++++++ .../msprobe/test/mindspore_ut/test_ms_service.py | 10 ---------- 3 files changed, 20 insertions(+), 10 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py b/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py new file mode 100644 index 0000000000..cada41ba2b --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py @@ -0,0 +1,14 @@ +from mindspore import Tensor +import torch + + +def create_msa_tensor(data, dtype=None): + return Tensor(data, dtype) + + +tensor_tensor = torch.tensor +setattr(torch, 'tensor', create_msa_tensor) + + +def reset_torch_tensor(): + setattr(torch, 'tensor', tensor_tensor) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py index 6442908bb0..47421278a1 100644 --- a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -17,6 +17,8 @@ from unittest import TestCase from unittest.mock import MagicMock from mindspore import mint +from mindtorch import reset_torch_tensor +from msprobe.mindspore.common.utils import is_mindtorch try: from mint import distributed @@ -24,7 +26,11 @@ except ImportError: distributed = MagicMock() setattr(mint, 'distributed', distributed) +mindtorch_check_result = is_mindtorch() +reset_torch_tensor() + class SetUp(TestCase): def test_case(self): self.assertTrue(hasattr(mint, 'distributed')) + self.assertTrue(mindtorch_check_result) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index b724f82fdd..710ee332d4 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -108,14 +108,6 @@ class TestService(unittest.TestCase): "currently there is a type.", str(context.exception)) mock_is_mindtorch.return_value = True - ori_check_model_valid = Service.check_model_valid - - def mock_check_model_valid(*args): - import torch - return ori_check_model_valid(*args) - - setattr(Service, 'check_model_valid', mock_check_model_valid) - model = 'model' with self.assertRaises(MsprobeException) as context: self.service.check_model_valid(model) @@ -133,8 +125,6 @@ class TestService(unittest.TestCase): self.assertIn("The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " "currently there is a type.", str(context.exception)) - setattr(Service, 'check_model_valid', ori_check_model_valid) - # def test_update_primitive_counters(self): # self.service.primitive_counters = {} # self.service.update_primitive_counters("conv2d") -- Gitee From 851f83b8d2a623417dffe651ad23edf79eecd5a2 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 14:53:06 +0800 Subject: [PATCH 10/28] V1.4 --- .../test/mindspore_ut/test_ms_service.py | 426 +++++++++--------- 1 file changed, 211 insertions(+), 215 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 710ee332d4..b2de4bfccd 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -23,10 +23,9 @@ import torch from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.utils import Const -from msprobe.core.data_dump.api_registry import ApiRegistry from msprobe.core.data_dump.scope import BaseScope +from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger -from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.dump.jit_dump import JitDump from msprobe.mindspore.service import Service @@ -125,216 +124,213 @@ class TestService(unittest.TestCase): self.assertIn("The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " "currently there is a type.", str(context.exception)) - # def test_update_primitive_counters(self): - # self.service.primitive_counters = {} - # self.service.update_primitive_counters("conv2d") - # self.assertEqual(self.service.primitive_counters["conv2d"], 0) - # self.service.update_primitive_counters("conv2d") - # self.assertEqual(self.service.primitive_counters["conv2d"], 1) - - # @patch('msprobe.mindspore.service.create_directory') - # def test_create_dirs(self, mock_create_directory): - # self.service.current_iter = 1 - # self.service.current_rank = 0 - # self.service.data_collector.tasks_need_tensor_data = [Const.TENSOR] - # self.service.data_collector.update_dump_paths = MagicMock() - # self.service.create_dirs() - # expected_calls = [ - # ("/tmp/dump"), - # ("/tmp/dump/step1/rank0"), - # "/tmp/dump/step1/rank0/dump_tensor_data" - # ] - # mock_create_directory.assert_has_calls( - # [unittest.mock.call(path) for path in expected_calls], any_order=True) - - # args, _ = self.service.data_collector.update_dump_paths.call_args - # self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") - # self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") - # self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") - # self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") - # self.service.data_collector.initialize_json_file.assert_called_once_with( - # framework=Const.MS_FRAMEWORK - # ) - - # @patch.object(Service, 'need_end_service', return_value=False) - # def test_start_stop_cycle(self, mock_need_end_service): - # self.service.model = nn.Cell() - # with patch.object(self.service, 'register_cell_hook') as mock_register_hook: - # self.should_stop_service = False - # self.service.start(self.service.model) - # self.assertTrue(self.service.switch) - # self.service.stop() - # self.assertFalse(self.service.switch) - # mock_register_hook.assert_called_once() - # mock_need_end_service.assert_called_once() - - # def test_should_execute_hook_return_false(self): - # cell = MagicMock() - # self.service.switch = False - # self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - # self.assertFalse(self.service.should_execute_hook("api", cell, True)) - - # self.service.switch = True - # cell.forward_data_collected = False - # self.assertFalse(self.service.should_execute_hook("api", cell, False)) - - # self.service.inner_switch = True - # self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - - # self.service.inner_switch = False - # self.service.data_collector = None - # self.assertFalse(self.service.should_execute_hook("Module", cell, True)) - - # def test_should_execute_hook_return_true(self): - # cell = MagicMock() - # self.service.switch = True - # self.service.inner_switch = False - # self.service.data_collector = MagicMock() - # self.service.data_collector.data_processor = MagicMock() - # self.service.data_collector.data_processor.is_terminated = False - # self.assertTrue(self.service.should_execute_hook("Module", cell, True)) - - # cell.forward_data_collected = True - # self.assertTrue(self.service.should_execute_hook("api", cell, False)) - - # def test_need_end_service_with_high_step(self): - # self.service.config.step = [1, 2, 3] - # self.service.current_iter = 4 - # self.assertTrue(self.service.need_end_service()) - - # def test_need_end_service_with_low_step(self): - # self.service.config.step = [1, 2, 3] - # self.service.current_iter = 2 - # self.service.data_collector.data_processor.is_terminated = False - # self.assertFalse(self.service.need_end_service()) - - # def test_start_with_termination_condition(self): - # self.service.config.step = [1, 2, 3] - # self.service.current_iter = 4 - # self.service.start() - # self.assertFalse(self.service.switch) - # self.assertTrue(self.service.should_stop_service) - # self.assertFalse(self.service.primitive_switch) - - # @patch('msprobe.mindspore.service.print_tools_ends_info') - # @patch.object(Service, 'need_end_service', return_value=True) - # def test_start_with_end_service(self, mock_need_end_service, mock_print_tools_ends_info): - # self.service.start(self.service.model) - # mock_need_end_service.assert_called_once() - # mock_print_tools_ends_info.assert_called_once() - # self.assertFalse(self.service.switch) - # self.assertTrue(self.service.should_stop_service) - - # @patch.object(Service, 'need_end_service', return_value=False) - # @patch.object(logger, 'info') - # @patch.object(Service, 'register_cell_hook') - # @patch.object(Service, 'register_primitive_hook') - # @patch.object(Service, 'create_dirs') - # @patch('msprobe.mindspore.service.get_rank_if_initialized', return_value=0) - # def test_start_first_time(self, mock_get_rank, mock_create_dirs, mock_register_primitive_hook, - # mock_register_cell_hook, mock_logger, mock_need_end_service): - # self.service.first_start = True - # self.service.should_stop_service = False - # self.service.start(self.service.model) - # mock_get_rank.assert_called_once() - # mock_register_cell_hook.assert_called_once() - # mock_register_primitive_hook.assert_called_once() - # mock_need_end_service.assert_called_once() - # mock_create_dirs.assert_called_once() - # self.assertFalse(self.service.first_start) - # self.assertTrue(self.service.switch) - # self.assertTrue(self.service.primitive_switch) - # mock_logger.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") - - # @patch.object(Service, 'register_primitive_hook') - # @patch.object(Service, 'register_cell_hook') - # @patch.object(Service, 'need_end_service', return_value=False) - # @patch.object(JitDump, 'set_config') - # @patch.object(JitDump, 'set_data_collector') - # @patch.object(ApiRegistry, 'register_all_api') - # def test_start_with_jit_dump_enabled(self, mock_api_set_hook_func, mock_set_data_collector, - # mock_set_config, mock_need_end_service, mock_register_cell_hook, - # mock_register_primitive_hook): - # self.service.config.level = Const.LEVEL_MIX - # self.service.first_start = True - # self.service.should_stop_service = False - # self.service.start(self.service.model) - # mock_set_config.assert_called_with(self.service.config) - # mock_set_data_collector.assert_called_with(self.service.data_collector) - # mock_api_set_hook_func.assert_called_once() - # mock_need_end_service.assert_called_once() - # mock_register_cell_hook.assert_called_once() - # mock_register_primitive_hook.assert_called_once() - # self.assertTrue(JitDump.jit_dump_switch) - - # def test_step_updates(self): - # CellProcessor.cell_count = {"test_api": 1} - # HOOKCell.cell_count = {"test_api": 1} - # JitDump.jit_count = {"test_api": 1} - # self.service.primitive_hook_service.primitive_counters = {"test_api": 1} - # self.service.loop = 0 - # self.service.step() - # self.assertEqual(self.service.loop, 1) - # self.service.data_collector.reset_status.assert_called_once() - # self.assertEqual(JitDump.jit_count, defaultdict(int)) - # self.assertEqual((self.service.primitive_hook_service.primitive_counters), {}) - - # @patch.object(Service, 'should_execute_hook') - # def test_build_forward_and_backward_hooks(self, mock_should_execute_hook): - # mock_should_execute_hook.return_value = True - # self.service.data_collector = MagicMock() - # self.service.data_collector.update_api_or_module_name = MagicMock() - # self.service.data_collector.forward_data_collect = MagicMock() - # self.service.data_collector.if_return_forward_new_output = MagicMock(return_value=False) - # self.service.data_collector.backward_data_collect = MagicMock() - - # mock_cell = MagicMock() - # mock_cell.mindstudio_reserved_name = "TestCell" - # mock_input = (MagicMock(),) - # mock_output = MagicMock() - - # _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook") - - # forward_hook(mock_cell, mock_input, mock_output) - # self.service.data_collector.update_api_or_module_name.assert_called_with('TestCell') - # self.service.data_collector.forward_data_collect.assert_called() - - # self.service.data_collector.reset_mock() - - # mock_grad_input = (MagicMock(),) - # mock_grad_output = MagicMock() - - # backward_hook(mock_cell, mock_grad_input, mock_grad_output) - # self.service.data_collector.update_api_or_module_name.assert_called_with('TestHookbackward.0') - # self.service.data_collector.backward_data_collect.assert_called() - - # def test_register_primitive_hook(self): - # self.service.config.level = Const.LEVEL_MIX - # primitive_attr = ops.Add() - # primitive_name = "primitive_api" - # cell_mock = MagicMock() - # cell_mock.primitive_api = primitive_attr - # primitive_combined_name = primitive_name + Const.SEP + primitive_attr.__class__.__name__ - # self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - # self.service.register_primitive_hook() - # self.assertTrue(hasattr(primitive_attr.__class__, '__call__')) - # self.assertEqual(self.service.primitive_hook_service.wrap_primitive.call_args[0][1], - # primitive_combined_name) - - # @patch.object(ApiRegistry, 'initialize_hook') - # @patch.object(ApiRegistry, 'register_all_api') - # @patch("msprobe.mindspore.service.logger.info") - # def test_register_hook_new_with_level_mix(self, mock_logger, mock_api_set_hook_func, mock_initialize_hook): - # self.service.config.level = Const.LEVEL_MIX - # self.service.register_api_hook() - # self.service.register_cell_hook() - # mock_logger.assert_called_with(f"The cell {self.service.config.task} hook function " - # "is successfully mounted to the model.") - # mock_api_set_hook_func.assert_called() - # mock_initialize_hook.assert_called() - - # def test_register_hook_new_without_model_raises_exception(self): - # self.service.config.level = Const.LEVEL_L0 - # self.service.model = None - # with self.assertRaises(MsprobeException): - # self.service.register_cell_hook() + def test_update_primitive_counters(self): + self.service.primitive_counters = {} + self.service.update_primitive_counters("conv2d") + self.assertEqual(self.service.primitive_counters["conv2d"], 0) + self.service.update_primitive_counters("conv2d") + self.assertEqual(self.service.primitive_counters["conv2d"], 1) + + @patch('msprobe.mindspore.service.create_directory') + def test_create_dirs(self, mock_create_directory): + self.service.current_iter = 1 + self.service.current_rank = 0 + self.service.data_collector.tasks_need_tensor_data = [Const.TENSOR] + self.service.data_collector.update_dump_paths = MagicMock() + self.service.create_dirs() + expected_calls = [ + ("/tmp/dump"), + ("/tmp/dump/step1/rank0"), + "/tmp/dump/step1/rank0/dump_tensor_data" + ] + mock_create_directory.assert_has_calls( + [unittest.mock.call(path) for path in expected_calls], any_order=True) + + args, _ = self.service.data_collector.update_dump_paths.call_args + self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") + self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") + self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") + self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") + self.service.data_collector.initialize_json_file.assert_called_once_with( + framework=Const.MS_FRAMEWORK + ) + + @patch.object(Service, 'need_end_service', return_value=False) + def test_start_stop_cycle(self, mock_need_end_service): + self.service.model = nn.Cell() + self.should_stop_service = False + self.service.start(self.service.model) + self.assertTrue(self.service.switch) + self.service.stop() + self.assertFalse(self.service.switch) + self.service.cell_processor.register_cell_hook.assert_called_once() + mock_need_end_service.assert_called_once() + + self.service.cell_processor.register_cell_hook.reset_mock() + + def test_should_execute_hook_return_false(self): + cell = MagicMock() + self.service.switch = False + self.assertFalse(self.service.should_execute_hook("Module", cell, True)) + self.assertFalse(self.service.should_execute_hook("api", cell, True)) + + self.service.switch = True + cell.forward_data_collected = False + self.assertFalse(self.service.should_execute_hook("api", cell, False)) + + self.service.inner_switch = True + self.assertFalse(self.service.should_execute_hook("Module", cell, True)) + + self.service.inner_switch = False + self.service.data_collector = None + self.assertFalse(self.service.should_execute_hook("Module", cell, True)) + + def test_should_execute_hook_return_true(self): + cell = MagicMock() + self.service.switch = True + self.service.inner_switch = False + self.service.data_collector = MagicMock() + self.service.data_collector.data_processor = MagicMock() + self.service.data_collector.data_processor.is_terminated = False + self.assertTrue(self.service.should_execute_hook("Module", cell, True)) + + cell.forward_data_collected = True + self.assertTrue(self.service.should_execute_hook("api", cell, False)) + + def test_need_end_service_with_high_step(self): + self.service.config.step = [1, 2, 3] + self.service.current_iter = 4 + self.assertTrue(self.service.need_end_service()) + + def test_need_end_service_with_low_step(self): + self.service.config.step = [1, 2, 3] + self.service.current_iter = 2 + self.service.data_collector.data_processor.is_terminated = False + self.assertFalse(self.service.need_end_service()) + + def test_start_with_termination_condition(self): + self.service.config.step = [1, 2, 3] + self.service.current_iter = 4 + self.service.start() + self.assertFalse(self.service.switch) + self.assertTrue(self.service.should_stop_service) + self.assertFalse(self.service.primitive_switch) + + @patch('msprobe.mindspore.service.print_tools_ends_info') + @patch.object(Service, 'need_end_service', return_value=True) + def test_start_with_end_service(self, mock_need_end_service, mock_print_tools_ends_info): + self.service.start(self.service.model) + mock_need_end_service.assert_called_once() + mock_print_tools_ends_info.assert_called_once() + self.assertFalse(self.service.switch) + self.assertTrue(self.service.should_stop_service) + + @patch.object(Service, 'need_end_service', return_value=False) + @patch.object(logger, 'info') + @patch.object(Service, 'register_primitive_hook') + @patch.object(Service, 'create_dirs') + @patch('msprobe.mindspore.service.get_rank_if_initialized', return_value=0) + def test_start_first_time(self, mock_get_rank, mock_create_dirs, mock_register_primitive_hook, + mock_logger, mock_need_end_service): + self.service.first_start = True + self.service.should_stop_service = False + self.service.start(self.service.model) + mock_get_rank.assert_called_once() + self.service.cell_processor.register_cell_hook.assert_called_once() + mock_register_primitive_hook.assert_called_once() + mock_need_end_service.assert_called_once() + mock_create_dirs.assert_called_once() + self.assertFalse(self.service.first_start) + self.assertTrue(self.service.switch) + self.assertTrue(self.service.primitive_switch) + mock_logger.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") + + self.service.cell_processor.register_cell_hook.reset_mock() + + @patch.object(Service, 'register_primitive_hook') + @patch.object(Service, 'need_end_service', return_value=False) + @patch.object(JitDump, 'set_config') + @patch.object(JitDump, 'set_data_collector') + def test_start_with_jit_dump_enabled(self, mock_set_data_collector, mock_set_config, + mock_need_end_service, mock_register_primitive_hook): + self.service.config.level = Const.LEVEL_MIX + self.service.first_start = True + self.service.should_stop_service = False + self.service.start(self.service.model) + mock_set_config.assert_called_with(self.service.config) + mock_set_data_collector.assert_called_with(self.service.data_collector) + self.service.api_register.register_all_api.assert_called_once() + mock_need_end_service.assert_called_once() + self.service.cell_processor.register_cell_hook.assert_called_once() + mock_register_primitive_hook.assert_called_once() + self.assertTrue(JitDump.jit_dump_switch) + + self.service.api_register.register_all_api.reset_mock() + self.service.cell_processor.register_cell_hook.reset_mock() + + def test_step_updates(self): + CellProcessor.cell_count = {"test_api": 1} + HOOKCell.cell_count = {"test_api": 1} + JitDump.jit_count = {"test_api": 1} + self.service.primitive_hook_service.primitive_counters = {"test_api": 1} + self.service.loop = 0 + self.service.step() + self.assertEqual(self.service.loop, 1) + self.service.data_collector.reset_status.assert_called_once() + self.assertEqual(JitDump.jit_count, defaultdict(int)) + self.assertEqual((self.service.primitive_hook_service.primitive_counters), {}) + + @patch.object(Service, 'should_execute_hook') + def test_build_forward_and_backward_hooks(self, mock_should_execute_hook): + mock_should_execute_hook.return_value = True + self.service.data_collector = MagicMock() + self.service.data_collector.update_api_or_module_name = MagicMock() + self.service.data_collector.forward_data_collect = MagicMock() + self.service.data_collector.if_return_forward_new_output = MagicMock(return_value=False) + self.service.data_collector.backward_data_collect = MagicMock() + + mock_cell = MagicMock() + mock_input = (MagicMock(),) + mock_output = MagicMock() + + _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook.forward.0") + + forward_hook(mock_cell, mock_input, mock_output) + self.service.data_collector.update_api_or_module_name.assert_called_with('TestHook.forward.0') + self.service.data_collector.forward_data_collect.assert_called() + + self.service.data_collector.reset_mock() + + mock_grad_input = (MagicMock(),) + mock_grad_output = MagicMock() + + backward_hook(mock_cell, mock_grad_input, mock_grad_output) + self.service.data_collector.update_api_or_module_name.assert_called_with('TestHook.backward.0') + self.service.data_collector.backward_data_collect.assert_called() + + def test_register_primitive_hook(self): + self.service.config.level = Const.LEVEL_MIX + primitive_attr = ops.Add() + primitive_name = "primitive_api" + mock_model = MagicMock() + cell_mock = MagicMock() + cell_mock.primitive_api = primitive_attr + primitive_combined_name = primitive_name + Const.SEP + primitive_attr.__class__.__name__ + self.service.model = mock_model + self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] + self.service.register_primitive_hook() + self.assertTrue(hasattr(primitive_attr.__class__, '__call__')) + self.assertEqual(self.service.primitive_hook_service.wrap_primitive.call_args[0][1], + primitive_combined_name) + + @patch("msprobe.mindspore.service.logger.info") + def test_register_hook_new_with_level_mix(self, mock_logger): + self.service.config.level = Const.LEVEL_MIX + self.service.register_api_hook() + mock_logger.assert_called_with(f'The api {self.service.config.task} hook function ' + 'is successfully mounted to the model.') + self.service.api_register.initialize_hook.assert_called_once() + self.service.api_register.register_all_api.assert_called_once() + + self.service.api_register.initialize_hook.reset_mock() + self.service.api_register.register_all_api.reset_mock() -- Gitee From 238cf34aa1962ff84cb1531aaa9fe8c42a8b1a16 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 15:29:36 +0800 Subject: [PATCH 11/28] V1.4 --- .../test/mindspore_ut/test_ms_service.py | 55 +++++++++++++------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index b2de4bfccd..69293ff1ea 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -137,29 +137,51 @@ class TestService(unittest.TestCase): self.service.current_rank = 0 self.service.data_collector.tasks_need_tensor_data = [Const.TENSOR] self.service.data_collector.update_dump_paths = MagicMock() - self.service.create_dirs() expected_calls = [ ("/tmp/dump"), ("/tmp/dump/step1/rank0"), "/tmp/dump/step1/rank0/dump_tensor_data" ] - mock_create_directory.assert_has_calls( - [unittest.mock.call(path) for path in expected_calls], any_order=True) - - args, _ = self.service.data_collector.update_dump_paths.call_args - self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") - self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") - self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") - self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") - self.service.data_collector.initialize_json_file.assert_called_once_with( - framework=Const.MS_FRAMEWORK - ) - + with patch('msprobe.mindspore.service.is_mindtorch') as mock_is_mindtorch: + mock_is_mindtorch.return_value = False + self.service.create_dirs() + mock_create_directory.assert_has_calls( + [unittest.mock.call(path) for path in expected_calls], any_order=True) + + args, _ = self.service.data_collector.update_dump_paths.call_args + self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") + self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") + self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") + self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") + self.service.data_collector.initialize_json_file.assert_called_once_with( + framework=Const.MS_FRAMEWORK + ) + + mock_create_directory.reset_mock() + self.service.data_collector.update_dump_paths.reset_mock() + self.service.data_collector.initialize_json_file.reset_mock() + + self.service.create_dirs() + mock_create_directory.assert_has_calls( + [unittest.mock.call(path) for path in expected_calls], any_order=True) + + args, _ = self.service.data_collector.update_dump_paths.call_args + self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") + self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") + self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") + self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") + self.service.data_collector.initialize_json_file.assert_called_once_with( + framework=Const.MT_FRAMEWORK + ) + + @patch.object(Service, 'check_model_valid') @patch.object(Service, 'need_end_service', return_value=False) - def test_start_stop_cycle(self, mock_need_end_service): + def test_start_stop_cycle(self, mock_need_end_service, mock_check_model_valid): self.service.model = nn.Cell() + mock_check_model_valid.return_value = self.service.model self.should_stop_service = False self.service.start(self.service.model) + mock_check_model_valid.assert_called_with(self.service.model) self.assertTrue(self.service.switch) self.service.stop() self.assertFalse(self.service.switch) @@ -317,8 +339,9 @@ class TestService(unittest.TestCase): cell_mock.primitive_api = primitive_attr primitive_combined_name = primitive_name + Const.SEP + primitive_attr.__class__.__name__ self.service.model = mock_model - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - self.service.register_primitive_hook() + with patch('msprobe.mindspore.service.get_cells_and_names') as mock_get_cells_and_names: + mock_get_cells_and_names.return_value = {'-1': [("cell_name", cell_mock)]} + self.service.register_primitive_hook() self.assertTrue(hasattr(primitive_attr.__class__, '__call__')) self.assertEqual(self.service.primitive_hook_service.wrap_primitive.call_args[0][1], primitive_combined_name) -- Gitee From f3e4b1bbd3b97b97a6e7946648529ae3d748d7ef Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 15:37:22 +0800 Subject: [PATCH 12/28] V1.5 --- .../msprobe/test/common_set_up/test_set_up.py | 6 ++++-- .../msprobe/test/mindspore_ut/test_ms_service.py | 1 + 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py index 47421278a1..0aa7941048 100644 --- a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -18,7 +18,7 @@ from unittest.mock import MagicMock from mindspore import mint from mindtorch import reset_torch_tensor -from msprobe.mindspore.common.utils import is_mindtorch +from msprobe.mindspore.common.utils import is_mindtorch, mindtorch_check_result try: from mint import distributed @@ -26,7 +26,7 @@ except ImportError: distributed = MagicMock() setattr(mint, 'distributed', distributed) -mindtorch_check_result = is_mindtorch() +is_mindtorch() reset_torch_tensor() @@ -34,3 +34,5 @@ class SetUp(TestCase): def test_case(self): self.assertTrue(hasattr(mint, 'distributed')) self.assertTrue(mindtorch_check_result) + global mindtorch_check_result + mindtorch_check_result = None diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index 69293ff1ea..569d92c023 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -161,6 +161,7 @@ class TestService(unittest.TestCase): self.service.data_collector.update_dump_paths.reset_mock() self.service.data_collector.initialize_json_file.reset_mock() + mock_is_mindtorch.return_value = True self.service.create_dirs() mock_create_directory.assert_has_calls( [unittest.mock.call(path) for path in expected_calls], any_order=True) -- Gitee From d85125a479b9fd0381d7185caf491a7585e233b0 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 15:52:17 +0800 Subject: [PATCH 13/28] V1.5 --- .../msprobe/test/common_set_up/test_set_up.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py index 0aa7941048..685803cb33 100644 --- a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -17,8 +17,10 @@ from unittest import TestCase from unittest.mock import MagicMock from mindspore import mint -from mindtorch import reset_torch_tensor -from msprobe.mindspore.common.utils import is_mindtorch, mindtorch_check_result + +from .mindtorch import reset_torch_tensor +from msprobe.mindspore.common import utils +from msprobe.mindspore.common.utils import is_mindtorch try: from mint import distributed @@ -33,6 +35,5 @@ reset_torch_tensor() class SetUp(TestCase): def test_case(self): self.assertTrue(hasattr(mint, 'distributed')) - self.assertTrue(mindtorch_check_result) - global mindtorch_check_result - mindtorch_check_result = None + self.assertTrue(is_mindtorch()) + utils.mindtorch_check_result = None -- Gitee From 806883beff1ae1c0434ddab2894de63506c267b6 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 16:11:46 +0800 Subject: [PATCH 14/28] V1.6 --- debug/accuracy_tools/msprobe/test/common_set_up/__init__.py | 0 .../free_benchmark/test_ms_api_pynative_self_check.py | 2 ++ 2 files changed, 2 insertions(+) create mode 100644 debug/accuracy_tools/msprobe/test/common_set_up/__init__.py diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/__init__.py b/debug/accuracy_tools/msprobe/test/common_set_up/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py index c4482a22f0..e07417aba8 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py @@ -99,11 +99,13 @@ class TestApiPyNativeSelfCheck(TestCase): _, forward_hook, backward_hook, _ = self.checker.build_hook("Functional.add.") cell = Cell() + cell.msprobe_input_kwargs = {} with patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.need_wrapper_func", return_value=False): self.assertIsNone(forward_hook(cell, "input", "output")) cell = Cell() + cell.msprobe_input_kwargs = {} self.checker.api_list = ["mindspore.ops.add"] self.checker.ori_func["mindspore.ops.add"] = "add" with patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.need_wrapper_func", return_value=True), \ -- Gitee From 4f6b9c91dac9af2c1bd7df4c204a11dcc1ee7e74 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 16:40:37 +0800 Subject: [PATCH 15/28] V1.6 --- .../test/mindspore_ut/test_primitive_dump.py | 57 +++++-------------- 1 file changed, 14 insertions(+), 43 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 79deeee08e..0bba6232f4 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,16 +12,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + +import tempfile import unittest -import mindspore as ms -import numpy as np -import os from unittest.mock import Mock, patch +import mindspore as ms +import numpy as np from mindspore import nn +from mindspore.common.tensor import Tensor -import tempfile from msprobe.core.common.utils import Const from msprobe.mindspore.service import Service from msprobe.core.common.exceptions import MsprobeException @@ -31,7 +30,6 @@ from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from collections import defaultdict from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService -from mindspore.common.tensor import Tensor class DummyModel(nn.Cell): @@ -73,7 +71,7 @@ class TestService(unittest.TestCase): def test_check_model_valid_invalid_model(self): model = "invalid_model" - with self.assertRaises(MsprobeException) as context: + with self.assertRaises(MsprobeException): self.service.check_model_valid(model) def test_update_primitive_counters(self): @@ -148,7 +146,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents @@ -163,7 +160,6 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_2) self.assertEqual(len(captured_grads), 6) # 捕获到两个梯度 - print(f"1After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 调用到达阈值,验证数据收集 self.assertTrue(self.mock_service_instance.data_collector.backward_output_data_collect.called) @@ -177,7 +173,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents @@ -214,14 +209,7 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") - if wrapped_primitive_call.__closure__: - for i, closure in enumerate(wrapped_primitive_call.__closure__): - print(f"Closure[{i}]:", closure.cell_contents) - - if hook_primitive_inputs.__closure__: - for i, closure in enumerate(hook_primitive_inputs.__closure__): - print(f"2Closure[{i}]:", closure.cell_contents) + create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents backward_hook = create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type) @@ -235,7 +223,6 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_2) self.assertEqual(len(captured_grads), 6) # 捕获到两个梯度 - print(f"After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 调用到达阈值,验证数据收集 self.assertTrue(self.mock_service_instance.data_collector.backward_input_data_collect.called) @@ -282,19 +269,14 @@ class TestPrimitiveHookService(unittest.TestCase): updated_primitive_name = "test_primitive_input" # 调用 hook_primitive_inputs - hooked_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents(args, - captured_grads_input, - updated_primitive_name) + hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) # 验证 hooked_inputs 是否正确添加了 hook for arg, hooked_arg in zip(args, hooked_inputs): if isinstance(arg, Tensor): - print(f"Captured hooked_arg after hook: {hooked_arg}") self.assertTrue(hasattr(hooked_arg, 'grad_fn')) - # 打印调试信息 - print(f"Captured gradients after hook: {captured_grads_input}") - def test_hook_primitive_outputs(self): # 模拟前向输出 out = (Tensor(np.array([1.0, 2.0]), ms.float32), Tensor(np.array([3.0, 4.0]), ms.float32)) @@ -311,9 +293,6 @@ class TestPrimitiveHookService(unittest.TestCase): if isinstance(tensor, Tensor): self.assertTrue(hasattr(hooked_tensor, 'grad_fn')) - # 打印调试信息 - print(f"Captured gradients after output hook: {captured_grads_output}") - def test_wrapped_primitive_call_args(self): # 模拟前向输入 args = (Tensor(np.array([1.0, 2.0]), ms.float32), Tensor(np.array([3.0, 4.0]), ms.float32)) @@ -331,13 +310,11 @@ class TestPrimitiveHookService(unittest.TestCase): if isinstance(arg, Tensor): self.assertTrue(hasattr(hooked_arg, 'grad_fn')) self.assertTrue(np.array_equal(arg.asnumpy(), hooked_arg.asnumpy())) - print(f"Arg type: {type(arg)}, Hooked input type: {type(hooked_arg)}") else: self.assertEqual(arg, hooked_arg) except Exception as e: self.fail(f"wrapped_primitive_call raised an exception: {e}") - def test_update_primitive_counters_multiple(self): # 测试更新 primitive 计数器的功能,增加多个不同名称的测试 primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"] @@ -416,13 +393,11 @@ class TestPrimitiveHookService(unittest.TestCase): for captured_grads in captured_grads_sets: updated_primitive_name = "MatMul.Backward" - num_tensors = len(captured_grads) hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) self.assertIsNotNone(backward_hook) - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward): # 模拟前向和后向钩子在同一个 primitive 中的行为 @@ -447,9 +422,6 @@ class TestPrimitiveHookService(unittest.TestCase): self.primitive_hook_service.update_primitive_counters(name) self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) - - - def test_update_primitive_counters(self): primitive_name = "MatMul" self.primitive_hook_service.update_primitive_counters(primitive_name) @@ -496,7 +468,7 @@ class TestPrimitiveHookService(unittest.TestCase): wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") # 模拟反向传播过程,调用包装的 primitive - with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect') as mock_backward_collect: + with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect'): result = wrapped_func(Mock(), input_tensor) # 验证结果是 Tensor 实例 @@ -544,7 +516,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 测试 create_backward_hook 的功能 captured_grads = [] updated_primitive_name = "MatMul.Backward" - num_tensors = 2 # 创建 backward hook backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") -- Gitee From a3a50a84518ac771820a8593f82a932eca76799a Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 16:53:18 +0800 Subject: [PATCH 16/28] V1.7 --- .../test/mindspore_ut/test_primitive_dump.py | 96 ++----------------- 1 file changed, 9 insertions(+), 87 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 0bba6232f4..e9f4e5ee46 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -13,96 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict import tempfile import unittest from unittest.mock import Mock, patch import mindspore as ms import numpy as np -from mindspore import nn from mindspore.common.tensor import Tensor from msprobe.core.common.utils import Const from msprobe.mindspore.service import Service -from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from collections import defaultdict from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService -class DummyModel(nn.Cell): - def __init__(self): - super(DummyModel, self).__init__() - self.dense = nn.Dense(2, 2) - - def construct(self, x): - return self.dense(x) - - -class TestService(unittest.TestCase): - @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def setUp(self, _): - json_config = { - "task": "statistics", - "dump_path": "/absolute_path", - "rank": [], - "step": [0, 2], - "level": "L1" - } - - common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) - config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - self.service.primitive_switch = True # Make sure the switch is on for testing - - def test_check_model_valid_none(self): - model = None - self.assertIsNone(self.service.check_model_valid(model)) - - def test_check_model_valid_valid_model(self): - model = DummyModel() - self.assertEqual(self.service.check_model_valid(model), model) - - def test_check_model_valid_invalid_model(self): - model = "invalid_model" - with self.assertRaises(MsprobeException): - self.service.check_model_valid(model) - - def test_update_primitive_counters(self): - primitive_name = "test_primitive" - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 0) - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 1) - - def test_step_updates_iteration(self): - initial_iter = self.service.loop - self.service.step() - self.assertEqual(self.service.loop, initial_iter + 1) - - @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) - def test_step_resets_counters(self, _): - # 假设在 step 调用之前已经有一些 primitive_counters - self.service.primitive_hook_service.primitive_counters["test_primitive"] = 5 - self.service.step() - self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) - self.assertEqual(HOOKCell.cell_count, defaultdict(int)) - - def test_start_calls_update_iter(self): - # 检查是否在调用 start 时调用了 update_iter - with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: - initial_iter = self.service.loop - init_step = self.service.init_step - self.service.start() - mock_update_iter.assert_called_once_with(initial_iter + init_step) - - class TestPrimitiveHookService(unittest.TestCase): def setUp(self): # 创建一个临时目录作为 dump_path @@ -119,19 +46,14 @@ class TestPrimitiveHookService(unittest.TestCase): common_config = CommonConfig(json_config) task_config = BaseConfig(json_config) config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - - # 模拟一个 service_instance 和 data_collector - self.mock_service_instance = Service(config) - self.mock_service_instance.switch = True - self.mock_service_instance.data_collector = Mock() - self.mock_service_instance.data_collector.dump_file_path = json_config["dump_path"] - - # 初始化 PrimitiveHookService - self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) + + with patch('msprobe.mindspore.service.build_data_collector'), \ + patch('msprobe.mindspore.service.CellProcessor'), \ + patch('msprobe.mindspore.service.PrimitiveHookService'), \ + patch('msprobe.mindspore.service.get_api_register'): + self.mock_service_instance = Service(config) + self.mock_service_instance.switch = True + self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) def tearDown(self): # 测试结束时删除临时目录 -- Gitee From 382160ddb0aa96cbb95546b008dddacd1dfa17ef Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 17:39:12 +0800 Subject: [PATCH 17/28] V1.8 --- .../test/mindspore_ut/test_primitive_dump.py | 58 +++++++++++-------- 1 file changed, 33 insertions(+), 25 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index e9f4e5ee46..ee5fdd95c8 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -16,11 +16,11 @@ from collections import defaultdict import tempfile import unittest -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock -import mindspore as ms import numpy as np -from mindspore.common.tensor import Tensor +import mindspore as ms +from mindspore import Tensor, ops from msprobe.core.common.utils import Const from msprobe.mindspore.service import Service @@ -192,12 +192,15 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 hook_primitive_inputs hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) - - # 验证 hooked_inputs 是否正确添加了 hook - for arg, hooked_arg in zip(args, hooked_inputs): - if isinstance(arg, Tensor): - self.assertTrue(hasattr(hooked_arg, 'grad_fn')) + with patch.object(ops, 'HookBackward') as mock_HookBackward: + mock_hbw = MagicMock() + target_value = Tensor([0], dytpe=ms.float32) + mock_hbw.return_value = target_value + mock_HookBackward.return_value = mock_hbw + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(args)) + for hooked_input in hooked_inputs: + self.assertTrue((hooked_input == target_value).all()) def test_hook_primitive_outputs(self): # 模拟前向输出 @@ -206,14 +209,17 @@ class TestPrimitiveHookService(unittest.TestCase): updated_primitive_name = "test_primitive_output" # 调用 hook_primitive_outputs - hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[ - 1].cell_contents - hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) - - # 验证 hooked_outputs 是否正确添加了 hook - for tensor, hooked_tensor in zip(out, hooked_outputs): - if isinstance(tensor, Tensor): - self.assertTrue(hasattr(hooked_tensor, 'grad_fn')) + hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, + "example").__closure__[1].cell_contents + with patch.object(ops, 'HookBackward') as mock_HookBackward: + mock_hbw = MagicMock() + target_value = Tensor([0], dytpe=ms.float32) + mock_hbw.return_value = target_value + mock_HookBackward.return_value = mock_hbw + hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(out)) + for hooked_output in hooked_outputs: + self.assertTrue((hooked_output == target_value).all()) def test_wrapped_primitive_call_args(self): # 模拟前向输入 @@ -226,14 +232,16 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrapped_primitive_call 并检查 hooked_inputs 是否与原始 args 相同 try: - hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, - updated_primitive_name) - for arg, hooked_arg in zip(args, hooked_inputs): - if isinstance(arg, Tensor): - self.assertTrue(hasattr(hooked_arg, 'grad_fn')) - self.assertTrue(np.array_equal(arg.asnumpy(), hooked_arg.asnumpy())) - else: - self.assertEqual(arg, hooked_arg) + with patch.object(ops, 'HookBackward') as mock_HookBackward: + mock_hbw = MagicMock() + target_value = Tensor([0], dytpe=ms.float32) + mock_hbw.return_value = target_value + mock_HookBackward.return_value = mock_hbw + hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, + updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(args)) + for hooked_input in hooked_inputs: + self.assertTrue((hooked_input == target_value).all()) except Exception as e: self.fail(f"wrapped_primitive_call raised an exception: {e}") -- Gitee From f3b2bb480a1f47f64ed48ff05020c1f9cbc72752 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 18:53:28 +0800 Subject: [PATCH 18/28] V1.8 --- .../test/mindspore_ut/test_primitive_dump.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index ee5fdd95c8..734be70a58 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -16,7 +16,7 @@ from collections import defaultdict import tempfile import unittest -from unittest.mock import Mock, patch, MagicMock +from unittest.mock import Mock, patch import numpy as np import mindspore as ms @@ -193,10 +193,9 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 hook_primitive_inputs hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents with patch.object(ops, 'HookBackward') as mock_HookBackward: - mock_hbw = MagicMock() - target_value = Tensor([0], dytpe=ms.float32) + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value mock_hbw.return_value = target_value - mock_HookBackward.return_value = mock_hbw hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) self.assertEqual(mock_HookBackward.call_count, len(args)) for hooked_input in hooked_inputs: @@ -212,10 +211,9 @@ class TestPrimitiveHookService(unittest.TestCase): hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[1].cell_contents with patch.object(ops, 'HookBackward') as mock_HookBackward: - mock_hbw = MagicMock() - target_value = Tensor([0], dytpe=ms.float32) + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value mock_hbw.return_value = target_value - mock_HookBackward.return_value = mock_hbw hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) self.assertEqual(mock_HookBackward.call_count, len(out)) for hooked_output in hooked_outputs: @@ -233,10 +231,9 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrapped_primitive_call 并检查 hooked_inputs 是否与原始 args 相同 try: with patch.object(ops, 'HookBackward') as mock_HookBackward: - mock_hbw = MagicMock() - target_value = Tensor([0], dytpe=ms.float32) + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value mock_hbw.return_value = target_value - mock_HookBackward.return_value = mock_hbw hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, updated_primitive_name) self.assertEqual(mock_HookBackward.call_count, len(args)) -- Gitee From 8f84280bba77abf1945e5c6bf6f9980293d3db0a Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 19:36:15 +0800 Subject: [PATCH 19/28] V1.9 --- .../test/pytorch_ut/dump/test_module_dump.py | 93 +++++++++++-------- 1 file changed, 56 insertions(+), 37 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py index 5aaf0820a7..71af09fbee 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py @@ -16,47 +16,66 @@ import unittest from unittest.mock import patch, MagicMock -import torch -import torch.nn as nn - -from msprobe.core.data_dump.api_registry import ApiRegistry -from msprobe.pytorch import PrecisionDebugger -from msprobe.pytorch.hook_module.api_register import get_api_register -from msprobe.pytorch.service import torch_version_above_or_equal_2 +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser class TestModuleDumper(unittest.TestCase): - @classmethod - def setUpClass(cls): - PrecisionDebugger._instance = None - get_api_register().restore_all_api() + def setUp(self): + self.service = MagicMock() + with patch('msprobe.pytorch.hook_module.api_register.get_api_register'): + self.module_dumper = ModuleDumper(self.service) - @classmethod - def tearDownClass(cls): - PrecisionDebugger._instance = None - get_api_register().restore_all_api() + def test__init__(self): + self.service = MagicMock() + with patch('msprobe.pytorch.hook_module.api_register.get_api_register') as mock_get_api_register: + self.module_dumper = ModuleDumper(self.service) + self.assertEqual(self.module_dumper.service, self.service) + mock_get_api_register.assert_called_once() - def setUp(self): - self.module = nn.Linear(8, 4) - debugger = PrecisionDebugger(dump_path="./") - self.module_dumper = debugger.module_dumper + def test_start_module_dump(self): + module = MagicMock() + with patch.object(logger, 'info_on_rank_0') as mock_info: + module.msprobe_hook = True + ModuleProcesser.enable_module_dump = False + self.module_dumper.api_register.restore_all_api.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_called_with('The init dump is enabled, and the module dump function will not be available.') + self.assertFalse(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_not_called() + self.assertFalse(hasattr(module, 'msprobe_module_dump')) + + del module.msprobe_hook + mock_info.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_not_called() + self.assertTrue(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_called_once() + self.module_dumper.service.module_processor.register_module_hook.assert_called_with( + module, + self.module_dumper.service.build_hook, + recursive=False, + module_names=['dump_name'] + ) + self.assertTrue(module.msprobe_module_dump) + ModuleProcesser.enable_module_dump = False + + self.module_dumper.api_register.restore_all_api.reset_mock() + self.module_dumper.service.module_processor.register_module_hook.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_not_called() + self.assertTrue(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_called_once() + self.module_dumper.service.module_processor.register_module_hook.assert_not_called() + + ModuleProcesser.enable_module_dump = False def test_stop_module_dump(self): - self.module_dumper.hook_handle_list.extend([1, 2, 3]) - with patch.object(ApiRegistry, 'register_all_api') as mock_api_register: - mock_handle1 = MagicMock(spec=torch.utils.hooks.RemovableHandle) - mock_handle2 = MagicMock(spec=torch.utils.hooks.RemovableHandle) - self.module_dumper.hook_handle_list.extend([mock_handle1, mock_handle2]) - - self.module_dumper.stop_module_dump() - mock_handle1.remove.assert_called_once() - mock_handle2.remove.assert_called_once() - self.assertEqual(self.module_dumper.hook_handle_list, []) - mock_api_register.assert_called_once() - - def test_register_hook(self): - self.module_dumper.register_hook(self.module, "TestModule") - if torch_version_above_or_equal_2: - self.assertEqual(len(self.module_dumper.hook_handle_list), 6) - else: - self.assertEqual(len(self.module_dumper.hook_handle_list), 5) + ModuleProcesser.enable_module_dump = True + self.module_dumper.api_register.register_all_api.reset_mock() + self.module_dumper.stop_module_dump() + self.assertFalse(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.register_all_api.assert_called_once() + + self.module_dumper.api_register.register_all_api.reset_mock() -- Gitee From aedea8059df1fe94d8e21ada29f1a20b0af6395a Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 19:39:26 +0800 Subject: [PATCH 20/28] V1.9 --- .../msprobe/test/pytorch_ut/dump/test_module_dump.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py index 71af09fbee..2004eeca17 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py @@ -24,12 +24,12 @@ from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser class TestModuleDumper(unittest.TestCase): def setUp(self): self.service = MagicMock() - with patch('msprobe.pytorch.hook_module.api_register.get_api_register'): + with patch('msprobe.pytorch.dump.module_dump.module_dump.get_api_register'): self.module_dumper = ModuleDumper(self.service) def test__init__(self): self.service = MagicMock() - with patch('msprobe.pytorch.hook_module.api_register.get_api_register') as mock_get_api_register: + with patch('msprobe.pytorch.dump.module_dump.module_dump.get_api_register') as mock_get_api_register: self.module_dumper = ModuleDumper(self.service) self.assertEqual(self.module_dumper.service, self.service) mock_get_api_register.assert_called_once() -- Gitee From 65183c3c0d62b21738f8b05ebe666b60e09fda3c Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 19:45:06 +0800 Subject: [PATCH 21/28] V1.9 --- .../msprobe/test/pytorch_ut/dump/test_module_dump.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py index 2004eeca17..4ba3556c27 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py @@ -16,6 +16,8 @@ import unittest from unittest.mock import patch, MagicMock +from torch import nn + from msprobe.pytorch.common.log import logger from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser @@ -35,7 +37,7 @@ class TestModuleDumper(unittest.TestCase): mock_get_api_register.assert_called_once() def test_start_module_dump(self): - module = MagicMock() + module = nn.Module() with patch.object(logger, 'info_on_rank_0') as mock_info: module.msprobe_hook = True ModuleProcesser.enable_module_dump = False -- Gitee From acf53862d5a01ff8a71290bc3966078ad691784f Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 19:55:43 +0800 Subject: [PATCH 22/28] V1.10 --- .../pytorch_ut/dump/test_module_processer.py | 55 ++----------------- 1 file changed, 4 insertions(+), 51 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py index 20cfdfa6ba..a779fa2a3f 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py @@ -4,7 +4,6 @@ from unittest.mock import MagicMock import torch from msprobe.core.data_dump.scope import ModuleRangeScope -from msprobe.pytorch.common.utils import Const from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser @@ -25,58 +24,12 @@ class TestModuleProcesser(unittest.TestCase): processor = ModuleProcesser(scope) self.assertIsNone(processor.scope) - def test_module_count_func(self): + def test_set_and_get_calls_number(self): + ModuleProcesser.reset_module_stats() test = ModuleProcesser(None) self.assertEqual(test.module_count, {}) module_name = "nope" - test.module_count_func(module_name) + test.set_and_get_calls_number(module_name) self.assertEqual(test.module_count["nope"], 0) - def test_node_hook_forward_start(self): - name_prefix = "forward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.START) - module = MagicMock() - input = (self.mock_tensor,) - module.mindstudio_reserved_name = None - hook(module, input) - expected_name = f"forward_layer{Const.SEP}0" - self.assertEqual(module.mindstudio_reserved_name, [expected_name]) - self.assertIn(expected_name, ModuleProcesser.module_stack) - self.assertEqual(ModuleProcesser.api_parent_node, expected_name) - - def test_node_hook_forward_stop(self): - name_prefix = "forward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.STOP) - ModuleProcesser.module_stack.append(f"forward_layer{Const.SEP}0") - - module = MagicMock() - input = (self.mock_tensor,) - reserved_name = f"forward_layer{Const.SEP}0" - module.mindstudio_reserved_name = [reserved_name] - hook(module, input) - self.assertNotIn([f"forward_layer{Const.SEP}0"], ModuleProcesser.module_stack) - self.assertEqual(ModuleProcesser.api_parent_node, reserved_name) - - def test_node_hook_backward(self): - name_prefix = "backward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.START) - - module = MagicMock() - input = (self.mock_tensor,) - module.mindstudio_reserved_name = None - ModuleProcesser.module_node[f"forward_layer{Const.SEP}0"] = None - hook(module, input) - expected_name = f"backward_layer{Const.SEP}0" - self.assertEqual(module.mindstudio_reserved_name, [expected_name]) - self.assertIn(expected_name, ModuleProcesser.module_node) - - def test_has_register_backward_hook(self): - module = MagicMock() - module._backward_hooks = {0: lambda: None} - module._is_full_backward_hook = False - result = self.processor.has_register_backward_hook(module) - self.assertTrue(result) - - module._is_full_backward_hook = True - result = self.processor.has_register_backward_hook(module) - self.assertFalse(result) + ModuleProcesser.reset_module_stats() -- Gitee From 548a2897c787e80b65fe14ba5352ff844b692db5 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 20:00:29 +0800 Subject: [PATCH 23/28] V1.10 --- .../pytorch_ut/dump/test_module_processer.py | 15 +++++++++++ .../hook_module/test_hook_module.py | 21 +++++++++++++-- .../pytorch_ut/hook_module/test_wrap_aten.py | 27 ++++++++++++++++--- 3 files changed, 57 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py index a779fa2a3f..832f63f8fd 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py index 1524a82ae1..d907b81af9 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py @@ -1,12 +1,29 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock, patch import threading + from msprobe.pytorch.hook_module.hook_module import HOOKModule + class TestHOOKModuleInit(unittest.TestCase): def setUp(self): - self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) + self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock())) def test_thread_handling(self): module = HOOKModule(self.mock_build_hook) @@ -16,7 +33,7 @@ class TestHOOKModuleInit(unittest.TestCase): class TestHOOKModuleCall(unittest.TestCase): def setUp(self): - self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) + self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock())) self.module = HOOKModule(self.mock_build_hook) @patch.object(HOOKModule, '_call_func') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py index af669cb5c7..e565c1cc08 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py @@ -1,15 +1,34 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock, patch import torch from msprobe.pytorch.function_factory import npu_custom_grad_functions -from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, white_aten_ops, \ +from msprobe.pytorch.hook_module.wrap_aten import ( + AtenOPTemplate, + white_aten_ops, AtenOPPacketTemplate +) def mock_build_hook(prefix): - return (MagicMock(), MagicMock(), MagicMock(), MagicMock()) + return (MagicMock(), MagicMock(), MagicMock()) + class TestAtenOPTemplate(unittest.TestCase): @@ -79,8 +98,8 @@ class TestAtenOPPacketTemplate(unittest.TestCase): del self.mock_op_packet.nonexistent_attr with self.assertRaises(AttributeError) as context: _ = self.template.nonexistent_attr - self.assertIn("or OpOverloadPacket does not have attribute 'nonexistent_attr'.", \ - str(context.exception)) + self.assertIn("or OpOverloadPacket does not have attribute 'nonexistent_attr'.", + str(context.exception)) @patch('msprobe.pytorch.hook_module.wrap_aten.AtenOPTemplate', autospec=True) def test_getattr_op_overload(self, MockAtenOPTemplate): -- Gitee From 146b08326cf037484dce3a3007a440315f2ba505 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 20:22:37 +0800 Subject: [PATCH 24/28] V1.10 --- .../msprobe/pytorch/dump/module_dump/module_processer.py | 9 +++++---- debug/accuracy_tools/msprobe/pytorch/service.py | 6 ++++-- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index dc3f9c1ff2..ae68fa9c60 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -144,11 +144,12 @@ class ModuleProcesser: def build_module_hook(self, module_name, build_data_hook): def forward_pre_hook(module, args, kwargs=None): - if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: - return - if kwargs is None: kwargs = {} + + if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: + return (args, kwargs) if torch_version_above_or_equal_2 else args + index = ModuleProcesser.set_and_get_calls_number(module_name) full_forward_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}' full_backward_name = f'{module_name}{Const.BACKWARD}{Const.SEP}{index}' @@ -182,7 +183,7 @@ class ModuleProcesser: def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: - return + return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output index = ModuleProcesser.module_count.get(module_name) full_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}' diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index ee482fc54b..cf9b9bc930 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -65,9 +65,11 @@ class Service: self.init_for_debug_level() def build_hook(self, module_type, name): - def pre_hook(api_or_module_name, module, args, kwargs={}): + def pre_hook(api_or_module_name, module, args, kwargs=None): + kwargs = {} if kwargs is None else kwargs + if module_type == BaseScope.Module_Type_Module or \ - not self.should_execute_hook(module_type, module, True): + not self.should_execute_hook(module_type, module, True): return is_recompute = is_recomputation() -- Gitee From e88cceff89cb7d8c3352ac1ddc6a90f138b90a36 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 20:36:43 +0800 Subject: [PATCH 25/28] V1.11 --- .../msprobe/test/common_set_up/test_set_up.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py index 685803cb33..150ce2dfce 100644 --- a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -18,16 +18,16 @@ from unittest.mock import MagicMock from mindspore import mint -from .mindtorch import reset_torch_tensor -from msprobe.mindspore.common import utils -from msprobe.mindspore.common.utils import is_mindtorch - try: from mint import distributed except ImportError: distributed = MagicMock() setattr(mint, 'distributed', distributed) +from .mindtorch import reset_torch_tensor +from msprobe.mindspore.common import utils +from msprobe.mindspore.common.utils import is_mindtorch + is_mindtorch() reset_torch_tensor() -- Gitee From 66c23ad221e8f7dbe69627af85240fa8361aff36 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 20:51:51 +0800 Subject: [PATCH 26/28] V1.11 --- debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py index 150ce2dfce..a629123064 100644 --- a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -24,6 +24,9 @@ except ImportError: distributed = MagicMock() setattr(mint, 'distributed', distributed) +# ensure not to import torch_npu +from msprobe.mindspore.dump.hook_cell.api_register import get_api_register + from .mindtorch import reset_torch_tensor from msprobe.mindspore.common import utils from msprobe.mindspore.common.utils import is_mindtorch -- Gitee From 4a81404a4324bc621027b1f5a2d4ca004d9459d2 Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Mon, 21 Apr 2025 22:51:17 +0800 Subject: [PATCH 27/28] V1.11 --- .../msprobe/test/common_set_up/test_set_up.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py index a629123064..96a0072792 100644 --- a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib from unittest import TestCase from unittest.mock import MagicMock @@ -25,13 +26,14 @@ except ImportError: setattr(mint, 'distributed', distributed) # ensure not to import torch_npu -from msprobe.mindspore.dump.hook_cell.api_register import get_api_register +from msprobe.mindspore import service from .mindtorch import reset_torch_tensor from msprobe.mindspore.common import utils from msprobe.mindspore.common.utils import is_mindtorch -is_mindtorch() +utils.mindtorch_check_result = None +importlib.reload(service) reset_torch_tensor() -- Gitee From e1644660add0a99fe1c04cc2dfa62f3d5deec7af Mon Sep 17 00:00:00 2001 From: jiangchao_j Date: Tue, 22 Apr 2025 16:21:52 +0800 Subject: [PATCH 28/28] V1.12 --- debug/accuracy_tools/msprobe/mindspore/common/const.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/msprobe/mindspore/common/const.py b/debug/accuracy_tools/msprobe/mindspore/common/const.py index a59451bbff..afc4a02f05 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/const.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/const.py @@ -71,7 +71,7 @@ class Const: } NonDifferentiableType = ( - mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte, + mstype.bool_, mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte, mstype.int16, mstype.short, mstype.uint16, mstype.ushort, mstype.int32, mstype.intc, mstype.uint32, mstype.uintc, mstype.int64, mstype.intp, mstype.uint64, mstype.uintp -- Gitee