diff --git a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py index b1c510d9448c08e4b6865f5491337b15882c19be..5a384e300c3e4841a392399674085b07e4401eee 100644 --- a/debug/accuracy_tools/msprobe/mindspore/cell_processor.py +++ b/debug/accuracy_tools/msprobe/mindspore/cell_processor.py @@ -180,8 +180,9 @@ class CellProcessor: result = (result,) if len(result) != len(outputs): raise TypeError( - "The backward pre hook return value size is {} not equal to output size {}".format( - len(result), len(outputs))) + f"The backward pre hook return value size is {len(result)} " + f"not equal to output size {len(outputs)}" + ) return result return forward_pre_hook, forward_hook diff --git a/debug/accuracy_tools/msprobe/mindspore/common/const.py b/debug/accuracy_tools/msprobe/mindspore/common/const.py index a59451bbff2b1cacbd5a248a7aa7834dc07606bd..afc4a02f0569d687b444777d087139a1e7eb49ee 100644 --- a/debug/accuracy_tools/msprobe/mindspore/common/const.py +++ b/debug/accuracy_tools/msprobe/mindspore/common/const.py @@ -71,7 +71,7 @@ class Const: } NonDifferentiableType = ( - mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte, + mstype.bool_, mstype.int8, mstype.byte, mstype.uint8, mstype.ubyte, mstype.int16, mstype.short, mstype.uint16, mstype.ushort, mstype.int32, mstype.intc, mstype.uint32, mstype.uintc, mstype.int64, mstype.intp, mstype.uint64, mstype.uintp diff --git a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py index dc3f9c1ff2cdd31977b412f11ee0a15987d6ec31..ae68fa9c60260561b81be8a050d90e9e63f912b5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py +++ b/debug/accuracy_tools/msprobe/pytorch/dump/module_dump/module_processer.py @@ -144,11 +144,12 @@ class ModuleProcesser: def build_module_hook(self, module_name, build_data_hook): def forward_pre_hook(module, args, kwargs=None): - if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: - return - if kwargs is None: kwargs = {} + + if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: + return (args, kwargs) if torch_version_above_or_equal_2 else args + index = ModuleProcesser.set_and_get_calls_number(module_name) full_forward_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}' full_backward_name = f'{module_name}{Const.BACKWARD}{Const.SEP}{index}' @@ -182,7 +183,7 @@ class ModuleProcesser: def forward_hook(module, args, kwargs_or_output, output_or_kwargs=None): if hasattr(module, 'msprobe_module_dump') and not self.enable_module_dump: - return + return output_or_kwargs if torch_version_above_or_equal_2 else kwargs_or_output index = ModuleProcesser.module_count.get(module_name) full_name = f'{module_name}{Const.FORWARD}{Const.SEP}{index}' diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index ee482fc54bb79a63d79b98f0d116258a42fccf78..cf9b9bc9303552bb9acecddd3a685aff8d4e6d88 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -65,9 +65,11 @@ class Service: self.init_for_debug_level() def build_hook(self, module_type, name): - def pre_hook(api_or_module_name, module, args, kwargs={}): + def pre_hook(api_or_module_name, module, args, kwargs=None): + kwargs = {} if kwargs is None else kwargs + if module_type == BaseScope.Module_Type_Module or \ - not self.should_execute_hook(module_type, module, True): + not self.should_execute_hook(module_type, module, True): return is_recompute = is_recomputation() diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/__init__.py b/debug/accuracy_tools/msprobe/test/common_set_up/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py b/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py new file mode 100644 index 0000000000000000000000000000000000000000..cada41ba2ba5014696714136801d86823978dd6d --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/common_set_up/mindtorch.py @@ -0,0 +1,14 @@ +from mindspore import Tensor +import torch + + +def create_msa_tensor(data, dtype=None): + return Tensor(data, dtype) + + +tensor_tensor = torch.tensor +setattr(torch, 'tensor', create_msa_tensor) + + +def reset_torch_tensor(): + setattr(torch, 'tensor', tensor_tensor) diff --git a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py index 6442908bb0e6dd573c101c4314388aabef4ed5c4..96a0072792ad933b981086925b55145a49f355bc 100644 --- a/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py +++ b/debug/accuracy_tools/msprobe/test/common_set_up/test_set_up.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import importlib from unittest import TestCase from unittest.mock import MagicMock @@ -24,7 +25,20 @@ except ImportError: distributed = MagicMock() setattr(mint, 'distributed', distributed) +# ensure not to import torch_npu +from msprobe.mindspore import service + +from .mindtorch import reset_torch_tensor +from msprobe.mindspore.common import utils +from msprobe.mindspore.common.utils import is_mindtorch + +utils.mindtorch_check_result = None +importlib.reload(service) +reset_torch_tensor() + class SetUp(TestCase): def test_case(self): self.assertTrue(hasattr(mint, 'distributed')) + self.assertTrue(is_mindtorch()) + utils.mindtorch_check_result = None diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py index c4482a22f042723f12431b56c730a8d69957b63c..e07417aba8c745833a3f551a9e0489d848a58bb1 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/free_benchmark/test_ms_api_pynative_self_check.py @@ -99,11 +99,13 @@ class TestApiPyNativeSelfCheck(TestCase): _, forward_hook, backward_hook, _ = self.checker.build_hook("Functional.add.") cell = Cell() + cell.msprobe_input_kwargs = {} with patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.need_wrapper_func", return_value=False): self.assertIsNone(forward_hook(cell, "input", "output")) cell = Cell() + cell.msprobe_input_kwargs = {} self.checker.api_list = ["mindspore.ops.add"] self.checker.ori_func["mindspore.ops.add"] = "add" with patch("msprobe.mindspore.free_benchmark.api_pynative_self_check.need_wrapper_func", return_value=True), \ diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py index fbae6599d776c7d5a440df769ba89550372fd2a9..eef22ebe8f7c5c8b92a4a0077410f6b2c320fac3 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_cell_processor.py @@ -182,7 +182,7 @@ class TestCellProcessor(unittest.TestCase): target_args = (Tensor([0.0]),) full_forward_name = f'{cell_name}{Const.FORWARD}.0' full_backward_name = f'{cell_name}{Const.BACKWARD}.0' - + # call testing function - forward_pre_hook ret = forward_pre_hook(mock_cell, args) self.assertIsNone(CellProcessor.module_node[full_forward_name]) self.assertEqual(CellProcessor.cell_stack, [full_forward_name]) @@ -193,11 +193,13 @@ class TestCellProcessor(unittest.TestCase): mock_CellBackwardHook.assert_called_with(full_backward_name, mock_cell, CellProcessor.cell_backward_hook[-1]) mock_bw.register_backward_hook.assert_called_once() + mock_bw.assert_called_with(*args) self.assertTrue((ret[0] == target_args[0]).all()) backward_hook = CellProcessor.cell_backward_hook[-1][full_backward_name] grad_input = (Tensor([1.0]),) grad_output = (Tensor([2.0]),) + # call testing function - backward_hook ret = backward_hook(mock_cell, grad_input, grad_output) mock_backward_data_hook.assert_called_with(mock_cell, grad_input, grad_output) self.assertFalse(mock_cell.has_pre_hook_called) @@ -209,6 +211,7 @@ class TestCellProcessor(unittest.TestCase): mock_build_data_hook.reset_mock() args = (Tensor([1], dtype=ms.int32),) full_forward_name = f'{cell_name}{Const.FORWARD}.1' + # call testing function - forward_pre_hook ret = forward_pre_hook(mock_cell, args) self.assertIsNone(CellProcessor.module_node[full_forward_name]) self.assertEqual(CellProcessor.cell_stack, [full_forward_name]) @@ -222,18 +225,92 @@ class TestCellProcessor(unittest.TestCase): CellProcessor.cell_stack = [full_forward_name] CellProcessor.api_parent_node = full_forward_name CellProcessor.module_node = {full_forward_name: None} + self.scope.reset_mock() mock_CellBackwardHook.reset_mock() mock_bw.reset_mock() - mock_backward_data_hook.reset_mock() - mock_forward_data_hook_hook = MagicMock() target_output = Tensor([0.5]) - mock_forward_data_hook_hook.return_value = target_output - mock_build_data_hook.return_value = (None, mock_forward_data_hook_hook, mock_backward_data_hook, None) args = (Tensor([1.0]),) output = Tensor([2.0]) + mock_bw.return_value = target_output + mock_backward_data_hook.reset_mock() + mock_forward_data_hook_hook = MagicMock() + mock_forward_data_hook_hook.return_value = output + mock_build_data_hook.return_value = (None, mock_forward_data_hook_hook, mock_backward_data_hook, None) + # call testing function - forward_hook ret = forward_hook(mock_cell, args, output) + self.assertEqual(CellProcessor.cell_count.get(cell_name), 0) + self.assertEqual(CellProcessor.cell_stack, []) + self.assertIsNone(CellProcessor.api_parent_node) + self.scope.end_module.assert_called_with(full_forward_name) + self.assertEqual(mock_bw.call_count, 2) + self.assertEqual(mock_bw.call_args_list[0][0][0], output) + self.assertEqual(mock_bw.call_args_list[1][0][0], target_output) + self.assertEqual(mock_CellBackwardHook.call_count, 1) + self.assertEqual(len(CellProcessor.cell_backward_pre_hook), 1) self.assertTrue((ret == target_output).all()) + backward_pre_hook = CellProcessor.cell_backward_pre_hook[-1][full_backward_name] + mock_backward_data_hook.reset_mock() + grad_output = (Tensor([2.0]),) + # call testing function - backward_pre_hook + ret = backward_pre_hook(mock_cell, grad_output) + self.assertTrue(mock_cell.has_pre_hook_called) + self.scope.begin_module.assert_called_with(full_backward_name) + self.assertEqual(CellProcessor.cell_stack, [full_backward_name]) + self.assertEqual(CellProcessor.api_parent_node, full_backward_name) + self.assertEqual(CellProcessor.module_node, {full_forward_name: None, full_backward_name: None}) + self.scope.begin_module.assert_called_with(full_backward_name) + mock_backward_data_hook.assert_not_called() + self.assertIsNone(ret) + + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack = [full_forward_name] + CellProcessor.api_parent_node = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + mock_bw.reset_mock() + args = (Tensor([1.0]),) + output = (Tensor([2.0]),) + mock_forward_data_hook_hook.return_value = output + target_output = (Tensor([0.5]),) + # call testing function - forward_hook + ret = forward_hook(mock_cell, args, output) + self.assertEqual(mock_bw.call_count, 2) + self.assertEqual(mock_bw.call_args_list[0][0][0], *output) + self.assertEqual(mock_bw.call_args_list[1][0][0], mock_bw.return_value) + self.assertTrue((ret[0] == target_output[0]).all()) + + CellProcessor.cell_count = {cell_name: 0} + CellProcessor.cell_stack = [full_forward_name] + CellProcessor.api_parent_node = full_forward_name + CellProcessor.module_node = {full_forward_name: None} + CellProcessor.cell_bw_hook_kernels.clear() + CellProcessor.cell_backward_pre_hook.clear() + mock_bw.reset_mock() + mock_bw.return_value = (Tensor([0.5]),) + output = (Tensor([1.0]), Tensor([2.0])) + mock_forward_data_hook_hook.return_value = output + with self.assertRaises(TypeError) as context: + # call testing function - forward_hook + forward_hook(mock_cell, args, output) + self.assertEqual(str(context.exception), + 'The backward pre hook return value size is 1 not equal to output size 2') + mock_bw.assert_called_with(*output) + + self.scope.reset_mock() + backward_pre_hook = CellProcessor.cell_backward_pre_hook[-1][full_backward_name] + # call testing function - backward_pre_hook + ret = backward_pre_hook(mock_cell, grad_output) + self.assertFalse(mock_cell.has_pre_hook_called) + self.scope.begin_module.assert_called_with(full_backward_name) + mock_backward_data_hook.assert_called_with(mock_cell, (), grad_output) + self.assertEqual(CellProcessor.cell_stack, []) + self.assertIsNone(CellProcessor.api_parent_node) + self.assertEqual(CellProcessor.module_node, {full_forward_name: None, full_backward_name: None}) + self.scope.end_module.assert_called_with(full_backward_name) + self.assertIsNone(ret) + + CellProcessor.reset_cell_stats() + def test_set_construct_info_in_pre_hook(self): CellProcessor.reset_cell_stats() self.processor.set_construct_info_in_pre_hook('full_name') diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py index f12fb2b0fbc16b2be86a5e9b4dc9d967be68788b..569d92c0234737bb0bcff1685fe7a9fdb1db1519 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_ms_service.py @@ -19,15 +19,13 @@ from collections import defaultdict from unittest.mock import MagicMock, patch from mindspore import nn, ops +import torch from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common.utils import Const -from msprobe.core.data_dump.api_registry import ApiRegistry from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.cell_processor import CellProcessor from msprobe.mindspore.common.log import logger -from msprobe.mindspore.common.utils import register_backward_hook_functions -from msprobe.mindspore.dump.hook_cell.api_register import get_api_register from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.dump.jit_dump import JitDump from msprobe.mindspore.service import Service @@ -41,36 +39,90 @@ class TestService(unittest.TestCase): self.config_mock.step = [] self.config_mock.rank = [] self.config_mock.task = Const.TENSOR - self.config_mock.framework = Const.MS_FRAMEWORK self.config_mock.list = [] self.config_mock.scope = [] - self.service = Service(self.config_mock) - self.service.model = MagicMock(spec=nn.Cell) - self.service.data_collector = MagicMock() - self.service.primitive_hook_service = MagicMock() - - def tearDown(self) -> None: - get_api_register().restore_all_api() + with patch('msprobe.mindspore.service.build_data_collector'), \ + patch('msprobe.mindspore.service.CellProcessor'), \ + patch('msprobe.mindspore.service.PrimitiveHookService'), \ + patch('msprobe.mindspore.service.get_api_register'): + self.service = Service(self.config_mock) def test_init(self): - self.assertEqual(self.service.config.level, "L0") - self.assertFalse(self.service.switch) - self.assertFalse(self.service.should_stop_service) - self.assertFalse(self.service.start_call) - self.assertTrue(self.service.first_start) - - def test_check_model_valid_with_valid_cell(self): - model = nn.Cell() - model_list = [model] - self.assertEqual(self.service.check_model_valid(model), model) - self.assertEqual(self.service.check_model_valid(model_list), model_list) - - def test_check_model_valid_with_invalid_type(self): - model = nn.Cell() - with self.assertRaises(MsprobeException): - self.service.check_model_valid("not a cell") - with self.assertRaises(MsprobeException): - self.service.check_model_valid(["not a cell", model]) + with patch('msprobe.mindspore.service.build_data_collector') as mock_build_data_collector, \ + patch('msprobe.mindspore.service.CellProcessor') as mock_CellProcessor, \ + patch('msprobe.mindspore.service.PrimitiveHookService') as mock_PrimitiveHookService, \ + patch('msprobe.mindspore.service.get_api_register') as mock_get_api_register, \ + patch.object(Service, 'register_api_hook') as mock_register_api_hook, \ + patch.object(Service, 'init_for_debug_level') as mock_init_for_debug_level: + self.service = Service(self.config_mock) + self.assertIsNone(self.service.model) + self.assertEqual(self.service.config.level_ori, Const.LEVEL_L0) + self.assertEqual(self.service.config.dump_path, '/tmp/dump') + self.assertEqual(self.service.config.step, []) + self.assertEqual(self.service.config.rank, []) + self.assertEqual(self.service.config.task, Const.TENSOR) + self.assertEqual(self.service.config.list, []) + self.assertEqual(self.service.config.scope, []) + self.assertEqual(self.service.config.level, Const.LEVEL_L0) + mock_build_data_collector.assert_called_with(self.service.config) + mock_CellProcessor.assert_called_with(mock_build_data_collector.return_value.scope) + mock_PrimitiveHookService.assert_called_with(self.service) + self.assertFalse(self.service.switch) + self.assertFalse(self.service.inner_switch) + self.assertFalse(self.service.primitive_switch) + self.assertEqual(self.service.current_iter, 0) + self.assertEqual(self.service.loop, 0) + self.assertEqual(self.service.init_step, 0) + self.assertTrue(self.service.first_start) + self.assertIsNone(self.service.current_rank) + self.assertIsNone(self.service.dump_iter_dir) + self.assertFalse(self.service.start_call) + self.assertFalse(self.service.should_stop_service) + self.assertEqual(self.service.params_grad_info, {}) + self.assertEqual(self.service.hook_handle_dict, {}) + mock_get_api_register.assert_called_with() + mock_register_api_hook.assert_called_with() + mock_init_for_debug_level.assert_called_with() + + def test_check_model_valid(self): + with patch('msprobe.mindspore.service.is_mindtorch') as mock_is_mindtorch: + mock_is_mindtorch.return_value = False + model = None + self.assertIsNone(self.service.check_model_valid(model)) + model = 'model' + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(model) + self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + self.assertIn("The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " + "currently there is a type.", str(context.exception)) + model = nn.Cell() + self.assertEqual(self.service.check_model_valid(model), model) + models = [model] + self.assertEqual(self.service.check_model_valid(models), models) + models = [model, 'model'] + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(models) + self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + self.assertIn("The 'model' parameter must be a mindspore.nn.Cell or list[mindspore.nn.Cell] type, " + "currently there is a type.", str(context.exception)) + + mock_is_mindtorch.return_value = True + model = 'model' + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(model) + self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + self.assertIn("The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " + "currently there is a type.", str(context.exception)) + model = torch.nn.Module() + self.assertEqual(self.service.check_model_valid(model), model) + models = [model] + self.assertEqual(self.service.check_model_valid(models), models) + models = [model, 'model'] + with self.assertRaises(MsprobeException) as context: + self.service.check_model_valid(models) + self.assertEqual(context.exception.code, MsprobeException.INVALID_PARAM_ERROR) + self.assertIn("The 'model' parameter must be a torch.nn.Module or list[torch.nn.Module] type, " + "currently there is a type.", str(context.exception)) def test_update_primitive_counters(self): self.service.primitive_counters = {} @@ -85,35 +137,59 @@ class TestService(unittest.TestCase): self.service.current_rank = 0 self.service.data_collector.tasks_need_tensor_data = [Const.TENSOR] self.service.data_collector.update_dump_paths = MagicMock() - self.service.create_dirs() expected_calls = [ ("/tmp/dump"), ("/tmp/dump/step1/rank0"), "/tmp/dump/step1/rank0/dump_tensor_data" ] - mock_create_directory.assert_has_calls( - [unittest.mock.call(path) for path in expected_calls], any_order=True) - - args, _ = self.service.data_collector.update_dump_paths.call_args - self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") - self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") - self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") - self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") - self.service.data_collector.initialize_json_file.assert_called_once_with( - framework=Const.MS_FRAMEWORK - ) - + with patch('msprobe.mindspore.service.is_mindtorch') as mock_is_mindtorch: + mock_is_mindtorch.return_value = False + self.service.create_dirs() + mock_create_directory.assert_has_calls( + [unittest.mock.call(path) for path in expected_calls], any_order=True) + + args, _ = self.service.data_collector.update_dump_paths.call_args + self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") + self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") + self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") + self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") + self.service.data_collector.initialize_json_file.assert_called_once_with( + framework=Const.MS_FRAMEWORK + ) + + mock_create_directory.reset_mock() + self.service.data_collector.update_dump_paths.reset_mock() + self.service.data_collector.initialize_json_file.reset_mock() + + mock_is_mindtorch.return_value = True + self.service.create_dirs() + mock_create_directory.assert_has_calls( + [unittest.mock.call(path) for path in expected_calls], any_order=True) + + args, _ = self.service.data_collector.update_dump_paths.call_args + self.assertEqual(args[0].dump_file_path, "/tmp/dump/step1/rank0/dump.json") + self.assertEqual(args[0].stack_file_path, "/tmp/dump/step1/rank0/stack.json") + self.assertEqual(args[0].construct_file_path, "/tmp/dump/step1/rank0/construct.json") + self.assertEqual(args[0].dump_tensor_data_dir, "/tmp/dump/step1/rank0/dump_tensor_data") + self.service.data_collector.initialize_json_file.assert_called_once_with( + framework=Const.MT_FRAMEWORK + ) + + @patch.object(Service, 'check_model_valid') @patch.object(Service, 'need_end_service', return_value=False) - def test_start_stop_cycle(self, mock_need_end_service): + def test_start_stop_cycle(self, mock_need_end_service, mock_check_model_valid): self.service.model = nn.Cell() - with patch.object(self.service, 'register_cell_hook') as mock_register_hook: - self.should_stop_service = False - self.service.start(self.service.model) - self.assertTrue(self.service.switch) - self.service.stop() - self.assertFalse(self.service.switch) - mock_register_hook.assert_called_once() - mock_need_end_service.assert_called_once() + mock_check_model_valid.return_value = self.service.model + self.should_stop_service = False + self.service.start(self.service.model) + mock_check_model_valid.assert_called_with(self.service.model) + self.assertTrue(self.service.switch) + self.service.stop() + self.assertFalse(self.service.switch) + self.service.cell_processor.register_cell_hook.assert_called_once() + mock_need_end_service.assert_called_once() + + self.service.cell_processor.register_cell_hook.reset_mock() def test_should_execute_hook_return_false(self): cell = MagicMock() @@ -174,17 +250,16 @@ class TestService(unittest.TestCase): @patch.object(Service, 'need_end_service', return_value=False) @patch.object(logger, 'info') - @patch.object(Service, 'register_cell_hook') @patch.object(Service, 'register_primitive_hook') @patch.object(Service, 'create_dirs') @patch('msprobe.mindspore.service.get_rank_if_initialized', return_value=0) def test_start_first_time(self, mock_get_rank, mock_create_dirs, mock_register_primitive_hook, - mock_register_cell_hook, mock_logger, mock_need_end_service): + mock_logger, mock_need_end_service): self.service.first_start = True self.service.should_stop_service = False self.service.start(self.service.model) mock_get_rank.assert_called_once() - mock_register_cell_hook.assert_called_once() + self.service.cell_processor.register_cell_hook.assert_called_once() mock_register_primitive_hook.assert_called_once() mock_need_end_service.assert_called_once() mock_create_dirs.assert_called_once() @@ -193,27 +268,29 @@ class TestService(unittest.TestCase): self.assertTrue(self.service.primitive_switch) mock_logger.assert_called_with(f"Dump data will be saved in {self.service.dump_iter_dir}.") + self.service.cell_processor.register_cell_hook.reset_mock() + @patch.object(Service, 'register_primitive_hook') - @patch.object(Service, 'register_cell_hook') @patch.object(Service, 'need_end_service', return_value=False) @patch.object(JitDump, 'set_config') @patch.object(JitDump, 'set_data_collector') - @patch.object(ApiRegistry, 'register_all_api') - def test_start_with_jit_dump_enabled(self, mock_api_set_hook_func, mock_set_data_collector, - mock_set_config, mock_need_end_service, mock_register_cell_hook, - mock_register_primitive_hook): + def test_start_with_jit_dump_enabled(self, mock_set_data_collector, mock_set_config, + mock_need_end_service, mock_register_primitive_hook): self.service.config.level = Const.LEVEL_MIX self.service.first_start = True self.service.should_stop_service = False self.service.start(self.service.model) mock_set_config.assert_called_with(self.service.config) mock_set_data_collector.assert_called_with(self.service.data_collector) - mock_api_set_hook_func.assert_called_once() + self.service.api_register.register_all_api.assert_called_once() mock_need_end_service.assert_called_once() - mock_register_cell_hook.assert_called_once() + self.service.cell_processor.register_cell_hook.assert_called_once() mock_register_primitive_hook.assert_called_once() self.assertTrue(JitDump.jit_dump_switch) + self.service.api_register.register_all_api.reset_mock() + self.service.cell_processor.register_cell_hook.reset_mock() + def test_step_updates(self): CellProcessor.cell_count = {"test_api": 1} HOOKCell.cell_count = {"test_api": 1} @@ -236,14 +313,13 @@ class TestService(unittest.TestCase): self.service.data_collector.backward_data_collect = MagicMock() mock_cell = MagicMock() - mock_cell.mindstudio_reserved_name = "TestCell" mock_input = (MagicMock(),) mock_output = MagicMock() - _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook") + _, forward_hook, backward_hook, _ = self.service.build_hook(BaseScope.Module_Type_Module, "TestHook.forward.0") forward_hook(mock_cell, mock_input, mock_output) - self.service.data_collector.update_api_or_module_name.assert_called_with('TestCell') + self.service.data_collector.update_api_or_module_name.assert_called_with('TestHook.forward.0') self.service.data_collector.forward_data_collect.assert_called() self.service.data_collector.reset_mock() @@ -252,52 +328,33 @@ class TestService(unittest.TestCase): mock_grad_output = MagicMock() backward_hook(mock_cell, mock_grad_input, mock_grad_output) - self.service.data_collector.update_api_or_module_name.assert_called_with('TestHookbackward.0') + self.service.data_collector.update_api_or_module_name.assert_called_with('TestHook.backward.0') self.service.data_collector.backward_data_collect.assert_called() def test_register_primitive_hook(self): self.service.config.level = Const.LEVEL_MIX primitive_attr = ops.Add() primitive_name = "primitive_api" + mock_model = MagicMock() cell_mock = MagicMock() cell_mock.primitive_api = primitive_attr primitive_combined_name = primitive_name + Const.SEP + primitive_attr.__class__.__name__ - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - self.service.register_primitive_hook() + self.service.model = mock_model + with patch('msprobe.mindspore.service.get_cells_and_names') as mock_get_cells_and_names: + mock_get_cells_and_names.return_value = {'-1': [("cell_name", cell_mock)]} + self.service.register_primitive_hook() self.assertTrue(hasattr(primitive_attr.__class__, '__call__')) self.assertEqual(self.service.primitive_hook_service.wrap_primitive.call_args[0][1], primitive_combined_name) - @patch.object(ApiRegistry, 'initialize_hook') - @patch.object(ApiRegistry, 'register_all_api') @patch("msprobe.mindspore.service.logger.info") - def test_register_hook_new_with_level_mix(self, mock_logger, mock_api_set_hook_func, mock_initialize_hook): + def test_register_hook_new_with_level_mix(self, mock_logger): self.service.config.level = Const.LEVEL_MIX self.service.register_api_hook() - self.service.register_cell_hook() - mock_logger.assert_called_with(f"The cell {self.service.config.task} hook function " - "is successfully mounted to the model.") - mock_api_set_hook_func.assert_called() - mock_initialize_hook.assert_called() - - @patch.object(CellProcessor, 'node_hook') - def test_register_hook_new_with_level_l0(self, mock_node_hook): - global register_backward_hook_functions - self.service.config.level = Const.LEVEL_L0 - cell_mock = MagicMock() - setattr(MagicMock, 'construct', None) - self.service.model.cells_and_names.return_value = [("cell_name", cell_mock)] - register_backward_hook_functions["pre"] = cell_mock.register_backward_pre_hook - register_backward_hook_functions["full"] = cell_mock.register_backward_hook - self.service.register_cell_hook() - cell_mock.register_forward_hook.assert_called() - cell_mock.register_backward_hook.assert_called() - mock_node_hook.assert_called() - register_backward_hook_functions = {} - del MagicMock.construct - - def test_register_hook_new_without_model_raises_exception(self): - self.service.config.level = Const.LEVEL_L0 - self.service.model = None - with self.assertRaises(MsprobeException): - self.service.register_cell_hook() + mock_logger.assert_called_with(f'The api {self.service.config.task} hook function ' + 'is successfully mounted to the model.') + self.service.api_register.initialize_hook.assert_called_once() + self.service.api_register.register_all_api.assert_called_once() + + self.service.api_register.initialize_hook.reset_mock() + self.service.api_register.register_all_api.reset_mock() diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 79deeee08e13273f08f32be26a375d1d26f5d2f1..734be70a5865fcaeaf2c66a0aed8b07c84322cda 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -1,8 +1,7 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -13,96 +12,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -""" + +from collections import defaultdict +import tempfile import unittest -import mindspore as ms -import numpy as np -import os from unittest.mock import Mock, patch -from mindspore import nn +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops -import tempfile from msprobe.core.common.utils import Const from msprobe.mindspore.service import Service -from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from collections import defaultdict from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService -from mindspore.common.tensor import Tensor - - -class DummyModel(nn.Cell): - def __init__(self): - super(DummyModel, self).__init__() - self.dense = nn.Dense(2, 2) - - def construct(self, x): - return self.dense(x) - - -class TestService(unittest.TestCase): - @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def setUp(self, _): - json_config = { - "task": "statistics", - "dump_path": "/absolute_path", - "rank": [], - "step": [0, 2], - "level": "L1" - } - - common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) - config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - self.service.primitive_switch = True # Make sure the switch is on for testing - - def test_check_model_valid_none(self): - model = None - self.assertIsNone(self.service.check_model_valid(model)) - - def test_check_model_valid_valid_model(self): - model = DummyModel() - self.assertEqual(self.service.check_model_valid(model), model) - - def test_check_model_valid_invalid_model(self): - model = "invalid_model" - with self.assertRaises(MsprobeException) as context: - self.service.check_model_valid(model) - - def test_update_primitive_counters(self): - primitive_name = "test_primitive" - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 0) - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 1) - - def test_step_updates_iteration(self): - initial_iter = self.service.loop - self.service.step() - self.assertEqual(self.service.loop, initial_iter + 1) - - @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) - def test_step_resets_counters(self, _): - # 假设在 step 调用之前已经有一些 primitive_counters - self.service.primitive_hook_service.primitive_counters["test_primitive"] = 5 - self.service.step() - self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) - self.assertEqual(HOOKCell.cell_count, defaultdict(int)) - - def test_start_calls_update_iter(self): - # 检查是否在调用 start 时调用了 update_iter - with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: - initial_iter = self.service.loop - init_step = self.service.init_step - self.service.start() - mock_update_iter.assert_called_once_with(initial_iter + init_step) class TestPrimitiveHookService(unittest.TestCase): @@ -121,19 +46,14 @@ class TestPrimitiveHookService(unittest.TestCase): common_config = CommonConfig(json_config) task_config = BaseConfig(json_config) config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - - # 模拟一个 service_instance 和 data_collector - self.mock_service_instance = Service(config) - self.mock_service_instance.switch = True - self.mock_service_instance.data_collector = Mock() - self.mock_service_instance.data_collector.dump_file_path = json_config["dump_path"] - # 初始化 PrimitiveHookService - self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) + with patch('msprobe.mindspore.service.build_data_collector'), \ + patch('msprobe.mindspore.service.CellProcessor'), \ + patch('msprobe.mindspore.service.PrimitiveHookService'), \ + patch('msprobe.mindspore.service.get_api_register'): + self.mock_service_instance = Service(config) + self.mock_service_instance.switch = True + self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) def tearDown(self): # 测试结束时删除临时目录 @@ -148,7 +68,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents @@ -163,7 +82,6 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_2) self.assertEqual(len(captured_grads), 6) # 捕获到两个梯度 - print(f"1After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 调用到达阈值,验证数据收集 self.assertTrue(self.mock_service_instance.data_collector.backward_output_data_collect.called) @@ -177,7 +95,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents @@ -214,14 +131,7 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrap_primitive 获取包装函数通过闭包显式调用backward_hook hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents - wrapped_primitive_call = self.primitive_hook_service.wrap_primitive(None, "example") - if wrapped_primitive_call.__closure__: - for i, closure in enumerate(wrapped_primitive_call.__closure__): - print(f"Closure[{i}]:", closure.cell_contents) - - if hook_primitive_inputs.__closure__: - for i, closure in enumerate(hook_primitive_inputs.__closure__): - print(f"2Closure[{i}]:", closure.cell_contents) + create_backward_hook = hook_primitive_inputs.__closure__[0].cell_contents backward_hook = create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type) @@ -235,7 +145,6 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_2) self.assertEqual(len(captured_grads), 6) # 捕获到两个梯度 - print(f"After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 调用到达阈值,验证数据收集 self.assertTrue(self.mock_service_instance.data_collector.backward_input_data_collect.called) @@ -282,18 +191,15 @@ class TestPrimitiveHookService(unittest.TestCase): updated_primitive_name = "test_primitive_input" # 调用 hook_primitive_inputs - hooked_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents(args, - captured_grads_input, - updated_primitive_name) - - # 验证 hooked_inputs 是否正确添加了 hook - for arg, hooked_arg in zip(args, hooked_inputs): - if isinstance(arg, Tensor): - print(f"Captured hooked_arg after hook: {hooked_arg}") - self.assertTrue(hasattr(hooked_arg, 'grad_fn')) - - # 打印调试信息 - print(f"Captured gradients after hook: {captured_grads_input}") + hook_primitive_inputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[0].cell_contents + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(args)) + for hooked_input in hooked_inputs: + self.assertTrue((hooked_input == target_value).all()) def test_hook_primitive_outputs(self): # 模拟前向输出 @@ -302,17 +208,16 @@ class TestPrimitiveHookService(unittest.TestCase): updated_primitive_name = "test_primitive_output" # 调用 hook_primitive_outputs - hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, "example").__closure__[ - 1].cell_contents - hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) - - # 验证 hooked_outputs 是否正确添加了 hook - for tensor, hooked_tensor in zip(out, hooked_outputs): - if isinstance(tensor, Tensor): - self.assertTrue(hasattr(hooked_tensor, 'grad_fn')) - - # 打印调试信息 - print(f"Captured gradients after output hook: {captured_grads_output}") + hook_primitive_outputs = self.primitive_hook_service.wrap_primitive(None, + "example").__closure__[1].cell_contents + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_outputs = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(out)) + for hooked_output in hooked_outputs: + self.assertTrue((hooked_output == target_value).all()) def test_wrapped_primitive_call_args(self): # 模拟前向输入 @@ -325,19 +230,18 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 wrapped_primitive_call 并检查 hooked_inputs 是否与原始 args 相同 try: - hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, - updated_primitive_name) - for arg, hooked_arg in zip(args, hooked_inputs): - if isinstance(arg, Tensor): - self.assertTrue(hasattr(hooked_arg, 'grad_fn')) - self.assertTrue(np.array_equal(arg.asnumpy(), hooked_arg.asnumpy())) - print(f"Arg type: {type(arg)}, Hooked input type: {type(hooked_arg)}") - else: - self.assertEqual(arg, hooked_arg) + with patch.object(ops, 'HookBackward') as mock_HookBackward: + target_value = Tensor([1.0]) + mock_hbw = mock_HookBackward.return_value + mock_hbw.return_value = target_value + hooked_inputs = wrapped_primitive_call.__closure__[0].cell_contents(args, captured_grads_input, + updated_primitive_name) + self.assertEqual(mock_HookBackward.call_count, len(args)) + for hooked_input in hooked_inputs: + self.assertTrue((hooked_input == target_value).all()) except Exception as e: self.fail(f"wrapped_primitive_call raised an exception: {e}") - def test_update_primitive_counters_multiple(self): # 测试更新 primitive 计数器的功能,增加多个不同名称的测试 primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"] @@ -416,13 +320,11 @@ class TestPrimitiveHookService(unittest.TestCase): for captured_grads in captured_grads_sets: updated_primitive_name = "MatMul.Backward" - num_tensors = len(captured_grads) hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) self.assertIsNotNone(backward_hook) - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward): # 模拟前向和后向钩子在同一个 primitive 中的行为 @@ -447,9 +349,6 @@ class TestPrimitiveHookService(unittest.TestCase): self.primitive_hook_service.update_primitive_counters(name) self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) - - - def test_update_primitive_counters(self): primitive_name = "MatMul" self.primitive_hook_service.update_primitive_counters(primitive_name) @@ -496,7 +395,7 @@ class TestPrimitiveHookService(unittest.TestCase): wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") # 模拟反向传播过程,调用包装的 primitive - with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect') as mock_backward_collect: + with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect'): result = wrapped_func(Mock(), input_tensor) # 验证结果是 Tensor 实例 @@ -544,7 +443,6 @@ class TestPrimitiveHookService(unittest.TestCase): # 测试 create_backward_hook 的功能 captured_grads = [] updated_primitive_name = "MatMul.Backward" - num_tensors = 2 # 创建 backward hook backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py index 5aaf0820a78339ff4f1cc5d28aff8762bae31a39..4ba3556c277f3326520547a6124170f32a9cc8e8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_dump.py @@ -16,47 +16,68 @@ import unittest from unittest.mock import patch, MagicMock -import torch -import torch.nn as nn +from torch import nn -from msprobe.core.data_dump.api_registry import ApiRegistry -from msprobe.pytorch import PrecisionDebugger -from msprobe.pytorch.hook_module.api_register import get_api_register -from msprobe.pytorch.service import torch_version_above_or_equal_2 +from msprobe.pytorch.common.log import logger +from msprobe.pytorch.dump.module_dump.module_dump import ModuleDumper +from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser class TestModuleDumper(unittest.TestCase): - @classmethod - def setUpClass(cls): - PrecisionDebugger._instance = None - get_api_register().restore_all_api() + def setUp(self): + self.service = MagicMock() + with patch('msprobe.pytorch.dump.module_dump.module_dump.get_api_register'): + self.module_dumper = ModuleDumper(self.service) - @classmethod - def tearDownClass(cls): - PrecisionDebugger._instance = None - get_api_register().restore_all_api() + def test__init__(self): + self.service = MagicMock() + with patch('msprobe.pytorch.dump.module_dump.module_dump.get_api_register') as mock_get_api_register: + self.module_dumper = ModuleDumper(self.service) + self.assertEqual(self.module_dumper.service, self.service) + mock_get_api_register.assert_called_once() - def setUp(self): - self.module = nn.Linear(8, 4) - debugger = PrecisionDebugger(dump_path="./") - self.module_dumper = debugger.module_dumper + def test_start_module_dump(self): + module = nn.Module() + with patch.object(logger, 'info_on_rank_0') as mock_info: + module.msprobe_hook = True + ModuleProcesser.enable_module_dump = False + self.module_dumper.api_register.restore_all_api.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_called_with('The init dump is enabled, and the module dump function will not be available.') + self.assertFalse(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_not_called() + self.assertFalse(hasattr(module, 'msprobe_module_dump')) + + del module.msprobe_hook + mock_info.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_not_called() + self.assertTrue(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_called_once() + self.module_dumper.service.module_processor.register_module_hook.assert_called_with( + module, + self.module_dumper.service.build_hook, + recursive=False, + module_names=['dump_name'] + ) + self.assertTrue(module.msprobe_module_dump) + ModuleProcesser.enable_module_dump = False + + self.module_dumper.api_register.restore_all_api.reset_mock() + self.module_dumper.service.module_processor.register_module_hook.reset_mock() + self.module_dumper.start_module_dump(module, 'dump_name') + mock_info.assert_not_called() + self.assertTrue(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.restore_all_api.assert_called_once() + self.module_dumper.service.module_processor.register_module_hook.assert_not_called() + + ModuleProcesser.enable_module_dump = False def test_stop_module_dump(self): - self.module_dumper.hook_handle_list.extend([1, 2, 3]) - with patch.object(ApiRegistry, 'register_all_api') as mock_api_register: - mock_handle1 = MagicMock(spec=torch.utils.hooks.RemovableHandle) - mock_handle2 = MagicMock(spec=torch.utils.hooks.RemovableHandle) - self.module_dumper.hook_handle_list.extend([mock_handle1, mock_handle2]) - - self.module_dumper.stop_module_dump() - mock_handle1.remove.assert_called_once() - mock_handle2.remove.assert_called_once() - self.assertEqual(self.module_dumper.hook_handle_list, []) - mock_api_register.assert_called_once() - - def test_register_hook(self): - self.module_dumper.register_hook(self.module, "TestModule") - if torch_version_above_or_equal_2: - self.assertEqual(len(self.module_dumper.hook_handle_list), 6) - else: - self.assertEqual(len(self.module_dumper.hook_handle_list), 5) + ModuleProcesser.enable_module_dump = True + self.module_dumper.api_register.register_all_api.reset_mock() + self.module_dumper.stop_module_dump() + self.assertFalse(ModuleProcesser.enable_module_dump) + self.module_dumper.api_register.register_all_api.assert_called_once() + + self.module_dumper.api_register.register_all_api.reset_mock() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py index 20cfdfa6ba399d274ca67effb6f93a7c3762edce..832f63f8fd99b53d8d1909bee45e7a5634c6ca92 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/dump/test_module_processer.py @@ -1,10 +1,24 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock import torch from msprobe.core.data_dump.scope import ModuleRangeScope -from msprobe.pytorch.common.utils import Const from msprobe.pytorch.dump.module_dump.module_processer import ModuleProcesser @@ -25,58 +39,12 @@ class TestModuleProcesser(unittest.TestCase): processor = ModuleProcesser(scope) self.assertIsNone(processor.scope) - def test_module_count_func(self): + def test_set_and_get_calls_number(self): + ModuleProcesser.reset_module_stats() test = ModuleProcesser(None) self.assertEqual(test.module_count, {}) module_name = "nope" - test.module_count_func(module_name) + test.set_and_get_calls_number(module_name) self.assertEqual(test.module_count["nope"], 0) - def test_node_hook_forward_start(self): - name_prefix = "forward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.START) - module = MagicMock() - input = (self.mock_tensor,) - module.mindstudio_reserved_name = None - hook(module, input) - expected_name = f"forward_layer{Const.SEP}0" - self.assertEqual(module.mindstudio_reserved_name, [expected_name]) - self.assertIn(expected_name, ModuleProcesser.module_stack) - self.assertEqual(ModuleProcesser.api_parent_node, expected_name) - - def test_node_hook_forward_stop(self): - name_prefix = "forward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.STOP) - ModuleProcesser.module_stack.append(f"forward_layer{Const.SEP}0") - - module = MagicMock() - input = (self.mock_tensor,) - reserved_name = f"forward_layer{Const.SEP}0" - module.mindstudio_reserved_name = [reserved_name] - hook(module, input) - self.assertNotIn([f"forward_layer{Const.SEP}0"], ModuleProcesser.module_stack) - self.assertEqual(ModuleProcesser.api_parent_node, reserved_name) - - def test_node_hook_backward(self): - name_prefix = "backward_layer" - hook = self.processor.node_hook(name_prefix, start_or_stop=Const.START) - - module = MagicMock() - input = (self.mock_tensor,) - module.mindstudio_reserved_name = None - ModuleProcesser.module_node[f"forward_layer{Const.SEP}0"] = None - hook(module, input) - expected_name = f"backward_layer{Const.SEP}0" - self.assertEqual(module.mindstudio_reserved_name, [expected_name]) - self.assertIn(expected_name, ModuleProcesser.module_node) - - def test_has_register_backward_hook(self): - module = MagicMock() - module._backward_hooks = {0: lambda: None} - module._is_full_backward_hook = False - result = self.processor.has_register_backward_hook(module) - self.assertTrue(result) - - module._is_full_backward_hook = True - result = self.processor.has_register_backward_hook(module) - self.assertFalse(result) + ModuleProcesser.reset_module_stats() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py index 1524a82ae1fc81eee245fa73bde4b4938cb89638..d907b81af97aeafdd5e35de2bec0fecd97399835 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py @@ -1,12 +1,29 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock, patch import threading + from msprobe.pytorch.hook_module.hook_module import HOOKModule + class TestHOOKModuleInit(unittest.TestCase): def setUp(self): - self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) + self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock())) def test_thread_handling(self): module = HOOKModule(self.mock_build_hook) @@ -16,7 +33,7 @@ class TestHOOKModuleInit(unittest.TestCase): class TestHOOKModuleCall(unittest.TestCase): def setUp(self): - self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock(), None)) + self.mock_build_hook = MagicMock(return_value=(MagicMock(), MagicMock(), MagicMock())) self.module = HOOKModule(self.mock_build_hook) @patch.object(HOOKModule, '_call_func') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py index af669cb5c73de85e51f36f62f9e7dc61bb599ca1..e565c1cc08d496bd96cc1e873f50e4c02e5c69a8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_wrap_aten.py @@ -1,15 +1,34 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import unittest from unittest.mock import MagicMock, patch import torch from msprobe.pytorch.function_factory import npu_custom_grad_functions -from msprobe.pytorch.hook_module.wrap_aten import AtenOPTemplate, white_aten_ops, \ +from msprobe.pytorch.hook_module.wrap_aten import ( + AtenOPTemplate, + white_aten_ops, AtenOPPacketTemplate +) def mock_build_hook(prefix): - return (MagicMock(), MagicMock(), MagicMock(), MagicMock()) + return (MagicMock(), MagicMock(), MagicMock()) + class TestAtenOPTemplate(unittest.TestCase): @@ -79,8 +98,8 @@ class TestAtenOPPacketTemplate(unittest.TestCase): del self.mock_op_packet.nonexistent_attr with self.assertRaises(AttributeError) as context: _ = self.template.nonexistent_attr - self.assertIn("or OpOverloadPacket does not have attribute 'nonexistent_attr'.", \ - str(context.exception)) + self.assertIn("or OpOverloadPacket does not have attribute 'nonexistent_attr'.", + str(context.exception)) @patch('msprobe.pytorch.hook_module.wrap_aten.AtenOPTemplate', autospec=True) def test_getattr_op_overload(self, MockAtenOPTemplate):