diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_wrap_distributed.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_wrap_distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..7a938683ef1284b75f51cf81962a09fb050dd6a0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/ms_monitor/test_wrap_distributed.py @@ -0,0 +1,198 @@ +import unittest +import os +import inspect +import re +from unittest.mock import patch, MagicMock +import numpy as np + +import mindspore as ms +from mindspore import nn, Tensor, ops +from mindspore.communication import comm_func + +# 测试目标代码 +from msprobe.mindspore.monitor.distributed.wrap_distributed import ( + DistributedOPTemplate, + ApiRegistry, + get_distributed_ops, + get_process_group, + stack_filter, + get_callstack, + op_aggregate, + update_data, + is_target_line, + catch_data, + create_async_callback_func, + create_hooks, + api_register +) + +class TestDistributedWrapper(unittest.TestCase): + def setUp(self): + self.test_ops = ['allreduce', 'broadcast'] + self.test_rank = 0 + self.test_tensor = Tensor(np.array([1.0, 2.0, 3.0]), dtype=ms.float32) + self.test_tensor_list = [self.test_tensor, self.test_tensor] + + # 模拟通信函数 + self.mock_comm_func = MagicMock() + self.mock_comm_func.allreduce = lambda x: x + self.mock_comm_func.broadcast = lambda x, root: x + + # 模拟CommHandle + self.mock_comm_handle = MagicMock() + + # 设置测试环境 + self.original_comm_func = comm_func + comm_func.allreduce = self.mock_comm_func.allreduce + comm_func.broadcast = self.mock_comm_func.broadcast + + def tearDown(self): + # 恢复原始通信函数 + comm_func.allreduce = self.original_comm_func.allreduce + comm_func.broadcast = self.original_comm_func.broadcast + + def test_DistributedOPTemplate_construct(self): + """测试DistributedOPTemplate的construct方法""" + def pre_hook(cell, inputs): + return + + def post_hook(cell, inputs, outputs): + return + + op_template = DistributedOPTemplate('all_reduce', [pre_hook], [post_hook]) + result = op_template.construct(self.test_tensor) + + self.assertTrue(isinstance(self.test_tensor, Tensor)) + + def test_get_distributed_ops(self): + """测试获取分布式操作列表""" + with patch('msprobe.mindspore.monitor.distributed.wrap_distributed.WrapDistributedOps', ['allreduce', 'broadcast']): + ops = get_distributed_ops() + self.assertIn('allreduce', ops) + self.assertIn('broadcast', ops) + + def test_get_process_group(self): + """测试获取进程组""" + # 测试自定义组 + custom_group = "custom_group" + result = get_process_group(custom_group) + self.assertEqual(result, custom_group) + + def test_stack_filter(self): + """测试堆栈过滤器""" + # 测试允许的堆栈 + valid_stack = "path/to/file.py[123] function_name" + self.assertTrue(stack_filter(valid_stack)) + + # 测试黑名单堆栈 + with patch('msprobe.mindspore.monitor.distributed.wrap_distributed.StackBlackList', ['blacklisted']): + invalid_stack = "blacklisted/path.py[456] bad_function" + self.assertFalse(stack_filter(invalid_stack)) + + def test_get_callstack(self): + """测试获取调用堆栈""" + # 模拟inspect.stack()返回 + test_frame = ( + None, # frame + "test_file.py", # filename + 123, # lineno + "test_function", # function + None, # code_context + None # index + ) + + with patch('inspect.stack', return_value=[test_frame]): + callstack = get_callstack() + self.assertEqual(len(callstack), 1) + self.assertIn("test_file.py[123]", callstack[0]) + + def test_update_data(self): + """测试数据更新函数""" + old_data = {} + new_data = { + 'tag1': {'op1': Tensor(1.0)}, + 'tag2': {'op2': Tensor(2.0)} + } + + updated = update_data(old_data, new_data) + self.assertEqual(len(updated), 2) + self.assertEqual(len(updated['tag1']['op1']), 1) + self.assertEqual(len(updated['tag2']['op2']), 1) + + # 测试追加数据 + updated = update_data(updated, new_data) + self.assertEqual(len(updated['tag1']['op1']), 2) + + def test_is_target_line(self): + """测试目标行匹配""" + # 模拟调用堆栈 + mock_stack = [ + "file1.py[123] func1", + "target_file.py[456] target_func", + "file2.py[789] func2" + ] + + with patch('msprobe.mindspore.monitor.distributed.wrap_distributed.get_callstack', return_value=mock_stack): + # 测试匹配目标行 + self.assertTrue(is_target_line(['target_file'])) + + # 测试不匹配目标行 + self.assertFalse(is_target_line(['nonexistent'])) + + # 测试空目标行 + self.assertTrue(is_target_line([])) + + def test_catch_data(self): + """测试数据捕获函数""" + mock_context = MagicMock() + mock_context.data = {} + + ops_list = [] + args = [self.test_tensor] + + catch_data(mock_context, 'test_op', ops_list, args, 'prefix') + + self.assertTrue(len(mock_context.data) > 0) + self.assertIn('test_op/prefix_0', mock_context.data) + + def test_create_async_callback_func(self): + """测试创建异步回调函数""" + mock_context = MagicMock() + mock_context.data = {} + + callback = create_async_callback_func(mock_context, 'test_op', [], [self.test_tensor], 'prefix') + callback() + + self.assertTrue(len(mock_context.data) > 0) + + def test_create_hooks(self): + """测试创建钩子函数""" + mock_monitor = MagicMock() + mock_monitor.cc_log_only = False + mock_monitor.cc_pre_hook = True + mock_monitor.cc_codeline = [] + mock_monitor.ops = ['mean'] + mock_monitor.module_rank_list = [0] + mock_monitor.cc_logged_stack = MagicMock() + + with patch('msprobe.mindspore.monitor.distributed.wrap_distributed.get_rank', return_value=0): + pre_hooks, hooks = create_hooks({}, mock_monitor) + + self.assertEqual(len(pre_hooks), 1) + self.assertEqual(len(hooks), 1) + + def test_ApiRegistry_redirect(self): + """测试ApiRegistry的重定向和恢复功能""" + # 保存原始API + original_allreduce = comm_func.all_reduce + + # 测试重定向 + api_register.distributed_attr_origin['all_reduce'] = original_allreduce + api_register.distributed_attr_hooked['all_reduce'] = MagicMock() + api_register.redirect_api() + + self.assertNotEqual(comm_func.all_reduce, original_allreduce) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file