diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py index 2bf9234e6392c760665533e223c61ec2c1a70bc5..96c366446de2176e54982fa05bf1d088f31d859f 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_primitive_dump.py @@ -17,6 +17,7 @@ import unittest import mindspore as ms import numpy as np +import os from unittest.mock import Mock, patch from mindspore import nn @@ -161,17 +162,21 @@ class TestPrimitiveHookService(unittest.TestCase): # 调用 backward_hook backward_hook(grad_1) - backward_hook(grad_2) print(f"1After first backward_hook call, len(captured_grads): {len(captured_grads)}") + # 确保 captured_grads 列表在第一次调用后仍然未满 + self.assertEqual(len(captured_grads), 1) # 只捕获了一个梯度 + + backward_hook(grad_2) # 验证 data_collector 的调用 - self.mock_service_instance.data_collector.update_api_or_module_name.assert_called_once_with( - f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" + self.mock_service_instance.data_collector.backward_output_data_collect.assert_called_with( + f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}", + self.mock_service_instance, + os.getpid(), + any() # 这里用 any() 允许任意类型的输出 ) - self.mock_service_instance.data_collector.backward_output_data_collect.assert_called_once() - # 确保梯度列表在捕获后被清除 - self.assertEqual(len(captured_grads), 0) + self.assertEqual(len(captured_grads), 0) # 所有捕获的梯度都应该被清除 def test_four_input_backward_hook(self): # 模拟梯度输入 @@ -201,11 +206,12 @@ class TestPrimitiveHookService(unittest.TestCase): print(f"1After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 验证 data_collector 的调用 - self.mock_service_instance.data_collector.update_api_or_module_name.assert_called_once_with( - f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" + self.mock_service_instance.data_collector.backward_output_data_collect.assert_called_with( + f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}", + self.mock_service_instance, + os.getpid(), + any() # 这里用 any() 允许任意类型的输出 ) - self.mock_service_instance.data_collector.backward_output_data_collect.assert_called_once() - # 确保梯度列表在捕获后被清除 self.assertEqual(len(captured_grads), 0) @@ -239,11 +245,12 @@ class TestPrimitiveHookService(unittest.TestCase): print(f"After first backward_hook call, len(captured_grads): {len(captured_grads)}") # 验证 data_collector 的调用 - self.mock_service_instance.data_collector.update_api_or_module_name.assert_called_once_with( - f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" + self.mock_service_instance.data_collector.backward_output_data_collect.assert_called_with( + f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}", + self.mock_service_instance, + os.getpid(), + any() # 这里用 any() 允许任意类型的输出 ) - self.mock_service_instance.data_collector.backward_input_data_collect.assert_called_once() - # 确保梯度列表在捕获后被清除 self.assertEqual(len(captured_grads), 0) @@ -273,11 +280,12 @@ class TestPrimitiveHookService(unittest.TestCase): backward_hook(grad_4) # 验证 data_collector 的调用 - self.mock_service_instance.data_collector.update_api_or_module_name.assert_called_once_with( - f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}" + self.mock_service_instance.data_collector.backward_output_data_collect.assert_called_with( + f"{updated_primitive_name}{Const.SEP}{Const.BACKWARD}", + self.mock_service_instance, + os.getpid(), + any() # 这里用 any() 允许任意类型的输出 ) - self.mock_service_instance.data_collector.backward_input_data_collect.assert_called_once() - # 确保梯度列表在捕获后被清除 self.assertEqual(len(captured_grads), 0)