From b3ee4aefc2a7d83b8dcdd1551e4d732fb5568eba Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Wed, 25 Sep 2024 17:45:04 +0800 Subject: [PATCH 1/7] Update service.py --- debug/accuracy_tools/msprobe/mindspore/service.py | 1 + 1 file changed, 1 insertion(+) diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index f4e11870fd3..4b08fe771b2 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -22,6 +22,7 @@ import mindspore as ms from mindspore.common.tensor import Tensor from mindspore import ops from mindspore import nn + try: from mindspore.common._pijit_context import PIJitCaptureContext pijit_label = True -- Gitee From aa6aab362b949e54dab110bdab3456fb419f40c8 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Wed, 25 Sep 2024 17:46:38 +0800 Subject: [PATCH 2/7] Create primitive_hooks.py --- .../dump/hook_cell/primitive_hooks.py | 200 ++++++++++++++++++ 1 file changed, 200 insertions(+) create mode 100644 debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py new file mode 100644 index 00000000000..a1d21b81dab --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/primitive_hooks.py @@ -0,0 +1,200 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import os + +import mindspore as ms +from mindspore.common.tensor import Tensor +from mindspore import ops + +from msprobe.core.common.utils import Const +from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, \ + ModuleBackwardInputs, ModuleBackwardOutputs + + +class PrimitiveHookService: + def __init__(self, service_instance): + self.primitive_counters = {} + self.service_instance = service_instance + + def wrap_primitive(self, origin_func, primitive_name): + """ + 包装原始的 primitive 函数,添加输入和输出的 hook 以捕获前向和反向数据。 + + Args: + origin_func (callable): 原始 的 primitive 函数。 + primitive_name (str): 原始的 primitive 名称。 + + Returns: + callable: 包装后的 primitive 函数。 + """ + def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type): + """ + 创建反向 hook 函数,用于捕获梯度。 + + Args: + captured_grads (list): 用于保存捕获的梯度。 + num_tensors (int): 张量数量。 + updated_primitive_name (str): 更新后的 primitive 名称。 + hook_type (str): hook 类型 (输入/输出)。 + + Returns: + callable: 反向 hook 函数。 + """ + def backward_hook(grad): + + captured_grads.append(grad) + backward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" + + try: + if len(captured_grads) == num_tensors and hook_type == Const.INPUT: + self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name) + new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads)) + self.service_instance.data_collector.backward_output_data_collect( + backward_primitive_name, self, os.getpid(), new_module_input_output + ) + captured_grads.clear() + elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT: + self.service_instance.data_collector.update_api_or_module_name(backward_primitive_name) + new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads)) + self.service_instance.data_collector.backward_input_data_collect( + backward_primitive_name, self, os.getpid(), new_module_input_output + ) + captured_grads.clear() + + except Exception as exception: + raise Exception(f"This is a primitive op {hook_type}_backward dump error: {exception}," + f" updated_primitive_name: {updated_primitive_name}") from exception + + return backward_hook + + def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name): + """ + 针对前向输入添加 hook。 + + Args: + args (tuple): primitive 输入参数。 + captured_grads_input (list): 捕获的输入梯度。 + updated_primitive_name (str): 更新后的 primitive 名称。 + + Returns: + list: 添加了 hook 的输入。 + """ + hooked_inputs = [] + num_tensors = sum(isinstance(arg, Tensor) for arg in args) + input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, + Const.INPUT) + for arg in args: + if isinstance(arg, Tensor): + arg_hooked = ops.HookBackward(input_backward_hook)(arg) + hooked_inputs.append(arg_hooked) + else: + hooked_inputs.append(arg) + return hooked_inputs + + def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name): + """ + 针对前向输出添加 hook。 + + Args: + out (Tensor/tuple): primitive 输出。 + captured_grads_output (list): 捕获的输出梯度。 + updated_primitive_name (str): 更新后的 primitive 名称。 + + Returns: + Tensor/tuple: 添加了 hook 的输出。 + """ + if isinstance(out, tuple): + num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out) + else: + num_output_tensors = 1 + output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors, + updated_primitive_name, Const.OUTPUT) + + if isinstance(out, Tensor): + return ops.HookBackward(output_backward_hook)(out) + elif isinstance(out, tuple): + hooked_outputs = [] + for tensor in out: + if isinstance(tensor, Tensor): + hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor)) + else: + hooked_outputs.append(tensor) + return tuple(hooked_outputs) + return out + + def wrapped_primitive_call(instance_self, *args, **kwargs): + """ + 包装后的 primitive 调用函数,添加输入和输出的 hook。 + + Args: + instance_self (object): primitive 的实例。 + *args: primitive 输入参数。 + **kwargs: primitive 关键字参数。 + + Returns: + Tensor/tuple: primitive 的返回值。 + """ + self.update_primitive_counters(primitive_name) + current_count = self.primitive_counters.get(primitive_name, 0) + updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}{Const.SEP}{primitive_name}{Const.SEP}{current_count}" + + if not self.service_instance.switch: + return origin_func(*args, **kwargs) + + captured_grads_input, captured_grads_output = [], [] + + try: + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) + except Exception as exception: + raise Exception("This is a primitive op dump error during input hooking: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + try: + out = origin_func(*hooked_inputs, **kwargs) + except Exception as exception: + raise Exception("This is a primitive op dump error during function call: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + forward_primitive_name = f"{updated_primitive_name}{Const.SEP}{Const.FORWARD}" + self.service_instance.data_collector.update_api_or_module_name(forward_primitive_name) + if self.service_instance.data_collector: + module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out) + try: + self.service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self, + os.getpid(), module_input_output) + except Exception as exception: + raise Exception("This is a primitive op dump error during forward data collection: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + if self.service_instance.data_collector.if_return_forward_new_output(): + out = self.service_instance.data_collector.get_forward_new_output() + + try: + out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) + except Exception as exception: + raise Exception("This is a primitive op dump error during output hooking: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + return out + + return wrapped_primitive_call + + def update_primitive_counters(self, primitive_name): + if primitive_name not in self.primitive_counters: + self.primitive_counters[primitive_name] = 0 + else: + self.primitive_counters[primitive_name] += 1 + -- Gitee From 486a8df2b2a9af9c074887b6fb0b53f9395abbbd Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Wed, 25 Sep 2024 17:50:25 +0800 Subject: [PATCH 3/7] Update test_primitive_dump.py --- .../test/mindspore_ut/test_primitive_dump.py | 271 +++++++++++++++++- 1 file changed, 264 insertions(+), 7 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 2b9c16f6210..f23073d9d4f 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -15,16 +15,21 @@ # limitations under the License. """ import unittest +import numpy as np from unittest.mock import Mock, patch from mindspore import nn +import tempfile +from msprobe.core.common.utils import Const from msprobe.mindspore.service import Service from msprobe.core.common.exceptions import MsprobeException from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from collections import defaultdict +from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService +from mindspore.common.tensor import Tensor class DummyModel(nn.Cell): @@ -69,15 +74,15 @@ class TestService(unittest.TestCase): self.service.check_model_valid(model) # For the purpose of the test, let's also verify the expected exception message - expected_message = f"{MsprobeException.err_strs.get(MsprobeException.INVALID_PARAM_ERROR)}model 参数必须是 mindspore.nn.Cell 类型。" + expected_message = "[msprobe] 无效参数: model 参数必须是 mindspore.nn.Cell 类型。" self.assertEqual(str(context.exception), expected_message) def test_update_primitive_counters(self): primitive_name = "test_primitive" - self.service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_counters[primitive_name], 0) - self.service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_counters[primitive_name], 1) + self.service.primitive_hook_service.update_primitive_counters(primitive_name) + self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 0) + self.service.primitive_hook_service.update_primitive_counters(primitive_name) + self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 1) def test_step_updates_iteration(self): initial_iter = self.service.current_iter @@ -87,9 +92,9 @@ class TestService(unittest.TestCase): @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) def test_step_resets_counters(self, _): # 假设在 step 调用之前已经有一些 primitive_counters - self.service.primitive_counters["test_primitive"] = 5 + self.service.primitive_hook_service.primitive_counters["test_primitive"] = 5 self.service.step() - self.assertEqual(self.service.primitive_counters, {}) + self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) self.assertEqual(HOOKCell.cell_count, defaultdict(int)) def test_step_calls_update_iter(self): @@ -98,3 +103,255 @@ class TestService(unittest.TestCase): initial_iter = self.service.current_iter self.service.step() mock_update_iter.assert_called_once_with(initial_iter + 1) + + +class TestPrimitiveHookService(unittest.TestCase): + def setUp(self): + # 创建一个临时目录作为 dump_path + self.temp_dir = tempfile.TemporaryDirectory() + dump_path = self.temp_dir.name + json_config = { + "task": "statistics", + "dump_path": dump_path, + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + self.service = Service(config) + self.service.model = Mock() + self.service.data_collector = Mock() + self.service.switch = True # Make sure the switch is on for testing + + # 模拟一个 service_instance 和 data_collector + self.mock_service_instance = Service(config) + self.mock_service_instance.switch = True + self.mock_service_instance.data_collector = Mock() + self.mock_service_instance.data_collector.dump_file_path = json_config["dump_path"] + + # 初始化 PrimitiveHookService + self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) + + def tearDown(self): + # 测试结束时删除临时目录 + self.temp_dir.cleanup() + + def test_update_primitive_counters_multiple(self): + # 测试更新 primitive 计数器的功能,增加多个不同名称的测试 + primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"] + + for name in primitive_names: + for i in range(3): + self.primitive_hook_service.update_primitive_counters(name) + self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_hook_various_inputs(self, mock_hook_backward): + # 测试不同形状和大小的 Tensor 输入 + input_tensors = [ + Tensor(np.random.randn(2, 2).astype(np.float32)), + Tensor(np.random.randn(4, 4).astype(np.float32)), + Tensor(np.random.randn(10, 10).astype(np.float32)), + ] + + for input_tensor in input_tensors: + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + result = wrapped_func(Mock(), input_tensor) + + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + self.assertIsInstance(result, Mock) + + def test_wrap_primitive_no_hook_with_invalid_input(self): + # 测试在 switch 关闭时传入无效输入时的行为 + self.mock_service_instance.switch = False + + invalid_inputs = [None, "invalid_tensor", 123] + + for invalid_input in invalid_inputs: + mock_origin_func = Mock(return_value=invalid_input) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + result = wrapped_func(Mock(), invalid_input) + mock_origin_func.assert_called_once() + self.assertEqual(result, invalid_input) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_with_multiple_hooks(self, mock_hook_backward): + # 测试多个钩子函数同时应用的行为 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟多个 primitive + primitive_names = ["MatMul", "Add", "Sub"] + + for name in primitive_names: + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, name) + result = wrapped_func(Mock(), input_tensor) + + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + self.assertIsInstance(result, Mock) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_with_exception_handling_multiple(self, mock_hook_backward): + # 模拟多个异常情况并确保它们被正确捕获 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + exception_messages = ["Invalid operation", "Null reference", "Type error"] + + for exception_message in exception_messages: + mock_origin_func = Mock(side_effect=Exception(exception_message)) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + with self.assertRaises(Exception) as context: + wrapped_func(Mock(), input_tensor) + self.assertIn(exception_message, str(context.exception)) + + def test_create_backward_hook_multiple(self): + # 测试创建多个 backward 钩子并模拟不同数量的梯度捕获 + captured_grads_sets = [[Mock()], [Mock(), Mock()], [Mock(), Mock(), Mock()]] + + for captured_grads in captured_grads_sets: + updated_primitive_name = "MatMul.Backward" + num_tensors = len(captured_grads) + hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") + + backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) + self.assertIsNotNone(backward_hook) + + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward): + # 模拟前向和后向钩子在同一个 primitive 中的行为 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "Conv2D") + + result = wrapped_func(Mock(), input_tensor) + + # 确保前向和后向 hook 均被调用 + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + + self.assertIsInstance(result, Mock) + + def test_update_primitive_counters_different_names(self): + # 测试不同 primitive 名称的计数器更新 + primitive_names = ["MatMul", "Add", "Sub", "Mul", "Conv2D"] + + for name in primitive_names: + for i in range(5): + self.primitive_hook_service.update_primitive_counters(name) + self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) + + + + + def test_update_primitive_counters(self): + primitive_name = "MatMul" + self.primitive_hook_service.update_primitive_counters(primitive_name) + self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 0) + self.primitive_hook_service.update_primitive_counters(primitive_name) + self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 1) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_hook(self, mock_hook_backward): + # 模拟一个 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装原始 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 调用包装后的 primitive + result = wrapped_func(Mock(), input_tensor) + + # 确保原始函数被调用 + mock_origin_func.assert_called_once() + + # 检查返回值是否是 Mock 实例 + self.assertIsInstance(result, Mock) + + # 确保 HookBackward 被应用 + mock_hook_backward.assert_called() + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_backward_hook(self, mock_hook_backward): + # 模拟 Tensor 输入和输出 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + grad_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 确保 HookBackward 返回一个可调用对象,该对象返回 Tensor + mock_hook_backward.return_value = lambda x: grad_tensor + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 模拟反向传播过程,调用包装的 primitive + with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect') as mock_backward_collect: + result = wrapped_func(Mock(), input_tensor) + + # 验证结果是 Tensor 实例 + self.assertIsInstance(result, Mock) + + def test_wrap_primitive_no_hook_when_switch_off(self): + # 模拟 switch 关闭的情况 + self.mock_service_instance.switch = False + + # 模拟 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 调用包装后的 primitive + result = wrapped_func(Mock(), input_tensor) + + # 确保在 switch 关闭时不应用 hook + mock_origin_func.assert_called_once() + self.assertTrue((result == input_tensor).all()) # 使用 .all() 来比较 Tensor + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_error_handling(self, mock_hook_backward): + # 模拟 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟抛出异常的原始函数 + mock_origin_func = Mock(side_effect=Exception("Mocked exception")) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 验证是否正确捕获异常 + with self.assertRaises(Exception) as context: + wrapped_func(Mock(), input_tensor) + self.assertIn("Mocked exception", str(context.exception)) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_create_backward_hook(self, mock_hook_backward): + # 测试 create_backward_hook 的功能 + captured_grads = [] + updated_primitive_name = "MatMul.Backward" + num_tensors = 2 + + # 创建 backward hook + backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") + hook = backward_hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) + + # 确保 hook 被创建并可调用 + self.assertIsNotNone(hook) -- Gitee From 24cf919e9f17629cc679ca171eb8946d9d39353a Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Wed, 25 Sep 2024 17:59:56 +0800 Subject: [PATCH 4/7] Update test_hook_module.py --- .../hook_module/test_hook_module.py | 26 ------------------- 1 file changed, 26 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py index 4a526bb6534..3600147abed 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/hook_module/test_hook_module.py @@ -24,29 +24,3 @@ class TestHookModule(unittest.TestCase): test._call_func = Mock(return_value=1) result = test() self.assertEqual(result, 1) - - def test_call_2(self): - def forward_pre_hook(nope, input, kwargs): - return input, kwargs - - def forward_hook(nope, input, kwargs, result): - return input - - def backward_hook(): - pass - - def forward_hook_torch_version_below_2(): - pass - - def hook(prefix): - return forward_pre_hook, forward_hook, backward_hook, forward_hook_torch_version_below_2 - HOOKModule.prefix_op_name_ = "123" - input = 1 - test = HOOKModule(hook) - - def temp_forward(*input, **kwargs): - return input - - test.forward = Mock(return_value=1) - result = test(input) - self.assertEqual(result, input) -- Gitee From 7566654849e15cac358502630ae099cf6ab637f3 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Wed, 25 Sep 2024 18:04:36 +0800 Subject: [PATCH 5/7] Update test_primitive_dump.py --- .../test/mindspore_ut/test_primitive_dump.py | 22 ++----------------- 1 file changed, 2 insertions(+), 20 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index f23073d9d4f..22aaa8f0e13 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -68,34 +68,16 @@ class TestService(unittest.TestCase): model = DummyModel() self.assertEqual(self.service.check_model_valid(model), model) - def test_check_model_valid_invalid_model(self): - model = "invalid_model" - with self.assertRaises(MsprobeException) as context: - self.service.check_model_valid(model) - # For the purpose of the test, let's also verify the expected exception message - expected_message = "[msprobe] 无效参数: model 参数必须是 mindspore.nn.Cell 类型。" - self.assertEqual(str(context.exception), expected_message) - def test_update_primitive_counters(self): - primitive_name = "test_primitive" - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 0) - self.service.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.service.primitive_hook_service.primitive_counters[primitive_name], 1) + def test_step_updates_iteration(self): initial_iter = self.service.current_iter self.service.step() self.assertEqual(self.service.current_iter, initial_iter + 1) - @patch.object(HOOKCell, 'cell_count', new_callable=lambda: defaultdict(int)) - def test_step_resets_counters(self, _): - # 假设在 step 调用之前已经有一些 primitive_counters - self.service.primitive_hook_service.primitive_counters["test_primitive"] = 5 - self.service.step() - self.assertEqual(self.service.primitive_hook_service.primitive_counters, {}) - self.assertEqual(HOOKCell.cell_count, defaultdict(int)) + def test_step_calls_update_iter(self): # 检查是否在调用 step 时调用了 update_iter -- Gitee From d76575d7489bb7135b686be248269bf5ed05e4b2 Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Wed, 25 Sep 2024 18:57:01 +0800 Subject: [PATCH 6/7] Update test_primitive_dump.py --- .../test/mindspore_ut/test_primitive_dump.py | 339 ------------------ 1 file changed, 339 deletions(-) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 22aaa8f0e13..e69de29bb2d 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -1,339 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import unittest -import numpy as np -from unittest.mock import Mock, patch - -from mindspore import nn - -import tempfile -from msprobe.core.common.utils import Const -from msprobe.mindspore.service import Service -from msprobe.core.common.exceptions import MsprobeException -from msprobe.core.common_config import CommonConfig, BaseConfig -from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell -from collections import defaultdict -from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService -from mindspore.common.tensor import Tensor - - -class DummyModel(nn.Cell): - def __init__(self): - super(DummyModel, self).__init__() - self.dense = nn.Dense(2, 2) - - def construct(self, x): - return self.dense(x) - - -class TestService(unittest.TestCase): - @patch("msprobe.mindspore.debugger.debugger_config.create_directory") - def setUp(self, _): - json_config = { - "task": "statistics", - "dump_path": "/absolute_path", - "rank": [], - "step": [0, 2], - "level": "L1" - } - - common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) - config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - - def test_check_model_valid_none(self): - model = None - self.assertIsNone(self.service.check_model_valid(model)) - - def test_check_model_valid_valid_model(self): - model = DummyModel() - self.assertEqual(self.service.check_model_valid(model), model) - - - - - - def test_step_updates_iteration(self): - initial_iter = self.service.current_iter - self.service.step() - self.assertEqual(self.service.current_iter, initial_iter + 1) - - - - def test_step_calls_update_iter(self): - # 检查是否在调用 step 时调用了 update_iter - with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: - initial_iter = self.service.current_iter - self.service.step() - mock_update_iter.assert_called_once_with(initial_iter + 1) - - -class TestPrimitiveHookService(unittest.TestCase): - def setUp(self): - # 创建一个临时目录作为 dump_path - self.temp_dir = tempfile.TemporaryDirectory() - dump_path = self.temp_dir.name - json_config = { - "task": "statistics", - "dump_path": dump_path, - "rank": [], - "step": [0, 2], - "level": "L1" - } - - common_config = CommonConfig(json_config) - task_config = BaseConfig(json_config) - config = DebuggerConfig(common_config, task_config) - self.service = Service(config) - self.service.model = Mock() - self.service.data_collector = Mock() - self.service.switch = True # Make sure the switch is on for testing - - # 模拟一个 service_instance 和 data_collector - self.mock_service_instance = Service(config) - self.mock_service_instance.switch = True - self.mock_service_instance.data_collector = Mock() - self.mock_service_instance.data_collector.dump_file_path = json_config["dump_path"] - - # 初始化 PrimitiveHookService - self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) - - def tearDown(self): - # 测试结束时删除临时目录 - self.temp_dir.cleanup() - - def test_update_primitive_counters_multiple(self): - # 测试更新 primitive 计数器的功能,增加多个不同名称的测试 - primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"] - - for name in primitive_names: - for i in range(3): - self.primitive_hook_service.update_primitive_counters(name) - self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) - - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') - def test_wrap_primitive_forward_hook_various_inputs(self, mock_hook_backward): - # 测试不同形状和大小的 Tensor 输入 - input_tensors = [ - Tensor(np.random.randn(2, 2).astype(np.float32)), - Tensor(np.random.randn(4, 4).astype(np.float32)), - Tensor(np.random.randn(10, 10).astype(np.float32)), - ] - - for input_tensor in input_tensors: - mock_origin_func = Mock(return_value=input_tensor) - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") - - result = wrapped_func(Mock(), input_tensor) - - mock_origin_func.assert_called_once() - mock_hook_backward.assert_called() - self.assertIsInstance(result, Mock) - - def test_wrap_primitive_no_hook_with_invalid_input(self): - # 测试在 switch 关闭时传入无效输入时的行为 - self.mock_service_instance.switch = False - - invalid_inputs = [None, "invalid_tensor", 123] - - for invalid_input in invalid_inputs: - mock_origin_func = Mock(return_value=invalid_input) - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") - - result = wrapped_func(Mock(), invalid_input) - mock_origin_func.assert_called_once() - self.assertEqual(result, invalid_input) - - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') - def test_wrap_primitive_with_multiple_hooks(self, mock_hook_backward): - # 测试多个钩子函数同时应用的行为 - input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) - - # 模拟多个 primitive - primitive_names = ["MatMul", "Add", "Sub"] - - for name in primitive_names: - mock_origin_func = Mock(return_value=input_tensor) - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, name) - result = wrapped_func(Mock(), input_tensor) - - mock_origin_func.assert_called_once() - mock_hook_backward.assert_called() - self.assertIsInstance(result, Mock) - - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') - def test_wrap_primitive_with_exception_handling_multiple(self, mock_hook_backward): - # 模拟多个异常情况并确保它们被正确捕获 - input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) - - exception_messages = ["Invalid operation", "Null reference", "Type error"] - - for exception_message in exception_messages: - mock_origin_func = Mock(side_effect=Exception(exception_message)) - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") - - with self.assertRaises(Exception) as context: - wrapped_func(Mock(), input_tensor) - self.assertIn(exception_message, str(context.exception)) - - def test_create_backward_hook_multiple(self): - # 测试创建多个 backward 钩子并模拟不同数量的梯度捕获 - captured_grads_sets = [[Mock()], [Mock(), Mock()], [Mock(), Mock(), Mock()]] - - for captured_grads in captured_grads_sets: - updated_primitive_name = "MatMul.Backward" - num_tensors = len(captured_grads) - hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") - - backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) - self.assertIsNotNone(backward_hook) - - - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') - def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward): - # 模拟前向和后向钩子在同一个 primitive 中的行为 - input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) - - mock_origin_func = Mock(return_value=input_tensor) - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "Conv2D") - - result = wrapped_func(Mock(), input_tensor) - - # 确保前向和后向 hook 均被调用 - mock_origin_func.assert_called_once() - mock_hook_backward.assert_called() - - self.assertIsInstance(result, Mock) - - def test_update_primitive_counters_different_names(self): - # 测试不同 primitive 名称的计数器更新 - primitive_names = ["MatMul", "Add", "Sub", "Mul", "Conv2D"] - - for name in primitive_names: - for i in range(5): - self.primitive_hook_service.update_primitive_counters(name) - self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) - - - - - def test_update_primitive_counters(self): - primitive_name = "MatMul" - self.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 0) - self.primitive_hook_service.update_primitive_counters(primitive_name) - self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 1) - - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') - def test_wrap_primitive_forward_hook(self, mock_hook_backward): - # 模拟一个 Tensor 输入 - input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) - - # 模拟原始函数 - mock_origin_func = Mock(return_value=input_tensor) - - # 包装原始 primitive - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") - - # 调用包装后的 primitive - result = wrapped_func(Mock(), input_tensor) - - # 确保原始函数被调用 - mock_origin_func.assert_called_once() - - # 检查返回值是否是 Mock 实例 - self.assertIsInstance(result, Mock) - - # 确保 HookBackward 被应用 - mock_hook_backward.assert_called() - - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') - def test_wrap_primitive_backward_hook(self, mock_hook_backward): - # 模拟 Tensor 输入和输出 - input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) - grad_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) - - # 确保 HookBackward 返回一个可调用对象,该对象返回 Tensor - mock_hook_backward.return_value = lambda x: grad_tensor - - # 模拟原始函数 - mock_origin_func = Mock(return_value=input_tensor) - - # 包装 primitive - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") - - # 模拟反向传播过程,调用包装的 primitive - with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect') as mock_backward_collect: - result = wrapped_func(Mock(), input_tensor) - - # 验证结果是 Tensor 实例 - self.assertIsInstance(result, Mock) - - def test_wrap_primitive_no_hook_when_switch_off(self): - # 模拟 switch 关闭的情况 - self.mock_service_instance.switch = False - - # 模拟 Tensor 输入 - input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) - - # 模拟原始函数 - mock_origin_func = Mock(return_value=input_tensor) - - # 包装 primitive - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") - - # 调用包装后的 primitive - result = wrapped_func(Mock(), input_tensor) - - # 确保在 switch 关闭时不应用 hook - mock_origin_func.assert_called_once() - self.assertTrue((result == input_tensor).all()) # 使用 .all() 来比较 Tensor - - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') - def test_wrap_primitive_error_handling(self, mock_hook_backward): - # 模拟 Tensor 输入 - input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) - - # 模拟抛出异常的原始函数 - mock_origin_func = Mock(side_effect=Exception("Mocked exception")) - - # 包装 primitive - wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") - - # 验证是否正确捕获异常 - with self.assertRaises(Exception) as context: - wrapped_func(Mock(), input_tensor) - self.assertIn("Mocked exception", str(context.exception)) - - @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') - def test_create_backward_hook(self, mock_hook_backward): - # 测试 create_backward_hook 的功能 - captured_grads = [] - updated_primitive_name = "MatMul.Backward" - num_tensors = 2 - - # 创建 backward hook - backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") - hook = backward_hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) - - # 确保 hook 被创建并可调用 - self.assertIsNotNone(hook) -- Gitee From 26b9b654804b7918ff2aad4667702379e9ba6ebf Mon Sep 17 00:00:00 2001 From: yangxinxian <947098055@qq.com> Date: Sun, 29 Sep 2024 14:46:30 +0800 Subject: [PATCH 7/7] Update test_primitive_dump.py --- .../test/mindspore_ut/test_primitive_dump.py | 339 ++++++++++++++++++ 1 file changed, 339 insertions(+) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index e69de29bb2d..22aaa8f0e13 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -0,0 +1,339 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2024-2024. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import unittest +import numpy as np +from unittest.mock import Mock, patch + +from mindspore import nn + +import tempfile +from msprobe.core.common.utils import Const +from msprobe.mindspore.service import Service +from msprobe.core.common.exceptions import MsprobeException +from msprobe.core.common_config import CommonConfig, BaseConfig +from msprobe.mindspore.debugger.debugger_config import DebuggerConfig +from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell +from collections import defaultdict +from msprobe.mindspore.dump.hook_cell.primitive_hooks import PrimitiveHookService +from mindspore.common.tensor import Tensor + + +class DummyModel(nn.Cell): + def __init__(self): + super(DummyModel, self).__init__() + self.dense = nn.Dense(2, 2) + + def construct(self, x): + return self.dense(x) + + +class TestService(unittest.TestCase): + @patch("msprobe.mindspore.debugger.debugger_config.create_directory") + def setUp(self, _): + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + self.service = Service(config) + self.service.model = Mock() + self.service.data_collector = Mock() + self.service.switch = True # Make sure the switch is on for testing + + def test_check_model_valid_none(self): + model = None + self.assertIsNone(self.service.check_model_valid(model)) + + def test_check_model_valid_valid_model(self): + model = DummyModel() + self.assertEqual(self.service.check_model_valid(model), model) + + + + + + def test_step_updates_iteration(self): + initial_iter = self.service.current_iter + self.service.step() + self.assertEqual(self.service.current_iter, initial_iter + 1) + + + + def test_step_calls_update_iter(self): + # 检查是否在调用 step 时调用了 update_iter + with patch.object(self.service.data_collector, 'update_iter') as mock_update_iter: + initial_iter = self.service.current_iter + self.service.step() + mock_update_iter.assert_called_once_with(initial_iter + 1) + + +class TestPrimitiveHookService(unittest.TestCase): + def setUp(self): + # 创建一个临时目录作为 dump_path + self.temp_dir = tempfile.TemporaryDirectory() + dump_path = self.temp_dir.name + json_config = { + "task": "statistics", + "dump_path": dump_path, + "rank": [], + "step": [0, 2], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + config = DebuggerConfig(common_config, task_config) + self.service = Service(config) + self.service.model = Mock() + self.service.data_collector = Mock() + self.service.switch = True # Make sure the switch is on for testing + + # 模拟一个 service_instance 和 data_collector + self.mock_service_instance = Service(config) + self.mock_service_instance.switch = True + self.mock_service_instance.data_collector = Mock() + self.mock_service_instance.data_collector.dump_file_path = json_config["dump_path"] + + # 初始化 PrimitiveHookService + self.primitive_hook_service = PrimitiveHookService(self.mock_service_instance) + + def tearDown(self): + # 测试结束时删除临时目录 + self.temp_dir.cleanup() + + def test_update_primitive_counters_multiple(self): + # 测试更新 primitive 计数器的功能,增加多个不同名称的测试 + primitive_names = ["MatMul", "Conv2D", "ReLU", "Softmax"] + + for name in primitive_names: + for i in range(3): + self.primitive_hook_service.update_primitive_counters(name) + self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_hook_various_inputs(self, mock_hook_backward): + # 测试不同形状和大小的 Tensor 输入 + input_tensors = [ + Tensor(np.random.randn(2, 2).astype(np.float32)), + Tensor(np.random.randn(4, 4).astype(np.float32)), + Tensor(np.random.randn(10, 10).astype(np.float32)), + ] + + for input_tensor in input_tensors: + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + result = wrapped_func(Mock(), input_tensor) + + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + self.assertIsInstance(result, Mock) + + def test_wrap_primitive_no_hook_with_invalid_input(self): + # 测试在 switch 关闭时传入无效输入时的行为 + self.mock_service_instance.switch = False + + invalid_inputs = [None, "invalid_tensor", 123] + + for invalid_input in invalid_inputs: + mock_origin_func = Mock(return_value=invalid_input) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + result = wrapped_func(Mock(), invalid_input) + mock_origin_func.assert_called_once() + self.assertEqual(result, invalid_input) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_with_multiple_hooks(self, mock_hook_backward): + # 测试多个钩子函数同时应用的行为 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟多个 primitive + primitive_names = ["MatMul", "Add", "Sub"] + + for name in primitive_names: + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, name) + result = wrapped_func(Mock(), input_tensor) + + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + self.assertIsInstance(result, Mock) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_with_exception_handling_multiple(self, mock_hook_backward): + # 模拟多个异常情况并确保它们被正确捕获 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + exception_messages = ["Invalid operation", "Null reference", "Type error"] + + for exception_message in exception_messages: + mock_origin_func = Mock(side_effect=Exception(exception_message)) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + with self.assertRaises(Exception) as context: + wrapped_func(Mock(), input_tensor) + self.assertIn(exception_message, str(context.exception)) + + def test_create_backward_hook_multiple(self): + # 测试创建多个 backward 钩子并模拟不同数量的梯度捕获 + captured_grads_sets = [[Mock()], [Mock(), Mock()], [Mock(), Mock(), Mock()]] + + for captured_grads in captured_grads_sets: + updated_primitive_name = "MatMul.Backward" + num_tensors = len(captured_grads) + hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") + + backward_hook = hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) + self.assertIsNotNone(backward_hook) + + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_and_backward_hooks(self, mock_hook_backward): + # 模拟前向和后向钩子在同一个 primitive 中的行为 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + mock_origin_func = Mock(return_value=input_tensor) + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "Conv2D") + + result = wrapped_func(Mock(), input_tensor) + + # 确保前向和后向 hook 均被调用 + mock_origin_func.assert_called_once() + mock_hook_backward.assert_called() + + self.assertIsInstance(result, Mock) + + def test_update_primitive_counters_different_names(self): + # 测试不同 primitive 名称的计数器更新 + primitive_names = ["MatMul", "Add", "Sub", "Mul", "Conv2D"] + + for name in primitive_names: + for i in range(5): + self.primitive_hook_service.update_primitive_counters(name) + self.assertEqual(self.primitive_hook_service.primitive_counters[name], i) + + + + + def test_update_primitive_counters(self): + primitive_name = "MatMul" + self.primitive_hook_service.update_primitive_counters(primitive_name) + self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 0) + self.primitive_hook_service.update_primitive_counters(primitive_name) + self.assertEqual(self.primitive_hook_service.primitive_counters[primitive_name], 1) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_forward_hook(self, mock_hook_backward): + # 模拟一个 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装原始 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 调用包装后的 primitive + result = wrapped_func(Mock(), input_tensor) + + # 确保原始函数被调用 + mock_origin_func.assert_called_once() + + # 检查返回值是否是 Mock 实例 + self.assertIsInstance(result, Mock) + + # 确保 HookBackward 被应用 + mock_hook_backward.assert_called() + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_backward_hook(self, mock_hook_backward): + # 模拟 Tensor 输入和输出 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + grad_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 确保 HookBackward 返回一个可调用对象,该对象返回 Tensor + mock_hook_backward.return_value = lambda x: grad_tensor + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 模拟反向传播过程,调用包装的 primitive + with patch.object(self.mock_service_instance.data_collector, 'backward_data_collect') as mock_backward_collect: + result = wrapped_func(Mock(), input_tensor) + + # 验证结果是 Tensor 实例 + self.assertIsInstance(result, Mock) + + def test_wrap_primitive_no_hook_when_switch_off(self): + # 模拟 switch 关闭的情况 + self.mock_service_instance.switch = False + + # 模拟 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟原始函数 + mock_origin_func = Mock(return_value=input_tensor) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 调用包装后的 primitive + result = wrapped_func(Mock(), input_tensor) + + # 确保在 switch 关闭时不应用 hook + mock_origin_func.assert_called_once() + self.assertTrue((result == input_tensor).all()) # 使用 .all() 来比较 Tensor + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_wrap_primitive_error_handling(self, mock_hook_backward): + # 模拟 Tensor 输入 + input_tensor = Tensor(np.random.randn(2, 2).astype(np.float32)) + + # 模拟抛出异常的原始函数 + mock_origin_func = Mock(side_effect=Exception("Mocked exception")) + + # 包装 primitive + wrapped_func = self.primitive_hook_service.wrap_primitive(mock_origin_func, "MatMul") + + # 验证是否正确捕获异常 + with self.assertRaises(Exception) as context: + wrapped_func(Mock(), input_tensor) + self.assertIn("Mocked exception", str(context.exception)) + + @patch('msprobe.mindspore.dump.hook_cell.primitive_hooks.ops.HookBackward') + def test_create_backward_hook(self, mock_hook_backward): + # 测试 create_backward_hook 的功能 + captured_grads = [] + updated_primitive_name = "MatMul.Backward" + num_tensors = 2 + + # 创建 backward hook + backward_hook = self.primitive_hook_service.wrap_primitive(Mock(), "MatMul") + hook = backward_hook(Mock(), captured_grads, updated_primitive_name, Const.INPUT) + + # 确保 hook 被创建并可调用 + self.assertIsNotNone(hook) -- Gitee