From aa6eed77f22913d7f3a6a93192e4a30a070b7b2a Mon Sep 17 00:00:00 2001 From: panzhaohu Date: Tue, 2 Sep 2025 14:44:12 +0800 Subject: [PATCH] cluster_ut --- .../analysis/test_base_analysis.py | 142 +++++++++++ .../test_cluster_base_info_analysis.py | 119 +++++++++ .../analysis/test_comm_matrix_analysis.py | 199 +++++++++++++++ .../analysis/test_communication_analysis.py | 230 ++++++++++++++++++ .../analysis/test_host_info_analysis.py | 220 +++++++++++++++++ 5 files changed, 910 insertions(+) create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_base_analysis.py create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_cluster_base_info_analysis.py create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_comm_matrix_analysis.py create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_communication_analysis.py create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_host_info_analysis.py diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_base_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_base_analysis.py new file mode 100644 index 0000000000..f3410b5ef9 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_base_analysis.py @@ -0,0 +1,142 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest.mock import patch +from msprof_analyze.cluster_analyse.analysis.base_analysis import BaseAnalysis +from msprof_analyze.prof_common.file_manager import FileManager +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.prof_common.logger import get_logger + +logger = get_logger() + + +class ConcreteBaseAnalysis(BaseAnalysis): + + def compute_total_info(self, communication_ops): + for op_name, rank_dict in communication_ops.items(): + total_info = {} + for _, op_info in rank_dict.items(): + for key, value in op_info.items(): + if self.check_add_op(key): + total_info[key] = total_info.get(key, 0) + value + communication_ops[op_name]["total"] = total_info + + +class TestBaseAnalysis(unittest.TestCase): + + def setUp(self): + self.param = { + Constant.COLLECTION_PATH: "/fake/path", + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: "/fake/output", + Constant.DATA_MAP: {}, + Constant.DATA_TYPE: "text", + Constant.COMM_DATA_DICT: { + Constant.COLLECTIVE_GROUP: { + "group0": [0, 1, 2] + } + }, + Constant.DATA_SIMPLIFICATION: False + } + self.analysis = ConcreteBaseAnalysis(self.param) + + def test_compute_ratio_when_various_inputs(self): + self.assertEqual(self.analysis.compute_ratio(1.0, 2.0), 0.5) + self.assertEqual(self.analysis.compute_ratio(1.0, 0.0), 0) + self.assertEqual(self.analysis.compute_ratio(-1.0, 2.0), -0.5) + self.assertEqual(self.analysis.compute_ratio(1.0, 1e-16), 0) + + def test_check_add_op_when_input_is_op_total_or_total(self): + self.assertTrue(self.analysis.check_add_op("op_total")) + self.assertTrue(self.analysis.check_add_op("total")) + + def test_split_op_by_group_when_contains_p2p_and_collective_ops(self): + self.analysis.communication_ops = [ + { + Constant.COMM_OP_TYPE: "p2p", + Constant.RANK_ID: 0, + Constant.STEP_ID: 1, + Constant.COMM_OP_NAME: "P2P_op", + Constant.COMM_OP_INFO: {"bytes": 100}, + Constant.GROUP_NAME: "group0" + }, + { + Constant.COMM_OP_TYPE: "collective", + Constant.GROUP_NAME: "group0", + Constant.RANK_ID: 0, + Constant.STEP_ID: 1, + Constant.COMM_OP_NAME: "AllReduce", + Constant.COMM_OP_INFO: {"bytes": 200} + }, + { + Constant.COMM_OP_TYPE: "collective", + Constant.GROUP_NAME: "group0", + Constant.RANK_ID: 1, + Constant.STEP_ID: 1, + Constant.COMM_OP_NAME: "AllReduce", + Constant.COMM_OP_INFO: {"bytes": 300} + } + ] + + self.analysis.split_op_by_group() + p2p_group = self.analysis.comm_ops_struct[Constant.P2P] + self.assertIn(1, p2p_group) + self.assertIn("P2P_op", p2p_group[1]) + self.assertIn(0, p2p_group[1]["P2P_op"]) + group0 = tuple([0, 1, 2]) + self.assertIn(group0, self.analysis.comm_ops_struct) + collective_group = self.analysis.comm_ops_struct[group0] + self.assertIn(1, collective_group) + self.assertIn(0, collective_group[1]["AllReduce"]) + self.assertEqual(collective_group[1]["AllReduce"][0]["bytes"], 200) + self.assertEqual(collective_group[1]["AllReduce"][1]["bytes"], 300) + + def test_combine_ops_total_info_when_ops_contain_allreduce_multi_rank(self): + self.analysis.comm_ops_struct = { + tuple([0, 1, 2]): { + 1: { + "AllReduce": { + 0: {"bytes": 200, "middle_bytes": 50}, + 1: {"bytes": 300, "middle_bytes": 60} + } + } + } + } + + self.analysis.combine_ops_total_info() + group_data = self.analysis.comm_ops_struct[tuple([0, 1, 2])][1]["AllReduce"] + self.assertIn("total", group_data) + total_info = group_data["total"] + self.assertEqual(total_info["bytes"], 500) + self.assertNotIn("middle_bytes", total_info) + + def test_dump_data_when_comm_ops_struct_no_value(self): + self.analysis.data_type = Constant.TEXT + self.analysis.comm_ops_struct = {} + + with patch.object(logger, 'warning') as mock_warning: + self.analysis.dump_data() + mock_warning.assert_called_once_with("There is no final comm ops data generated.") + + def test_dump_data_when_too_many_ranks_but_with_simplification(self): + self.analysis.data_type = "db" + self.analysis.data_map = {i: f"rank_{i}" for i in range(self.analysis.MAX_RANKS + 1)} + self.analysis.data_simplification = True + self.analysis.comm_ops_struct = {"p2p": {"step4": {"hcom": "communication"}}} + + with patch.object(self.analysis, 'dump_db') as mock_dump_db: + self.analysis.dump_data() + mock_dump_db.assert_called_once() \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_cluster_base_info_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_cluster_base_info_analysis.py new file mode 100644 index 0000000000..571aa2697e --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_cluster_base_info_analysis.py @@ -0,0 +1,119 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock +import json +import os +from msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis import ClusterBaseInfoAnalysis +from msprof_analyze.prof_common.constant import Constant + + +class TestClusterBaseInfoAnalysis(unittest.TestCase): + + def setUp(self): + self.param = { + Constant.COLLECTION_PATH: "/fake/collection/path", + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: "/fake/output/path", + Constant.DATA_MAP: {}, + Constant.DATA_TYPE: Constant.DB, + Constant.COMM_DATA_DICT: {}, + Constant.DATA_SIMPLIFICATION: False + } + self.analysis = ClusterBaseInfoAnalysis(self.param) + + @patch('msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis.increase_shared_value') + def test_run_when_data_type_is_text(self, mock_increase): + with patch.object(self.analysis, 'extract_base_info') as mock_extract: + self.analysis.data_type = "text" + completed_processes = MagicMock() + lock = MagicMock() + self.analysis.run(completed_processes, lock) + mock_increase.assert_called_once_with(completed_processes, lock) + + @patch('msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis.increase_shared_value') + def test_run_when_extract_base_info_returns_true(self, mock_increase): + with patch.object(self.analysis, 'extract_base_info', return_value=True), \ + patch.object(self.analysis, 'dump_db') as mock_dump: + completed_processes = MagicMock() + lock = MagicMock() + self.analysis.run(completed_processes, lock) + mock_increase.assert_called_once_with(completed_processes, lock) + + def test_dump_db_when_has_distributed_args(self): + self.analysis.distributed_args = {"world_size": 8} + + mock_db = MagicMock() + mock_conn = MagicMock() + mock_curs = MagicMock() + mock_db.create_connect_db.return_value = (mock_conn, mock_curs) + + with patch('msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis.DBManager', mock_db), \ + patch('msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis.PathManager.make_dir_safety'): + self.analysis.dump_db() + self.assertTrue(mock_db.create_connect_db.called) + self.assertTrue(mock_db.create_tables.called) + self.assertTrue(mock_db.executemany_sql.called) + + def test_dump_db_when_no_distributed_args(self): + self.analysis.distributed_args = {} + mock_db = MagicMock() + with patch('msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis.DBManager', mock_db): + self.analysis.dump_db() + mock_db.create_connect_db.assert_not_called() + mock_db.create_tables.assert_not_called() + mock_db.executemany_sql.assert_not_called() + + def test_extract_base_info_when_metadata_contains_distributed_args(self): + path = "msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis" + with patch(f"{path}.PathManager.limited_depth_walk") as mock_walk, \ + patch(f"{path}.FileManager.read_json_file") as mock_read: + mock_walk.return_value = [("/path/rank0", [], [Constant.PROFILER_METADATA])] + mock_read.return_value = {Constant.DISTRIBUTED_ARGS: {"world_size": 8}} + result = self.analysis.extract_base_info() + self.assertTrue(result) + self.assertEqual(self.analysis.distributed_args, {"world_size": 8}) + + def test_extract_base_info_when_no_distributed_args_in_metadata(self): + path = "msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis" + with patch(f"{path}.PathManager.limited_depth_walk") as mock_walk, \ + patch(f"{path}.FileManager.read_json_file") as mock_read: + mock_walk.return_value = [("/path", [], [Constant.PROFILER_METADATA])] + mock_read.return_value = {"other": "data"} + result = self.analysis.extract_base_info() + self.assertFalse(result) + self.assertEqual(self.analysis.distributed_args, {}) + + def test_extract_base_info_when_no_metadata_files_found(self): + path = "msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis.PathManager.limited_depth_walk" + with patch(path) as mock_walk: + mock_walk.return_value = [] + result = self.analysis.extract_base_info() + self.assertFalse(result) + + def test_get_profiler_metadata_file_when_returns_correct_paths(self): + path = "msprof_analyze.cluster_analyse.analysis.cluster_base_info_analysis.PathManager.limited_depth_walk" + with patch(path) as mock_walk: + mock_walk.return_value = [ + ("/path/rank0", [], [Constant.PROFILER_METADATA, "other.txt"]), + ("/path/rank1", [], ["other.txt"]), + ("/path/rank2", [], [Constant.PROFILER_METADATA]) + ] + result = self.analysis.get_profiler_metadata_file() + expected = [ + os.path.join("/path/rank0", Constant.PROFILER_METADATA), + os.path.join("/path/rank2", Constant.PROFILER_METADATA) + ] + self.assertEqual(result, expected) \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_comm_matrix_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_comm_matrix_analysis.py new file mode 100644 index 0000000000..d8b5e41933 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_comm_matrix_analysis.py @@ -0,0 +1,199 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock +import copy +import os +import shutil +from msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis import CommMatrixAnalysis +from msprof_analyze.prof_common.constant import Constant + + +class TestCommMatrixAnalysis(unittest.TestCase): + test_dir = os.path.join(os.path.dirname(__file__), 'DT_CLUSTER_PREPROCESS') + + def setUp(self): + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + self.output_path = os.path.join(self.test_dir, "cluster_analysis_output") + os.makedirs(self.output_path, exist_ok=True) + + self.param = { + Constant.COMM_DATA_DICT: { + Constant.MATRIX_OPS: { + 'op1@group1': { + '0': {'0-1': {'transport_type': 'nccl', 'transit_time_ms': 10, 'transit_size_mb': 5}}, + '1': {'0-1': {'transport_type': 'nccl', 'transit_time_ms': 15, 'transit_size_mb': 8}} + } + } + }, + 'cluster_analysis_output_path': self.output_path + } + self.analysis = CommMatrixAnalysis(self.param) + + def tearDown(self): + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def test_combine_link_when_same_transport_type_and_op_name(self): + link_info = { + Constant.TRANSPORT_TYPE: 'nccl', + Constant.TRANSIT_TIME_MS: 10, + Constant.TRANSIT_SIZE_MB: 5, + Constant.OP_NAME: 'op1' + } + single_link = { + Constant.TRANSPORT_TYPE: 'nccl', + Constant.TRANSIT_TIME_MS: 15, + Constant.TRANSIT_SIZE_MB: 8, + Constant.OP_NAME: 'op1' + } + + CommMatrixAnalysis.combine_link(link_info, single_link) + + self.assertEqual(link_info[Constant.TRANSPORT_TYPE], 'nccl') + self.assertEqual(link_info[Constant.TRANSIT_TIME_MS], 25) + self.assertEqual(link_info[Constant.TRANSIT_SIZE_MB], 13) + self.assertEqual(link_info[Constant.OP_NAME], 'op1') + + @patch('msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis.increase_shared_value') + @patch('msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis.logger') + def test_run_when_no_comm_ops(self, mock_logger, mock_increase): + param_no_ops = copy.deepcopy(self.param) + param_no_ops[Constant.COMM_DATA_DICT][Constant.MATRIX_OPS] = None + analysis = CommMatrixAnalysis(param_no_ops) + + completed_processes = MagicMock() + lock = MagicMock() + + analysis.run(completed_processes, lock) + + mock_increase.assert_called_once_with(completed_processes, lock) + mock_logger.info.assert_called_with("CommMatrixAnalysis completed") + + @patch.object(CommMatrixAnalysis, 'split_op_by_group') + @patch.object(CommMatrixAnalysis, 'combine_ops_total_info') + @patch.object(CommMatrixAnalysis, 'dump_data') + @patch('msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis.increase_shared_value') + @patch('msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis.logger') + def test_run_with_comm_ops(self, mock_logger, mock_increase, mock_dump, mock_combine, mock_split): + completed_processes = MagicMock() + lock = MagicMock() + + self.analysis.run(completed_processes, lock) + + mock_split.assert_called_once() + mock_combine.assert_called_once() + mock_dump.assert_called_once() + mock_increase.assert_called_with(completed_processes, lock) + mock_logger.info.assert_called_with("CommMatrixAnalysis completed") + + @patch.object(CommMatrixAnalysis, 'compute_ratio') + def test_merge_same_links_when_same_op_group_and_link(self, mock_compute_ratio): + mock_compute_ratio.return_value = 4.16 + + step_dict = { + 'op1@group1': { + '0': {'0-1': { + Constant.TRANSPORT_TYPE: 'nccl', + Constant.TRANSIT_TIME_MS: 10, + Constant.TRANSIT_SIZE_MB: 5, + Constant.OP_NAME: 'op1' + }}, + '1': {'0-1': { + Constant.TRANSPORT_TYPE: 'nccl', + Constant.TRANSIT_TIME_MS: 15, + Constant.TRANSIT_SIZE_MB: 8, + Constant.OP_NAME: 'op1' + }} + } + } + with patch.object(self.analysis, 'get_parallel_group_info') as mock_group_info: + mock_group_info.return_value = {'group1': {'0': 0, '1': 1}} + + self.analysis.merge_same_links(step_dict) + self.assertIn('op1@group1', step_dict) + self.assertIn('0-1', step_dict['op1@group1']) + link_info = step_dict['op1@group1']['0-1'] + self.assertEqual(link_info[Constant.TRANSIT_TIME_MS], 25) + self.assertEqual(link_info[Constant.TRANSIT_SIZE_MB], 13) + self.assertEqual(link_info[Constant.BANDWIDTH_GB_S], 4.16) + mock_compute_ratio.assert_called_with(13, 25) + + @patch.object(CommMatrixAnalysis, 'compute_ratio') + def test_combine_link_info_when_multiple_ops_share_same_link_and_group(self, mock_compute_ratio): + mock_compute_ratio.return_value = 4.0888888888888895 + + step_dict = { + 'op1@group1': { + '0-1': { + Constant.TRANSPORT_TYPE: 'nccl', + Constant.TRANSIT_TIME_MS: 25, + Constant.TRANSIT_SIZE_MB: 13, + Constant.OP_NAME: 'op1' + } + }, + 'op2@group1': { + '0-1': { + Constant.TRANSPORT_TYPE: 'nccl', + Constant.TRANSIT_TIME_MS: 20, + Constant.TRANSIT_SIZE_MB: 10, + Constant.OP_NAME: 'op2' + } + } + } + with patch.object(self.analysis, 'check_add_op') as mock_check: + mock_check.return_value = True + self.analysis.combine_link_info(step_dict) + self.assertIn(Constant.TOTAL_OP_INFO, step_dict) + total_info = step_dict.get(Constant.TOTAL_OP_INFO) + self.assertIsNotNone(total_info) + self.assertIn('0-1', total_info) + link_info = total_info.get('0-1') + self.assertIsNotNone(link_info) + self.assertEqual(link_info.get(Constant.BANDWIDTH_GB_S), 4.0888888888888895) + mock_compute_ratio.assert_called_with(23, 45) + + @patch.object(CommMatrixAnalysis, 'compute_ratio') + def test_compute_ratio_when_input_int(self, mock_compute_ratio): + mock_compute_ratio.return_value = 4.0 + result = self.analysis.compute_ratio(100, 200) + self.assertEqual(result, 4.0) + mock_compute_ratio.assert_called_with(100, 200) + + def test_compute_ratio_when_input_zero_time(self): + result = self.analysis.compute_ratio(100, 0) + self.assertEqual(result, 0) + + @patch('msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis.os.path.join') + @patch('msprof_analyze.cluster_analyse.analysis.comm_matrix_analysis.os.path.exists') + def test_dump_db_when_db_exists_and_matrix_converts_successfully(self, mock_exists, mock_join, mock_db_manager): + mock_exists.return_value = True + mock_join.return_value = "/mock/fixed/path" + mock_adapter = MagicMock() + mock_adapter.transfer_matrix_from_json_to_db.return_value = [ + {'field1': 'value1', 'field2': 'value2'} + ] + self.analysis.adapter = mock_adapter + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_db_manager.create_connect_db.return_value = (mock_conn, mock_cursor) + self.analysis.cluster_analysis_output_path = "/mock/path" + self.analysis.dump_db() + mock_db_manager.create_tables.assert_called_once() + mock_db_manager.executemany_sql.assert_called_once() + mock_db_manager.destroy_db_connect.assert_called_once_with(mock_conn, mock_cursor) \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_communication_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_communication_analysis.py new file mode 100644 index 0000000000..8da9c963d1 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_communication_analysis.py @@ -0,0 +1,230 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock +import copy +import os +from collections import defaultdict +import shutil +from msprof_analyze.cluster_analyse.analysis.communication_analysis import CommunicationAnalysis +from msprof_analyze.prof_common.constant import Constant + + +class TestCommunicationAnalysis(unittest.TestCase): + test_dir = os.path.join(os.path.dirname(__file__), 'DT_CLUSTER_PREPROCESS') + + def setUp(self): + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + self.output_path = os.path.join(self.test_dir, "cluster_analysis_output") + os.makedirs(self.output_path, exist_ok=True) + + self.param = { + Constant.COMM_DATA_DICT: { + Constant.COMMUNICATION_OPS: { + 'op1': { + '0': { + Constant.COMMUNICATION_TIME_INFO: { + Constant.WAIT_TIME_MS: 10, + Constant.TRANSIT_TIME_MS: 20, + Constant.SYNCHRONIZATION_TIME_MS: 5 + }, + Constant.COMMUNICATION_BANDWIDTH_INFO: { + 'nccl': { + Constant.TRANSIT_TIME_MS: 15, + Constant.TRANSIT_SIZE_MB: 8, + Constant.SIZE_DISTRIBUTION: {'1KB': [5, 10], '10KB': [3, 6]} + } + } + } + } + } + }, + 'cluster_analysis_output_path': self.output_path + } + self.analysis = CommunicationAnalysis(self.param) + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def test_combine_size_distribution_when_data_contains_various_value(self): + op_dict = {'1KB': [5, 10], '10KB': [3, 6]} + total_dict = defaultdict(lambda: [0, 0]) + total_dict['1KB'] = [2, 4] + total_dict['10KB'] = [1, 2] + + CommunicationAnalysis.combine_size_distribution(op_dict, total_dict) + + self.assertEqual(total_dict['1KB'], [7, 14]) + self.assertEqual(total_dict['10KB'], [4, 8]) + + @patch('msprof_analyze.cluster_analyse.analysis.communication_analysis.increase_shared_value') + @patch('msprof_analyze.cluster_analyse.analysis.communication_analysis.logger') + def test_run_when_no_comm_ops(self, mock_logger, mock_increase): + param_no_ops = copy.deepcopy(self.param) + param_no_ops[Constant.COMM_DATA_DICT][Constant.COMMUNICATION_OPS] = None + analysis = CommunicationAnalysis(param_no_ops) + completed_processes = MagicMock() + lock = MagicMock() + analysis.run(completed_processes, lock) + mock_increase.assert_called_once_with(completed_processes, lock) + mock_logger.info.assert_called_with("CommunicationAnalysis completed") + + @patch.object(CommunicationAnalysis, 'split_op_by_group') + @patch.object(CommunicationAnalysis, 'combine_ops_total_info') + @patch.object(CommunicationAnalysis, 'dump_data') + @patch('msprof_analyze.cluster_analyse.analysis.communication_analysis.increase_shared_value') + @patch('msprof_analyze.cluster_analyse.analysis.communication_analysis.logger') + def test_run_when_has_comm_ops(self, mock_logger, mock_increase, mock_dump, mock_combine, mock_split): + completed_processes = MagicMock() + lock = MagicMock() + self.analysis.run(completed_processes, lock) + mock_split.assert_called_once() + mock_combine.assert_called_once() + mock_dump.assert_called_once() + mock_increase.assert_called_with(completed_processes, lock) + mock_logger.info.assert_called_with("CommunicationAnalysis completed") + + def test_combine_time_info_when_contains_wait_transit_sync_time(self): + com_info_dict = { + Constant.WAIT_TIME_MS: 10, + Constant.TRANSIT_TIME_MS: 20, + Constant.SYNCHRONIZATION_TIME_MS: 5, + Constant.START_TIMESTAMP: 1000 + } + total_time_info_dict = { + Constant.WAIT_TIME_MS: 15, + Constant.TRANSIT_TIME_MS: 25, + Constant.SYNCHRONIZATION_TIME_MS: 8 + } + + self.analysis.combine_time_info(com_info_dict, total_time_info_dict) + self.assertEqual(total_time_info_dict[Constant.WAIT_TIME_MS], 25) + self.assertEqual(total_time_info_dict[Constant.SYNCHRONIZATION_TIME_MS], 13) + self.assertNotIn(Constant.START_TIMESTAMP, total_time_info_dict) + + def test_combine_bandwidth_info_when_contains_nccl_with_size_distribution(self): + com_info_dict = { + 'nccl': { + Constant.TRANSIT_TIME_MS: 15, + Constant.TRANSIT_SIZE_MB: 8, + Constant.SIZE_DISTRIBUTION: {'1KB': [5, 10], '10KB': [3, 6]} + } + } + total_bandwidth_info_dict = { + 'nccl': { + Constant.TRANSIT_TIME_MS: 10, + Constant.TRANSIT_SIZE_MB: 5, + Constant.SIZE_DISTRIBUTION: {'1KB': [2, 4], '10KB': [1, 2]} + } + } + self.analysis.combine_bandwidth_info(com_info_dict, total_bandwidth_info_dict) + self.assertEqual(total_bandwidth_info_dict['nccl'][Constant.SIZE_DISTRIBUTION]['10KB'], [4, 8]) + + def test_combine_bandwidth_info_when_has_new_transport(self): + com_info_dict = { + 'hccs': { + Constant.TRANSIT_TIME_MS: 15, + Constant.TRANSIT_SIZE_MB: 8, + Constant.SIZE_DISTRIBUTION: {'1KB': [5, 10]} + } + } + total_bandwidth_info_dict = {} + self.analysis.combine_bandwidth_info(com_info_dict, total_bandwidth_info_dict) + hccs_info = total_bandwidth_info_dict.get('hccs') + self.assertIsNotNone(hccs_info) + self.assertEqual(hccs_info.get(Constant.TRANSIT_TIME_MS), 15) + self.assertEqual(hccs_info.get(Constant.TRANSIT_SIZE_MB), 8) + + @patch.object(CommunicationAnalysis, 'compute_ratio') + def test_compute_time_ratio_when_contains_wait_transit_sync_time(self, mock_compute_ratio): + mock_compute_ratio.side_effect = [0.333, 0.2] + + total_time_info_dict = { + Constant.WAIT_TIME_MS: 10, + Constant.TRANSIT_TIME_MS: 20, + Constant.SYNCHRONIZATION_TIME_MS: 5 + } + self.analysis.compute_time_ratio(total_time_info_dict) + self.assertEqual(total_time_info_dict.get(Constant.WAIT_TIME_RATIO), 0.333) + self.assertEqual(total_time_info_dict.get(Constant.SYNCHRONIZATION_TIME_RATIO), 0.2) + self.assertEqual(mock_compute_ratio.call_count, 2) + mock_compute_ratio.assert_any_call(10, 30) + mock_compute_ratio.assert_any_call(5, 25) + + @patch.object(CommunicationAnalysis, 'compute_ratio') + def test_compute_bandwidth_ratio_when_contains_nccl_transit_data(self, mock_compute_ratio): + mock_compute_ratio.return_value = 0.533 + total_bandwidth_info_dict = { + 'nccl': { + Constant.TRANSIT_TIME_MS: 15, + Constant.TRANSIT_SIZE_MB: 8, + Constant.BANDWIDTH_GB_S: 0.533, + Constant.SIZE_DISTRIBUTION: {'1KB': [5, 10]} + } + } + + self.analysis.compute_bandwidth_ratio(total_bandwidth_info_dict) + self.assertEqual(total_bandwidth_info_dict['nccl'][Constant.BANDWIDTH_GB_S], 0.533) + mock_compute_ratio.assert_called_with(8, 15) + + def test_compute_total_info_when_contains_communication_time_and_bandwidth(self): + comm_ops = { + 'op1': { + '0': { + Constant.COMMUNICATION_TIME_INFO: { + Constant.WAIT_TIME_MS: 10, + Constant.TRANSIT_TIME_MS: 20, + Constant.SYNCHRONIZATION_TIME_MS: 5 + }, + Constant.COMMUNICATION_BANDWIDTH_INFO: { + 'nccl': { + Constant.TRANSIT_TIME_MS: 15, + Constant.TRANSIT_SIZE_MB: 8, + Constant.SIZE_DISTRIBUTION: {'1KB': [5, 10]} + } + } + } + } + } + + with patch.object(self.analysis, 'compute_time_ratio') as mock_time_ratio, \ + patch.object(self.analysis, 'compute_bandwidth_ratio') as mock_bandwidth_ratio: + self.analysis.compute_total_info(comm_ops) + total_info = comm_ops.get(Constant.TOTAL_OP_INFO) + self.assertIsNotNone(total_info) + rank_info = total_info.get('0') + self.assertIsNotNone(rank_info) + mock_time_ratio.assert_called_once() + mock_bandwidth_ratio.assert_called_once() + + @patch('msprof_analyze.cluster_analyse.analysis.communication_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.communication_analysis.os.path.join') + @patch('msprof_analyze.cluster_analyse.analysis.communication_analysis.os.path.exists') + def test_dump_db_when_db_exists_and_adapter_converts_successfully(self, mock_exists, mock_join, mock_db_manager): + mock_exists.return_value = True + mock_join.return_value = "/mock/fixed/path" + mock_adapter = MagicMock() + mock_adapter.transfer_comm_from_json_to_db.return_value = ([{'time': 'data'}], [{'bandwidth': 'data'}]) + self.analysis.adapter = mock_adapter + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_db_manager.create_connect_db.return_value = (mock_conn, mock_cursor) + + self.analysis.dump_db() + mock_db_manager.create_tables.assert_called_once() + self.assertEqual(mock_db_manager.executemany_sql.call_count, 2) + mock_db_manager.destroy_db_connect.assert_called_once_with(mock_conn, mock_cursor) \ No newline at end of file diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_host_info_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_host_info_analysis.py new file mode 100644 index 0000000000..ce029568c1 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_host_info_analysis.py @@ -0,0 +1,220 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import patch, MagicMock +import os +import shutil +from msprof_analyze.cluster_analyse.analysis.host_info_analysis import HostInfoAnalysis +from msprof_analyze.prof_common.constant import Constant + + +class TestHostInfoAnalysis(unittest.TestCase): + test_dir = os.path.join(os.path.dirname(__file__), 'DT_CLUSTER_PREPROCESS') + + def setUp(self): + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + self.output_path = os.path.join(self.test_dir, "cluster_analysis_output") + os.makedirs(self.output_path, exist_ok=True) + + self.profiling_dir_0 = os.path.join(self.test_dir, 'profiling_0') + self.profiling_dir_1 = os.path.join(self.test_dir, 'profiling_1') + os.makedirs(self.profiling_dir_0, exist_ok=True) + os.makedirs(self.profiling_dir_1, exist_ok=True) + + self.param = { + 'data_type': Constant.DB, + 'cluster_analysis_output_path': self.output_path, + Constant.IS_MSPROF: False, + Constant.IS_MINDSPORE: False, + 'data_map': { + '0': self.profiling_dir_0, + '1': self.profiling_dir_1 + } + } + self.analysis = HostInfoAnalysis(self.param) + + def tearDown(self): + shutil.rmtree(self.test_dir, ignore_errors=True) + + def mock_join_function(self, *args): + filtered_args = [str(arg) for arg in args if arg is not None] + return os.path.join("/mock", *filtered_args) + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.increase_shared_value') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.logger') + def test_run_when_no_db_data_type_and_with_process_lock(self, mock_logger, mock_increase): + analysis = HostInfoAnalysis({'data_type': 'json'}) + completed_processes = MagicMock() + lock = MagicMock() + + analysis.run(completed_processes, lock) + + mock_increase.assert_called_once_with(completed_processes, lock) + mock_logger.info.assert_called_with("HostInfoAnalysis completed") + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.increase_shared_value') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.logger') + def test_run_when_no_db_data_type_no_lock(self, mock_logger, mock_increase): + analysis = HostInfoAnalysis({'data_type': 'json'}) + + analysis.run() + + mock_increase.assert_not_called() + mock_logger.info.assert_called_with("HostInfoAnalysis completed") + + @patch.object(HostInfoAnalysis, 'analyze_host_info') + @patch.object(HostInfoAnalysis, 'dump_db') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.increase_shared_value') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.logger') + def test_run_when_db_data_type_and_with_process_lock(self, mock_logger, mock_increase, mock_dump_db, mock_analyze): + completed_processes = MagicMock() + lock = MagicMock() + + self.analysis.run(completed_processes, lock) + mock_analyze.assert_called_once() + mock_dump_db.assert_called_once() + mock_increase.assert_called_with(completed_processes, lock) + mock_logger.info.assert_called_with("HostInfoAnalysis completed") + + @patch.object(HostInfoAnalysis, 'analyze_host_info') + @patch.object(HostInfoAnalysis, 'dump_db') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.increase_shared_value') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.logger') + def test_run_when_db_data_type_and_no_lock_mode(self, mock_logger, mock_increase, mock_dump_db, mock_analyze): + self.analysis.run() + mock_dump_db.assert_called_once() + mock_increase.assert_not_called() + mock_logger.info.assert_called_with("HostInfoAnalysis completed") + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.os.path.join') + def test_dump_host_info_when_host_info_is_not_empty(self, mock_join, mock_db_manager): + mock_join.side_effect = self.mock_join_function + self.analysis.all_rank_host_info = {'host1': 'hostname1', 'host2': 'hostname2'} + mock_conn = MagicMock() + mock_db_manager.create_connect_db.return_value = (mock_conn, MagicMock()) + + self.analysis.dump_host_info('/mock/db', mock_conn) + mock_db_manager.create_tables.assert_called_once() + mock_db_manager.executemany_sql.assert_called_once() + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.os.path.join') + def test_dump_rank_device_map_when_data_is_not_empty(self, mock_join, mock_db_manager): + mock_join.side_effect = self.mock_join_function + self.analysis.all_rank_device_info = [['0', 'device0'], ['1', 'device1']] + mock_conn = MagicMock() + mock_db_manager.create_connect_db.return_value = (mock_conn, MagicMock()) + + self.analysis.dump_rank_device_map('/mock/db', mock_conn) + mock_db_manager.create_tables.assert_called_once() + mock_db_manager.executemany_sql.assert_called_once() + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.os.path.join') + def test_dump_rank_device_map_when_data_is_empty(self, mock_join, mock_db_manager): + mock_join.side_effect = self.mock_join_function + mock_conn = MagicMock() + mock_db_manager.create_connect_db.return_value = (mock_conn, MagicMock()) + + self.analysis.dump_rank_device_map('/mock/db', mock_conn) + mock_db_manager.create_tables.assert_not_called() + mock_db_manager.executemany_sql.assert_not_called() + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.MsprofDataPreprocessor') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.MindsporeDataPreprocessor') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.logger') + def test_analyze_host_info_msprof_when_mode_is_msprof_and_info_exists(self, mock_logger, \ + mock_mindspore, mock_msprof, mock_db_manager): + self.analysis.is_msprof = True + mock_db_path = os.path.join(self.test_dir, 'test.db') + mock_db_manager.check_tables_in_db.return_value = True + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_db_manager.create_connect_db.return_value = (mock_conn, mock_cursor) + mock_db_manager.fetch_all_data.side_effect = [ + [['host_uid_0', 'host_name_0']], + [['0', 'device0']], + [['host_uid_1', 'host_name_1']], + [['1', 'device1']] + ] + + mock_msprof.get_device_id.side_effect = ['device0', 'device1'] + mock_msprof.get_msprof_profiler_db_path.return_value = mock_db_path + + with patch('os.path.exists', return_value=True): + self.analysis.analyze_host_info() + + expected_host_info = { + 'host_uid_0': 'host_name_0', + 'host_uid_1': 'host_name_1' + } + self.assertEqual(self.analysis.all_rank_host_info, expected_host_info) + self.assertEqual(len(self.analysis.all_rank_device_info), 2) + expected_device_info_0 = ['0', 'device0', 'host_uid_0', self.profiling_dir_0] + expected_device_info_1 = ['1', 'device1', 'host_uid_1', self.profiling_dir_1] + self.assertIn(expected_device_info_0, self.analysis.all_rank_device_info) + self.assertIn(expected_device_info_1, self.analysis.all_rank_device_info) + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.MsprofDataPreprocessor') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.MindsporeDataPreprocessor') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.logger') + def test_analyze_host_info_when_no_host_info(self, mock_logger, mock_mindspore, mock_msprof, mock_db_manager): + mock_db_path = os.path.join(self.test_dir, 'test.db') + mock_db_manager.check_tables_in_db.return_value = True + mock_conn = MagicMock() + mock_cursor = MagicMock() + mock_db_manager.create_connect_db.return_value = (mock_conn, mock_cursor) + + mock_db_manager.fetch_all_data.side_effect = [ + [], + [['0', 'device0']], + [], + [['1', 'device1']] + ] + with patch('os.path.exists', return_value=True): + self.analysis.analyze_host_info() + + self.assertEqual(self.analysis.all_rank_host_info, {}) + self.assertEqual(self.analysis.all_rank_device_info, []) + self.assertTrue(mock_logger.warning.called) + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.MsprofDataPreprocessor') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.MindsporeDataPreprocessor') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.logger') + def test_analyze_host_info_when_db_not_exist(self, mock_logger, mock_mindspore, mock_msprof, mock_db_manager): + mock_db_manager.check_tables_in_db.return_value = True + + with patch('os.path.exists', return_value=False): + self.analysis.analyze_host_info() + + self.assertEqual(self.analysis.all_rank_host_info, {}) + self.assertEqual(self.analysis.all_rank_device_info, []) + self.assertTrue(mock_logger.warning.called) + + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.DBManager') + @patch('msprof_analyze.cluster_analyse.analysis.host_info_analysis.logger') + def test_analyze_host_info_when_no_tables(self, mock_logger, mock_db_manager): + mock_db_manager.check_tables_in_db.return_value = False + with patch('os.path.exists', return_value=True): + self.analysis.analyze_host_info() + + self.assertEqual(self.analysis.all_rank_host_info, {}) + self.assertEqual(self.analysis.all_rank_device_info, []) \ No newline at end of file -- Gitee