diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/common_func/test_time_range_calculator.py b/profiler/msprof_analyze/test/ut/cluster_analyse/common_func/test_time_range_calculator.py new file mode 100644 index 0000000000000000000000000000000000000000..d5d34e95e88f02a42ea34a4d1996acfae4ae1c36 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/common_func/test_time_range_calculator.py @@ -0,0 +1,100 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from msprof_analyze.cluster_analyse.common_func.time_range_calculator import RangeCaculator, TimeRange, \ + CommunicationTimeRange + + +class TestTimeRangeCalculator(unittest.TestCase): + def test_time_range_initialization(self): + time_range = TimeRange() + self.assertEqual(time_range.start_ts, -1) + self.assertEqual(time_range.end_ts, -1) + + custom_range = TimeRange(10, 20) + self.assertEqual(custom_range.start_ts, 10) + self.assertEqual(custom_range.end_ts, 20) + + def test_communication_time_range_initialization(self): + comm_range = CommunicationTimeRange() + self.assertEqual(comm_range.start_ts, -1) + self.assertEqual(comm_range.end_ts, -1) + + def test_generate_time_range(self): + time_range = RangeCaculator.generate_time_range(10, 20) + self.assertEqual(time_range.start_ts, 10) + self.assertEqual(time_range.end_ts, 20) + self.assertIsInstance(time_range, TimeRange) + + comm_range = RangeCaculator.generate_time_range(10, 20, CommunicationTimeRange) + self.assertEqual(comm_range.start_ts, 10) + self.assertEqual(comm_range.end_ts, 20) + self.assertIsInstance(comm_range, CommunicationTimeRange) + + def test_merge_continuous_intervals(self): + # 测试空列表 + self.assertEqual(RangeCaculator.merge_continuous_intervals([]), []) + + # 测试无重叠区间 + range1 = TimeRange(1, 2) + range2 = TimeRange(3, 4) + self.assertEqual(RangeCaculator.merge_continuous_intervals([range1, range2]), [range1, range2]) + + # 测试有重叠区间 + range3 = TimeRange(1, 3) + range4 = TimeRange(2, 4) + merged = RangeCaculator.generate_time_range(1, 4) + self.assertEqual(RangeCaculator.merge_continuous_intervals([range3, range4]), [merged]) + + # 测试有包含关系的区间 + range5 = TimeRange(1, 5) + range6 = TimeRange(2, 3) + self.assertEqual(RangeCaculator.merge_continuous_intervals([range5, range6]), [range5]) + + def test_compute_pipeline_overlap(self): + # 测试空列表 + pure_comm, free_time = RangeCaculator.compute_pipeline_overlap([], []) + self.assertEqual(pure_comm, []) + self.assertEqual(free_time, []) + + # 测试无重叠区间 + comm_range1 = CommunicationTimeRange() + comm_range1.start_ts, comm_range1.end_ts = 1, 2 + compute_range1 = TimeRange() + compute_range1.start_ts, compute_range1.end_ts = 3, 4 + pure_comm, free_time = RangeCaculator.compute_pipeline_overlap([comm_range1], [compute_range1]) + self.assertEqual(len(pure_comm), 1) + self.assertEqual(pure_comm[0].start_ts, 1) + self.assertEqual(pure_comm[0].end_ts, 2) + self.assertEqual(len(free_time), 1) + self.assertEqual(free_time[0].start_ts, 2) + self.assertEqual(free_time[0].end_ts, 3) + + # 测试有重叠区间 + comm_range2 = CommunicationTimeRange() + comm_range2.start_ts, comm_range2.end_ts = 1, 3 + compute_range2 = TimeRange() + compute_range2.start_ts, compute_range2.end_ts = 2, 4 + pure_comm, free_time = RangeCaculator.compute_pipeline_overlap([comm_range2], [compute_range2]) + self.assertEqual(len(pure_comm), 1) + self.assertEqual(pure_comm[0].start_ts, 1) + self.assertEqual(pure_comm[0].end_ts, 2) + self.assertEqual(len(free_time), 0) + + +if __name__ == '__main__': + unittest.main() diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_communication_db_group.py b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_communication_db_group.py new file mode 100644 index 0000000000000000000000000000000000000000..27540643e530ea79907b648e53d45b625b159148 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_communication_db_group.py @@ -0,0 +1,268 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock + +from msprof_analyze.cluster_analyse.communication_group.communication_db_group import get_communication_data, \ + dump_group_db, CommunicationDBGroupOptimized +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_common.logger import get_logger + +logger = get_logger() + + +class TestGetCommunicationData(unittest.TestCase): + @patch('os.path.exists') + def test_get_communication_data_db_not_exist(self, mock_exists): + # 模拟数据库路径不存在的情况 + mock_exists.return_value = False + rank_id = '0' + db_path = '/path/to/db' + analysis_mode = Constant.ALL + + result = get_communication_data(rank_id, db_path, analysis_mode) + self.assertEqual(result, ([], [], [])) + + @patch('os.path.exists') + @patch.object(DBManager, 'create_connect_db') + @patch.object(DBManager, 'check_tables_in_db') + @patch.object(DBManager, 'fetch_all_data') + @patch.object(DBManager, 'destroy_db_connect') + def test_get_communication_data_with_all_mode(self, mock_destroy, mock_fetch, mock_check, mock_create, mock_exists): + # 模拟数据库路径存在的情况 + mock_exists.return_value = True + rank_id = '0' + db_path = '/path/to/db' + analysis_mode = Constant.ALL + conn, cursor = MagicMock(), MagicMock() + mock_create.return_value = (conn, cursor) + mock_check.side_effect = [True, True] + mock_fetch.side_effect = [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + + time_data, bandwidth_data, matrix_data = get_communication_data(rank_id, db_path, analysis_mode) + self.assertEqual(time_data, [1, 2, 3]) + self.assertEqual(bandwidth_data, [4, 5, 6]) + self.assertEqual(matrix_data, [7, 8, 9]) + mock_destroy.assert_called_once_with(conn, cursor) + + @patch('os.path.exists') + @patch.object(DBManager, 'create_connect_db') + @patch.object(DBManager, 'check_tables_in_db') + @patch.object(DBManager, 'fetch_all_data') + @patch.object(DBManager, 'destroy_db_connect') + def test_get_communication_data_with_communication_time_mode(self, mock_destroy, mock_fetch, mock_check, + mock_create, mock_exists): + # 模拟只获取通信时间数据的情况 + mock_exists.return_value = True + rank_id = '0' + db_path = '/path/to/db' + analysis_mode = Constant.COMMUNICATION_TIME + conn, cursor = MagicMock(), MagicMock() + mock_create.return_value = (conn, cursor) + mock_check.side_effect = [True, False] + mock_fetch.side_effect = [[1, 2, 3], [4, 5, 6]] + + time_data, bandwidth_data, matrix_data = get_communication_data(rank_id, db_path, analysis_mode) + self.assertEqual(time_data, [1, 2, 3]) + self.assertEqual(bandwidth_data, [4, 5, 6]) + self.assertEqual(matrix_data, []) + mock_destroy.assert_called_once_with(conn, cursor) + + @patch('os.path.exists') + @patch.object(DBManager, 'create_connect_db') + @patch.object(DBManager, 'check_tables_in_db') + @patch.object(DBManager, 'fetch_all_data') + @patch.object(DBManager, 'destroy_db_connect') + def test_get_communication_data_with_communication_matrix_mode(self, mock_destroy, mock_fetch, mock_check, + mock_create, mock_exists): + # 模拟只获取通信矩阵数据的情况 + mock_exists.return_value = True + rank_id = '0' + db_path = '/path/to/db' + analysis_mode = Constant.COMMUNICATION_MATRIX + conn, cursor = MagicMock(), MagicMock() + mock_create.return_value = (conn, cursor) + mock_check.side_effect = [False, True] + mock_fetch.return_value = [7, 8, 9] + + time_data, bandwidth_data, matrix_data = get_communication_data(rank_id, db_path, analysis_mode) + self.assertEqual(time_data, []) + self.assertEqual(bandwidth_data, []) + self.assertEqual(matrix_data, [7, 8, 9]) + mock_destroy.assert_called_once_with(conn, cursor) + + @patch('os.path.join') + @patch.object(DBManager, 'create_tables') + @patch.object(DBManager, 'create_connect_db') + @patch.object(DBManager, 'executemany_sql') + @patch.object(DBManager, 'destroy_db_connect') + def test_dump_group_db_with_data(self, mock_destroy, mock_executemany, mock_create, mock_create_tables, mock_join): + # 准备测试数据 + dump_data = [[1, 2, 3], [4, 5, 6]] + group_table = 'test_table' + cluster_analysis_output_path = '/path/to/output' + + # 模拟返回值 + output_path = '/path/to/output/CLUSTER_ANALYSIS_OUTPUT' + result_db = '/path/to/output/CLUSTER_ANALYSIS_OUTPUT/DB_CLUSTER_COMMUNICATION_ANALYZER' + mock_join.side_effect = [output_path, result_db] + conn, cursor = MagicMock(), MagicMock() + mock_create.return_value = (conn, cursor) + + # 调用函数 + dump_group_db(dump_data, group_table, cluster_analysis_output_path) + + # 验证函数调用 + mock_create_tables.assert_called_once_with(result_db, group_table) + mock_create.assert_called_once_with(result_db) + sql = "insert into {} values ({})".format(group_table, "?," * (len(dump_data[0]) - 1) + "?") + mock_executemany.assert_called_once_with(conn, sql, dump_data) + mock_destroy.assert_called_once_with(conn, cursor) + + @patch.object(logger, 'warning') + def test_dump_group_db_without_data(self, mock_warning): + # 准备测试数据 + dump_data = [] + group_table = 'test_table' + cluster_analysis_output_path = '/path/to/output' + + # 调用函数 + dump_group_db(dump_data, group_table, cluster_analysis_output_path) + + # 验证警告日志 + mock_warning.assert_called_once_with( + "[WARNING] The CommunicationGroup table won't be created because no data has been calculated.") + + +class TestCommunicationDBGroupOptimized(unittest.TestCase): + def setUp(self): + self.params = {} + self.analyzer = CommunicationDBGroupOptimized(self.params) + self.analyzer.adapter = MagicMock() + self.analyzer.rank_comm_dir_dict = [] + self.analyzer.collective_group_dict = {} + self.analyzer.p2p_group_dict = {} + self.analyzer.communication_ops = [] + self.analyzer.bandwidth_data = [] + self.analyzer.matrix_ops = [] + self.analyzer.communication_group = {} + self.analyzer.comm_group_parallel_info_df = MagicMock() + self.analyzer.cluster_analysis_output_path = 'test_path' + + def test_init(self): + self.assertEqual(self.analyzer.bandwidth_data, []) + self.assertEqual(self.analyzer.matrix_ops, []) + + def test_read_communication_func_insufficient_params(self): + params = (1,) + result = self.analyzer.read_communication_func(params) + self.assertEqual(result, (-1, {}, {})) + + @patch('msprof_analyze.cluster_analyse.communication_group.communication_db_group.get_communication_data') + def test_read_communication_func(self, mock_get_communication_data): + mock_get_communication_data.return_value = ([], [], []) + self.analyzer.adapter.transfer_matrix_from_db_to_json.return_value = {} + params = (1, 'db_path', None) + rank_id, comm_time_data, comm_matrix_data = self.analyzer.read_communication_func(params) + self.assertEqual(rank_id, 1) + self.assertEqual(comm_time_data, ([], [])) + self.assertEqual(comm_matrix_data, {}) + mock_get_communication_data.assert_called_once_with(1, 'db_path', None) + self.analyzer.adapter.transfer_matrix_from_db_to_json.assert_called_once_with([]) + + def test_set_group_rank_map_no_group_name(self): + time_data = [{Constant.TYPE: 'type'}] + self.analyzer.set_group_rank_map(1, time_data) + self.assertEqual(self.analyzer.collective_group_dict, {}) + self.assertEqual(self.analyzer.p2p_group_dict, {}) + + def test_set_group_rank_map_collective(self): + self.analyzer.collective_group_dict = {'group': set()} + time_data = [{Constant.TYPE: Constant.COLLECTIVE, Constant.GROUP_NAME: 'group'}] + self.analyzer.set_group_rank_map(1, time_data) + self.assertEqual(self.analyzer.collective_group_dict['group'], {1}) + + def test_set_group_rank_map_p2p(self): + self.analyzer.p2p_group_dict = {'group': set()} + time_data = [{Constant.TYPE: Constant.P2P, Constant.GROUP_NAME: 'group'}] + self.analyzer.set_group_rank_map(1, time_data) + self.assertEqual(self.analyzer.p2p_group_dict['group'], {1}) + + @patch('msprof_analyze.cluster_analyse.communication_group.communication_db_group.logger') + def test_analyze_communication_data_time_mode_empty(self, mock_logger): + self.analyzer.analysis_mode = Constant.ALL + self.analyzer.rank_comm_dir_dict = [(1, ([], []), {})] + self.analyzer.analyze_communication_data() + mock_logger.warning.assert_called_once_with('[WARNING] rank %s has error format in time data.', 1) + + @patch('msprof_analyze.cluster_analyse.communication_group.communication_db_group.logger') + @patch('msprof_analyze.cluster_analyse.communication_group.communication_db_group.' + 'CommunicationDBGroupOptimized.set_group_rank_map') + @patch('msprof_analyze.cluster_analyse.communication_group.communication_db_group.' + 'CommunicationDBGroupOptimized._merge_data_with_rank') + def test_analyze_communication_data_matrix_mode_empty(self, mock_merge_data, mock_group_rank, mock_logger): + self.analyzer.analysis_mode = Constant.ALL + self.analyzer.rank_comm_dir_dict = [(1, ([{}], []), None)] + self.analyzer.analyze_communication_data() + mock_logger.warning.assert_any_call('[WARNING] rank %s matrix data is null.', 1) + + def test_analyze_communication_data_invalid_matrix_format(self): + self.analyzer.analysis_mode = Constant.ALL + self.analyzer.rank_comm_dir_dict = [(1, ([{}], []), {1: 'invalid'})] + with patch('msprof_analyze.cluster_analyse.communication_group.communication_db_group.logger') as mock_logger: + self.analyzer.analyze_communication_data() + mock_logger.warning.assert_any_call('[WARNING] rank %s has error format in matrix data.', 1) + + def test_generate_collective_communication_group(self): + self.analyzer.collective_group_dict = {'group1': {1, 2}, 'group2': {3}} + self.analyzer.generate_collective_communication_group() + self.assertEqual(self.analyzer.communication_group[Constant.COLLECTIVE], [('group1', [1, 2]), ('group2', [3])]) + + def test_collect_comm_data(self): + self.analyzer.collective_group_dict = {'group': {1}} + self.analyzer.communication_ops = [1] + self.analyzer.bandwidth_data = [2] + self.analyzer.matrix_ops = [3] + self.analyzer.communication_group = {'type': 'data'} + result = self.analyzer.collect_comm_data() + expected = { + Constant.COLLECTIVE_GROUP: {'group': {1}}, + Constant.COMMUNICATION_OPS: ([1], [2]), + Constant.MATRIX_OPS: [3], + Constant.COMMUNICATION_GROUP: {'type': 'data'} + } + self.assertEqual(result, expected) + + @patch('msprof_analyze.cluster_analyse.communication_group.communication_db_group.dump_group_db') + def test_dump_data(self, mock_dump_group_db): + mock_df = MagicMock() + mock_df.values.tolist.return_value = [[1, 2]] + self.analyzer.comm_group_parallel_info_df = mock_df + self.analyzer.dump_data() + mock_df['rank_set'].apply.assert_called_once() + mock_df.values.tolist.assert_called_once() + mock_dump_group_db.assert_called_once_with([[1, 2]], 'CommunicationGroupMapping', 'test_path') + + def test__merge_data_with_rank(self): + data_list = [{'key': 'value'}] + result = self.analyzer._merge_data_with_rank(1, data_list) + expected = [{'key': 'value', Constant.RANK_ID: 1}] + self.assertEqual(result, expected) + + +if __name__ == '__main__': + unittest.main() diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_msprof_communication_matrix_adapter.py b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_msprof_communication_matrix_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..596cd98cb7c7745f564d6fb0ff08d24afdbc351e --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/communication_group/test_msprof_communication_matrix_adapter.py @@ -0,0 +1,106 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch + +from msprof_analyze.cluster_analyse.communication_group.msprof_communication_matrix_adapter import \ + MsprofCommunicationMatrixAdapter +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.file_manager import FileManager + + +class TestMsprofCommunicationMatrixAdapter(unittest.TestCase): + def setUp(self): + self.file_path = 'test.json' + self.adapter = MsprofCommunicationMatrixAdapter(self.file_path) + + @patch.object(FileManager, 'read_json_file') + def test_generate_comm_matrix_data(self, mock_read_json): + # 准备测试数据 + mock_comm_matrix_data = { + 'hcom_send_op1': {'link1': {'data': 'p2p_data'}}, + 'allreduce_op2': {'link2': {'data': 'collective_data'}}, + 'TOTAL_op3': {'link3': {'data': 'total_data'}} + } + mock_read_json.return_value = mock_comm_matrix_data + + # 模拟 get_comm_type 和 integrate_matrix_data 方法 + with patch.object(self.adapter, 'get_comm_type', ) as mock_get_comm_type, \ + patch.object(self.adapter, 'integrate_matrix_data') as mock_integrate_matrix_data: + mock_get_comm_type.side_effect = [{'p2p_data': []}, {'collective_data': []}] + mock_integrate_matrix_data.side_effect = [{'p2p_result': {}}, {'collective_result': {}}] + result = self.adapter.generate_comm_matrix_data() + + # 验证调用逻辑 + mock_read_json.assert_called_once_with(self.file_path) + self.assertEqual(mock_get_comm_type.call_count, 2) + self.assertEqual(mock_integrate_matrix_data.call_count, 2) + self.assertEqual(result, { + 'step': { + Constant.P2P: {'p2p_result': {}}, + Constant.COLLECTIVE: {'collective_result': {}} + } + }) + + def test_get_comm_type(self): + # 准备测试数据 + op_data = { + 'send_op1@step1': {'link1': {'Bandwidth(GB/s)': 10}}, + 'unknown_op2__extra@step2': {'link2': {'Bandwidth(GB/s)': 20}} + } + + with patch('msprof_analyze.cluster_analyse.communication_group.' + 'msprof_communication_matrix_adapter.logger.warning') as mock_warning: + result = self.adapter.get_comm_type(op_data) + + # 验证匹配到 HCCL 模式的情况 + self.assertIn(('send', 'step1', 'link1'), result) + self.assertEqual(result[('send', 'step1', 'link1')], [{'Bandwidth(GB/s)': 10, 'Op Name': 'send_op1'}]) + + # 验证未匹配到 HCCL 模式的情况 + self.assertIn(('unknown_op2', 'step2', 'link2'), result) + self.assertEqual(result[('unknown_op2', 'step2', 'link2')], + [{'Bandwidth(GB/s)': 20, 'Op Name': 'unknown_op2__extra'}]) + mock_warning.assert_called_once_with('Unknown communication op type: unknown_op2') + + def test_integrate_matrix_data(self): + # 准备测试数据 + new_comm_op_dict = { + ('send', 'step1', 'link1'): [ + {'Bandwidth(GB/s)': 30, 'Transport Type': 'type1', 'Transit Size(MB)': 100, 'Transit Time(ms)': 10}, + {'Bandwidth(GB/s)': 20, 'Transport Type': 'type2', 'Transit Size(MB)': 200, 'Transit Time(ms)': 20}, + {'Bandwidth(GB/s)': 10, 'Transport Type': 'type3', 'Transit Size(MB)': 300, 'Transit Time(ms)': 30} + ] + } + + result = self.adapter.integrate_matrix_data(new_comm_op_dict) + + # 验证排序和数据整合 + self.assertEqual(result['send-top1@step1']['link1'], new_comm_op_dict[('send', 'step1', 'link1')][0]) + self.assertEqual(result['send-middle@step1']['link1'], new_comm_op_dict[('send', 'step1', 'link1')][1]) + self.assertEqual(result['send-bottom1@step1']['link1'], new_comm_op_dict[('send', 'step1', 'link1')][2]) + self.assertEqual(result['send-bottom2@step1']['link1'], new_comm_op_dict[('send', 'step1', 'link1')][1]) + self.assertEqual(result['send-bottom3@step1']['link1'], new_comm_op_dict[('send', 'step1', 'link1')][0]) + self.assertEqual(result['send-total@step1']['link1'], { + 'Transport Type': 'type1', + 'Transit Size(MB)': 600, + 'Transit Time(ms)': 60, + 'Bandwidth(GB/s)': 10.0 + }) + + +if __name__ == '__main__': + unittest.main() diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_base_recipe_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_base_recipe_analysis.py new file mode 100644 index 0000000000000000000000000000000000000000..e31db347d738e68aa77df6f2c969e7102c5a0b20 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_base_recipe_analysis.py @@ -0,0 +1,249 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import unittest +from unittest.mock import patch, MagicMock + +import pandas as pd + +from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_common.file_manager import FileManager + + +class TestBaseRecipeAnalysis(unittest.TestCase): + def setUp(self): + self.params = { + Constant.COLLECTION_PATH: '/tmp/to/collection', + Constant.DATA_MAP: {0: '/tmp/to/data/0', 1: '/tmp/to/data/1'}, + Constant.RECIPE_NAME: 'test_recipe', + Constant.PARALLEL_MODE: 'parallel', + Constant.EXPORT_TYPE: 'csv', + Constant.IS_MSPROF: False, + Constant.IS_MINDSPORE: False, + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: '/tmp/to/output', + Constant.RANK_LIST: '0,1', + Constant.STEP_ID: 1, + Constant.EXTRA_ARGS: [] + } + + # 创建一个 BaseRecipeAnalysis 的子类用于测试 + class ConcreteRecipeAnalysis(BaseRecipeAnalysis): + @property + def base_dir(self): + return 'test_dir' + + def run(self, context): + pass + + self.analysis = ConcreteRecipeAnalysis(self.params) + + def test_enter_exit(self): + with self.analysis as instance: + self.assertEqual(instance, self.analysis) + + with patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.logger.error') as mock_logger, \ + patch('traceback.print_exc') as mock_traceback: + try: + with self.analysis: + raise ValueError('Test error') + except ValueError: + pass + mock_logger.assert_called_once_with('Failed to exit analysis: Test error') + mock_traceback.assert_called_once() + + def test_output_path_property(self): + self.assertEqual( + self.analysis.output_path, + os.path.join('/tmp/to/output', Constant.CLUSTER_ANALYSIS_OUTPUT, 'test_recipe') + ) + + def test_filter_data(self): + test_data = [(1, [1, 2, 3]), (2, []), (3, None), (4, [4, 5])] + result = BaseRecipeAnalysis._filter_data(test_data) + self.assertEqual(result, [(1, [1, 2, 3]), (4, [4, 5])]) + + @patch.object(DBManager, 'create_connect_db') + @patch.object(DBManager, 'destroy_db_connect') + def test_dump_data_to_db(self, mock_destroy, mock_create): + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_create.return_value = (mock_conn, mock_cursor) + data = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + + self.analysis.dump_data(data, 'test.db', 'test_table') + + mock_create.assert_called_once_with(os.path.join(self.analysis.output_path, 'test.db')) + mock_destroy.assert_called_once_with(mock_conn, mock_cursor) + + @patch.object(FileManager, 'create_csv_from_dataframe') + def test_dump_data_to_csv(self, mock_create_csv): + data = pd.DataFrame({'a': [1, 2], 'b': [3, 4]}) + with patch('msprof_analyze.cluster_analyse.common_func.utils.convert_unit', return_value=data): + self.analysis.dump_data(data, 'test.csv') + + @patch('shutil.copy') + @patch('os.chmod') + def test_create_notebook_without_replace(self, mock_chmod, mock_copy): + self.analysis.create_notebook('test.ipynb') + mock_copy.assert_called_once_with( + os.path.realpath(os.path.join(os.path.dirname(__file__), "..", "..", "..", "..", + "cluster_analyse", "recipes", 'test_dir', 'test.ipynb')), + os.path.join(self.analysis.output_path, 'test.ipynb') + ) + mock_chmod.assert_called_once_with( + os.path.join(self.analysis.output_path, 'test.ipynb'), + Constant.FILE_AUTHORITY + ) + + @patch('shutil.copy') + @patch('os.chmod') + def test_add_helper_file(self, mock_chmod, mock_copy): + # 准备测试数据 + helper_file = 'test_helper.txt' + mock_dirname = MagicMock(return_value='test_dir') + + with patch('os.path.dirname', mock_dirname): + # 调用函数 + self.analysis.add_helper_file(helper_file) + + # 验证 shutil.copy 被调用 + mock_copy.assert_called_once_with( + os.path.join('test_dir', helper_file), + os.path.join(self.analysis.output_path, helper_file) + ) + + # 验证 os.chmod 被调用 + mock_chmod.assert_called_once_with( + os.path.join(self.analysis.output_path, helper_file), + Constant.FILE_AUTHORITY + ) + + def test_map_rank_pp_stage(self): + # 测试用例 1: 默认参数 + distributed_args = {} + result = self.analysis.map_rank_pp_stage(distributed_args) + self.assertEqual(result, {0: 0}) + + # 测试用例 2: 仅设置 TP_SIZE + distributed_args = {self.analysis.TP_SIZE: 2} + result = self.analysis.map_rank_pp_stage(distributed_args) + self.assertEqual(result, {0: 0, 1: 0}) + + # 测试用例 3: 仅设置 PP_SIZE + distributed_args = {self.analysis.PP_SIZE: 2} + result = self.analysis.map_rank_pp_stage(distributed_args) + self.assertEqual(result, {0: 0, 1: 1}) + + # 测试用例 4: 设置所有参数 + distributed_args = { + self.analysis.TP_SIZE: 2, + self.analysis.PP_SIZE: 2, + self.analysis.DP_SIZE: 2 + } + result = self.analysis.map_rank_pp_stage(distributed_args) + self.assertEqual(result, { + 0: 0, 1: 0, 2: 0, 3: 0, + 4: 1, 5: 1, 6: 1, 7: 1 + }) + + @patch('os.path.exists') + @patch('json.loads') + def test_load_distributed_args_from_extra_args(self, mock_json_loads, mock_exists): + # 测试从 _extra_args 获取参数 + self.analysis._extra_args = {'tp': 2, 'pp': 2, 'dp': 2} + result = self.analysis.load_distributed_args() + self.assertEqual(result, { + self.analysis.TP_SIZE: 2, + self.analysis.PP_SIZE: 2, + self.analysis.DP_SIZE: 2 + }) + + @patch('os.path.exists') + @patch('json.loads') + @patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DatabaseService') + def test_load_distributed_args_from_db(self, mock_service, mock_json_loads, mock_exists): + # 测试从数据库获取参数 + mock_exists.return_value = True + mock_df = MagicMock() + mock_df.loc.return_value = MagicMock(empty=False, values=[json.dumps({ + self.analysis.TP_SIZE: 1, + self.analysis.PP_SIZE: 1, + self.analysis.DP_SIZE: 1 + })]) + mock_service.return_value.query_data.return_value = {'META_DATA': mock_df} + result = self.analysis.load_distributed_args() + self.assertEqual(result, { + self.analysis.TP_SIZE: 1, + self.analysis.PP_SIZE: 1, + self.analysis.DP_SIZE: 1 + }) + + @patch('os.path.exists') + def test_get_rank_db(self, mock_exists): + # 测试 _get_rank_db 函数 + mock_exists.return_value = True + self.analysis._get_step_range = MagicMock(return_value={'id': 1}) + self.analysis._get_profiler_db_path = MagicMock(return_value='test_profiler.db') + self.analysis._get_analysis_db_path = MagicMock(return_value='test_analysis.db') + result = self.analysis._get_rank_db() + self.assertEqual(len(result), 2) + self.assertEqual(result[0][Constant.RANK_ID], 0) + self.assertEqual(result[0][Constant.PROFILER_DB_PATH], 'test_profiler.db') + self.assertEqual(result[0][Constant.ANALYSIS_DB_PATH], 'test_analysis.db') + self.assertEqual(result[0][Constant.STEP_RANGE], {'id': 1}) + + def test_get_profiler_db_path(self): + # 测试 _get_profiler_db_path 函数 + # 测试 PyTorch 情况 + result = self.analysis._get_profiler_db_path(0, 'test_path') + self.assertEqual(result, os.path.join('test_path', Constant.SINGLE_OUTPUT, 'ascend_pytorch_profiler_0.db')) + + # 测试 MindSpore 情况 + self.analysis._is_mindspore = True + result = self.analysis._get_profiler_db_path(0, 'test_path') + self.assertEqual(result, os.path.join('test_path', Constant.SINGLE_OUTPUT, 'ascend_mindspore_profiler_0.db')) + + def test_get_analysis_db_path(self): + # 测试 _get_analysis_db_path 函数 + # 测试 PyTorch 情况 + result = self.analysis._get_analysis_db_path('test_path') + self.assertEqual(result, os.path.join('test_path', Constant.SINGLE_OUTPUT, 'analysis.db')) + + # 测试 MindSpore 情况 + self.analysis._is_mindspore = True + result = self.analysis._get_analysis_db_path('test_path') + self.assertEqual(result, os.path.join('test_path', Constant.SINGLE_OUTPUT, 'communication_analyzer.db')) + + @patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DBManager.create_connect_db') + @patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DBManager.judge_table_exists') + @patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DBManager.fetch_all_data') + @patch('msprof_analyze.cluster_analyse.recipes.base_recipe_analysis.DBManager.destroy_db_connect') + def test_get_step_range(self, mock_destroy, mock_fetch, mock_judge, mock_connect): + # 测试 _get_step_range 函数 + mock_conn, mock_cursor = MagicMock(), MagicMock() + mock_connect.return_value = (mock_conn, mock_cursor) + mock_judge.return_value = True + mock_fetch.return_value = [{'id': 1, 'startNs': 0, 'endNs': 100}] + self.analysis._step_id = 1 + result = self.analysis._get_step_range('test.db') + self.assertEqual(result, {'id': 1, 'startNs': 0, 'endNs': 100}) + + +if __name__ == '__main__': + unittest.main() diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_pp_chart.py b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_pp_chart.py index 5a21c032bed2d57606a0cd2e739d0510231f8236..a6a77431f6b268d02a4ec47dc4e8acaa06fd4296 100644 --- a/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_pp_chart.py +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/recipes/test_pp_chart.py @@ -52,9 +52,9 @@ class TestClusterTimeCompareSummary(unittest.TestCase): ['17F+3B', 3], ['8F+14B', 3], ['18F+4B', 3], ['9F+15B', 3], ['19F+5B', 3], ['16B', 5], ['6B', 5], ['17B', 5], ['7B', 5], ['18B', 6], ['8B', 6], ['19B', 6], ['9B', 6]] } - with (mock.patch(NAMESPACE + ".base_recipe_analysis.BaseRecipeAnalysis.load_distributed_args", - return_value={PPChart.PP_SIZE: 4}), - mock.patch(NAMESPACE + ".pp_chart.pp_chart.PPChart.load_pp_info")): + with mock.patch(NAMESPACE + ".base_recipe_analysis.BaseRecipeAnalysis.load_distributed_args", + return_value={PPChart.PP_SIZE: 4}), \ + mock.patch(NAMESPACE + ".pp_chart.pp_chart.PPChart.load_pp_info"): pp_chart_instance = PPChart({}) pp_chart_instance.micro_batch_num = 10 pp_chart_instance.calculate_micro_batch_id_for_dualpipev() @@ -63,7 +63,7 @@ class TestClusterTimeCompareSummary(unittest.TestCase): def test_pp_chart_should_generate_table_when_pp_info_not_existed(self): df = pd.DataFrame({"step": [0, 0], "msg": ["forward_step", "backward_step"], "startNs": [1, 4], - "endNs": [2, 5]}) + "endNs": [2, 5]}) with mock.patch(NAMESPACE + ".base_recipe_analysis.BaseRecipeAnalysis.load_distributed_args", return_value={}), \ mock.patch(NAMESPACE + ".base_recipe_analysis.BaseRecipeAnalysis.dump_data"), \