diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py new file mode 100644 index 0000000000000000000000000000000000000000..a1d21b81dab3fe9defc2d551cb729cc85594c976 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py @@ -0,0 +1,200 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os + +import mindspore as ms +from mindspore.common.tensor import Tensor +from mindspore import ops + +from msprobe.core.common.utils import Const +from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \ + ModuleBackwardInputs, ModuleBackwardOutputs + + +class PrimitiveHookService: + def __init__(self, service_instance): + self.primitive_counters = {} + self.service_instance = service_instance + + def wrap_primitive(self, origin_func, primitive_name): + """ + 包装原始的 primitive 函数,添加输入和输出的 hook 以捕获前向和反向数据。 + + Args: + origin_func (callable): 原始 的 primitive 函数。 + primitive_name (str): 原始的 primitive 名称。 + + Returns: + callable: 包装后的 primitive 函数。 + """ + def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type): + """ + 创建反向 hook 函数,用于捕获梯度。 + + Args: + captured_grads (list): 用于保存捕获的梯度。 + num_tensors (int): 张量数量。 + updated_primitive_name (str): 更新后的 primitive 名称。 + hook_type (str): hook 类型 (输入/输出)。 + + Returns: + callable: 反向 hook 函数。 + """ + def backward_hook(grad): + + captured_grads.append(grad) + backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" + + try: + if len(captured_grads) == num_tensors and hook_type == Const.INPUT: + self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name) + new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads)) + self.service_instance.data_collector.backward_output_data_collect( + backward_primitive_name, self, os.getpid(), new_module_input_output + ) + captured_grads.clear() + elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT: + self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name) + new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads)) + self.service_instance.data_collector.backward_input_data_collect( + backward_primitive_name, self, os.getpid(), new_module_input_output + ) + captured_grads.clear() + + except Exception as exception: + raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception}," + f" updated_primitive_name: {updated_primitive_name}") from exception + + return backward_hook + + def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name): + """ + 针对前向输入添加 hook。 + + Args: + args (tuple): primitive 输入参数。 + captured_grads_input (list): 捕获的输入梯度。 + updated_primitive_name (str): 更新后的 primitive 名称。 + + Returns: + list: 添加了 hook 的输入。 + """ + hooked_inputs = [] + num_tensors = sum(isinstance(arg, Tensor) for arg in args) + input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, + Const.INPUT) + for arg in args: + if isinstance(arg, Tensor): + arg_hooked = ops.HookBackward(input_backward_hook)(arg) + hooked_inputs.append(arg_hooked) + else: + hooked_inputs.append(arg) + return hooked_inputs + + def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name): + """ + 针对前向输出添加 hook。 + + Args: + out (Tensor/tuple): primitive 输出。 + captured_grads_output (list): 捕获的输出梯度。 + updated_primitive_name (str): 更新后的 primitive 名称。 + + Returns: + Tensor/tuple: 添加了 hook 的输出。 + """ + if isinstance(out, tuple): + num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out) + else: + num_output_tensors = 1 + output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors, + updated_primitive_name, Const.OUTPUT) + + if isinstance(out, Tensor): + return ops.HookBackward(output_backward_hook)(out) + elif isinstance(out, tuple): + hooked_outputs = [] + for tensor in out: + if isinstance(tensor, Tensor): + hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor)) + else: + hooked_outputs.append(tensor) + return tuple(hooked_outputs) + return out + + def wrapped_primitive_call(instance_self, *args, **kwargs): + """ + 包装后的 primitive 调用函数,添加输入和输出的 hook。 + + Args: + instance_self (object): primitive 的实例。 + *args: primitive 输入参数。 + **kwargs: primitive 关键字参数。 + + Returns: + Tensor/tuple: primitive 的返回值。 + """ + self.update_primitive_counters(primitive_name) + current_count = self.primitive_counters.get(primitive_name, 0) + updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}" + + if not self.service_instance.switch: + return origin_func(*args, **kwargs) + + captured_grads_input, captured_grads_output = [], [] + + try: + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) + except Exception as exception: + raise Exception("This is a primitive op dump error during input hooking: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + try: + out = origin_func(*hooked_inputs, **kwargs) + except Exception as exception: + raise Exception("This is a primitive op dump error during function call: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}" + self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name) + if self.service_instance.data_collector: + module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out) + try: + self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self, + os.getpid(), module_input_output) + except Exception as exception: + raise Exception("This is a primitive op dump error during forward data collection: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + if self.service_instance.data_collector.if_return_forward_new_output(): + out = self.service_instance.data_collector.get_forward_new_output() + + try: + out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) + except Exception as exception: + raise Exception("This is a primitive op dump error during output hooking: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + return out + + return wrapped_primitive_call + + def update_primitive_counters(self, primitive_name): + if primitive_name not in self.primitive_counters: + self.primitive_counters[primitive_name] = 0 + else: + self.primitive_counters[primitive_name] += 1 + diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 2b9c16f6210bc3fe75a4f5b4c4ff99fb5b128f4d..22aaa8f0e1317345afa24c899faa4e2822d4b381 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -15,16 +15,21 @@ # limitations under the License. """ import unittest +import numpy as np from unittest.mock import Mock, patch from mindspore import nn +import tempfile +from msprobe.core.common.utils import Const from msprobe.mindspore.service import Service from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from collections import defaultdict +from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService +from mindspore.common.tensor import Tensor class DummyModel(nn.Cell): @@ -63,34 +68,16 @@ class TestService(unittest.TestCase): model = DummyModel() self.assertEqual(self.service.check_model_valid(model), model) - def test_check_model_valid_invalid_model(self): - model = "invalid_model" - with self.assertRaises(MsprobeException) as context: - self.service.check_model_valid(model) - # For the purpose of the test, let's also verify the expected exception message - expected_message = f"{MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)}model 参数必须是 mindspore.nn.Cell 类型。" - self.assertEqual(str(context.exception), expected_message) - def test_update_primitive_counters(self): - primitive_name = "test_primitive" - self.service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_counters[primitive_name], 0) - self.service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_counters[primitive_name], 1) + def test_step_updates_iteration(self): initial_iter = self.service.current_iter self.service.step() self.assertEqual(self.service.current_iter, initial_iter + 1) - @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) - def test_step_resets_counters(self, _): - # 假设在 step 调用之前已经有一些 primitive_counters - self.service.primitive_counters["test_primitive"] = 5 - self.service.step() - self.assertEqual(self.service.primitive_counters, {}) - self.assertEqual(HOOKCell.cell_count, defaultdict(int)) + def test_step_calls_update_iter(self): # 检查是否在调用 step 时调用了 update_iter @@ -98,3 +85,255 @@ class TestService(unittest.TestCase): initial_iter = self.service.current_iter self.service.step() mock_update_iter.assert_called_once_with(initial_iter + 1) + + +class TestPrimitiveHookService(unittest.TestCase): + def setUp(self): + # 创建一个临时目录作为 dump_path + self.temp_dir = tempfile.TemporaryDirectory() + dump_path = self.temp_dir.name + json_config = { + "task": "statistics", + "dump_path": dump_path, + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + self.service = Service(config) + self.service.model = Mock() + self.service.data_collector = Mock() + self.service.switch = True # Make sure the switch is on for testing + + # 模拟一个 service_instance 和 data_collector + self.mock_service_instance = Service(config) + self.mock_service_instance.switch = True + self.mock_service_instance.data_collector = Mock() + self.mock_service_instance.data_collector.dump_file_path = json_config["dump_path"] + + # 初始化 PrimitiveHookService + self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) + + def tearDown(self): + # 测试结束时删除临时目录 + self.temp_dir.cleanup() + + def test_update_primitive_counters_multiple(self): + # 测试更新 primitive 计数器的功能,增加多个不同名称的测试 + primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"] + + for name in primitive_names: + for i in range(3): + self.primitive_hook_service.update_primitive_counters(name) + self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_hook_various_inputs(self, mock_hook_backward): + # 测试不同形状和大小的 Tensor 输入 + input_tensors = [ + Tensor(np.random.randn(2, 2).astype(np.float32)), + Tensor(np.random.randn(4, 4).astype(np.float32)), + Tensor(np.random.randn(10, 10).astype(np.float32)), + ] + + for input_tensor in input_tensors: + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + result = wrapped_func(Mock(), input_tensor) + + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + self.assertIsInstance(result, Mock) + + def test_wrap_primitive_no_hook_with_invalid_input(self): + # 测试在 switch 关闭时传入无效输入时的行为 + self.mock_service_instance.switch = False + + invalid_inputs = [None, "invalid_tensor", 123] + + for invalid_input in invalid_inputs: + mock_origin_func = Mock(return_value=invalid_input) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + result = wrapped_func(Mock(), invalid_input) + mock_origin_func.assert_called_once() + self.assertEqual(result, invalid_input) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_with_multiple_hooks(self, mock_hook_backward): + # 测试多个钩子函数同时应用的行为 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟多个 primitive + primitive_names = ["MatMul", "Add", "Sub"] + + for name in primitive_names: + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, name) + result = wrapped_func(Mock(), input_tensor) + + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + self.assertIsInstance(result, Mock) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_with_exception_handling_multiple(self, mock_hook_backward): + # 模拟多个异常情况并确保它们被正确捕获 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + exception_messages = ["Invalid operation", "Null reference", "Type error"] + + for exception_message in exception_messages: + mock_origin_func = Mock(side_effect=Exception(exception_message)) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + with self.assertRaises(Exception) as context: + wrapped_func(Mock(), input_tensor) + self.assertIn(exception_message, str(context.exception)) + + def test_create_backward_hook_multiple(self): + # 测试创建多个 backward 钩子并模拟不同数量的梯度捕获 + captured_grads_sets = [[Mock()], [Mock(), Mock()], [Mock(), Mock(), Mock()]] + + for captured_grads in captured_grads_sets: + updated_primitive_name = "MatMul.Backward" + num_tensors = len(captured_grads) + hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") + + backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) + self.assertIsNotNone(backward_hook) + + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward): + # 模拟前向和后向钩子在同一个 primitive 中的行为 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "Conv2D") + + result = wrapped_func(Mock(), input_tensor) + + # 确保前向和后向 hook 均被调用 + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + + self.assertIsInstance(result, Mock) + + def test_update_primitive_counters_different_names(self): + # 测试不同 primitive 名称的计数器更新 + primitive_names = ["MatMul", "Add", "Sub", "Mul", "Conv2D"] + + for name in primitive_names: + for i in range(5): + self.primitive_hook_service.update_primitive_counters(name) + self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) + + + + + def test_update_primitive_counters(self): + primitive_name = "MatMul" + self.primitive_hook_service.update_primitive_counters(primitive_name) + self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 0) + self.primitive_hook_service.update_primitive_counters(primitive_name) + self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 1) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_hook(self, mock_hook_backward): + # 模拟一个 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装原始 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 调用包装后的 primitive + result = wrapped_func(Mock(), input_tensor) + + # 确保原始函数被调用 + mock_origin_func.assert_called_once() + + # 检查返回值是否是 Mock 实例 + self.assertIsInstance(result, Mock) + + # 确保 HookBackward 被应用 + mock_hook_backward.assert_called() + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_backward_hook(self, mock_hook_backward): + # 模拟 Tensor 输入和输出 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + grad_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 确保 HookBackward 返回一个可调用对象,该对象返回 Tensor + mock_hook_backward.return_value = lambda x: grad_tensor + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 模拟反向传播过程,调用包装的 primitive + with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect') as mock_backward_collect: + result = wrapped_func(Mock(), input_tensor) + + # 验证结果是 Tensor 实例 + self.assertIsInstance(result, Mock) + + def test_wrap_primitive_no_hook_when_switch_off(self): + # 模拟 switch 关闭的情况 + self.mock_service_instance.switch = False + + # 模拟 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 调用包装后的 primitive + result = wrapped_func(Mock(), input_tensor) + + # 确保在 switch 关闭时不应用 hook + mock_origin_func.assert_called_once() + self.assertTrue((result == input_tensor).all()) # 使用 .all() 来比较 Tensor + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_error_handling(self, mock_hook_backward): + # 模拟 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟抛出异常的原始函数 + mock_origin_func = Mock(side_effect=Exception("Mocked exception")) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 验证是否正确捕获异常 + with self.assertRaises(Exception) as context: + wrapped_func(Mock(), input_tensor) + self.assertIn("Mocked exception", str(context.exception)) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_create_backward_hook(self, mock_hook_backward): + # 测试 create_backward_hook 的功能 + captured_grads = [] + updated_primitive_name = "MatMul.Backward" + num_tensors = 2 + + # 创建 backward hook + backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") + hook = backward_hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) + + # 确保 hook 被创建并可调用 + self.assertIsNotNone(hook)