From 416da00c02b5980b0f0059a635d01a258d88fa04 Mon Sep 17 00:00:00 2001 From: xieanran <694099604@qq.com> Date: Mon, 25 Aug 2025 09:48:02 +0800 Subject: [PATCH] mstt cluster_analysis && data_transfer_adapter Unit Test --- .../test_data_transfer_adapter.py | 403 +++++++++++++++ .../cluster_analyse/test_cluster_analysis.py | 482 ++++++++++++++++++ 2 files changed, 885 insertions(+) create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/cluster_utils/test_data_transfer_adapter.py create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/test_cluster_analysis.py diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_utils/test_data_transfer_adapter.py b/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_utils/test_data_transfer_adapter.py new file mode 100644 index 0000000000..a9a93336f5 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_utils/test_data_transfer_adapter.py @@ -0,0 +1,403 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from msprof_analyze.cluster_analyse.cluster_utils.data_transfer_adapter import DataTransferAdapter +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant +from msprof_analyze.prof_common.constant import Constant + + +class TestDataTransferAdapter(unittest.TestCase): + """ + DataTransferAdapter UTest + DataTransferAdapter mainly inter-transfer data with database and json + """ + + def setUp(self): + + self.adapter = DataTransferAdapter() + + self.mock_time_info = [ + { + TableConstant.STEP: "step_0", + TableConstant.TYPE: "forward", + TableConstant.HCCL_OP_NAME: "AllReduce", + TableConstant.GROUP_NAME: "group_0", + TableConstant.START_TIMESTAMP: 1000, + TableConstant.ELAPSED_TIME: 100, + TableConstant.TRANSIT_TIME: 50, + TableConstant.WAIT_TIME: 30, + TableConstant.SYNCHRONIZATION_TIME: 20, + TableConstant.IDLE_TIME: 10, + TableConstant.SYNCHRONIZATION_TIME_RATIO: 0.2, + TableConstant.WAIT_TIME_RATIO: 0.3 + } + ] + + self.mock_bandwidth_info = [ + { + TableConstant.STEP: "step_0", + TableConstant.TYPE: "forward", + TableConstant.HCCL_OP_NAME: "AllReduce", + TableConstant.GROUP_NAME: "group_0", + TableConstant.TRANSPORT_TYPE: "RDMA", + TableConstant.TRANSIT_SIZE: 1024, + TableConstant.TRANSIT_TIME: 50, + TableConstant.BANDWIDTH: 20.48, + TableConstant.LARGE_PACKET_RATIO: 0.8, + TableConstant.PACKAGE_SIZE: "1MB", + TableConstant.COUNT: 10, + TableConstant.TOTAL_DURATION: 500 + } + ] + + self.mock_matrix_data = [ + { + TableConstant.STEP: "step_0", + TableConstant.TYPE: "forward", + TableConstant.HCCL_OP_NAME: "AllReduce", + TableConstant.GROUP_NAME: "group_0", + TableConstant.SRC_RANK: "0", + TableConstant.DST_RANK: "1", + TableConstant.TRANSIT_SIZE: 1024, + TableConstant.TRANSIT_TIME: 50, + TableConstant.BANDWIDTH: 20.48, + TableConstant.TRANSPORT_TYPE: "RDMA", + TableConstant.OPNAME: "AllReduce" + } + ] + + def test_init(self): + """ + test DataTransferAdapter init + """ + adapter = DataTransferAdapter() + self.assertIsInstance(adapter, DataTransferAdapter) + + self.assertIsInstance(adapter.COMM_TIME_TABLE_COLUMN, list) + self.assertIsInstance(adapter.COMM_TIME_JSON_COLUMN, list) + self.assertIsInstance(adapter.MATRIX_TABLE_COLUMN, list) + self.assertIsInstance(adapter.MATRIX_JSON_COLUMN, list) + self.assertIsInstance(adapter.COMM_BD_TABLE_COLUMN, list) + self.assertIsInstance(adapter.COMM_BD_JSON_COLUMN, list) + + def test_transfer_comm_from_db_to_json_empty_data_success(self): + """ + test from database to json with empty data + """ + result = self.adapter.transfer_comm_from_db_to_json([], []) + self.assertEqual(result, {}) + + result = self.adapter.transfer_comm_from_db_to_json(None, None) + self.assertEqual(result, {}) + + + def test_transfer_comm_from_db_to_json_both_info(self): + """ + test from database to json with both time and bandwidth info + """ + result = self.adapter.transfer_comm_from_db_to_json(self.mock_time_info, self.mock_bandwidth_info) + + expected_hccl_name = "AllReduce@group_0" + + # 验证时间信息 + self.assertIn(Constant.COMMUNICATION_TIME_INFO, result["step_0"]["forward"][expected_hccl_name]) + + # 验证带宽信息 + self.assertIn(Constant.COMMUNICATION_BANDWIDTH_INFO, result["step_0"]["forward"][expected_hccl_name]) + + def test_transfer_comm_from_json_to_db_empty_data_success(self): + """ + test from json transfer to db with empty data + """ + comm_data, bd_data = self.adapter.transfer_comm_from_json_to_db({}) + self.assertEqual(comm_data, []) + self.assertEqual(bd_data, []) + + def test_transfer_comm_from_json_to_db_with_data_success(self): + """ + test from json transfer to db with data + """ + json_data = { + "rank_set_0": { + "step_0": { + "AllReduce@group_0": { + "rank0": { + Constant.COMMUNICATION_TIME_INFO: { + Constant.START_TIMESTAMP: 1000, + Constant.ELAPSE_TIME_MS: 100, + Constant.TRANSIT_TIME_MS: 50, + Constant.WAIT_TIME_MS: 30, + Constant.SYNCHRONIZATION_TIME_MS: 20, + Constant.IDLE_TIME_MS: 10, + Constant.SYNCHRONIZATION_TIME_RATIO: 0.2, + Constant.WAIT_TIME_RATIO: 0.3 + }, + Constant.COMMUNICATION_BANDWIDTH_INFO: { + "RDMA": { + Constant.TRANSIT_SIZE_MB: 1024, + Constant.TRANSIT_TIME_MS: 50, + Constant.BANDWIDTH_GB_S: 20.48, + Constant.LARGE_PACKET_RATIO: 0.8, + Constant.SIZE_DISTRIBUTION: { + "1MB": [10, 500] + } + } + } + } + } + } + } + } + + comm_data, bd_data = self.adapter.transfer_comm_from_json_to_db(json_data) + + # 验证通信时间数据 + self.assertEqual(len(comm_data), 1) + comm_record = comm_data[0] + self.assertEqual(comm_record[TableConstant.RANK_SET], "rank_set_0") + self.assertEqual(comm_record[TableConstant.STEP], "step_0") + self.assertEqual(comm_record[TableConstant.HCCL_OP_NAME], "AllReduce") + self.assertEqual(comm_record[TableConstant.GROUP_NAME], "group_0") + self.assertEqual(comm_record[TableConstant.START_TIMESTAMP], 1000) + self.assertEqual(comm_record[TableConstant.ELAPSED_TIME], 100) + + # 验证带宽数据 + self.assertEqual(len(bd_data), 1) + bd_record = bd_data[0] + self.assertEqual(bd_record[TableConstant.RANK_SET], "rank_set_0") + self.assertEqual(bd_record[TableConstant.STEP], "step_0") + self.assertEqual(bd_record[TableConstant.HCCL_OP_NAME], "AllReduce") + self.assertEqual(bd_record[TableConstant.GROUP_NAME], "group_0") + self.assertEqual(bd_record[TableConstant.TRANSPORT_TYPE], "RDMA") + self.assertEqual(bd_record[TableConstant.TRANSIT_SIZE], 1024) + self.assertEqual(bd_record[TableConstant.PACKAGE_SIZE], "1MB") + self.assertEqual(bd_record[TableConstant.COUNT], 10) + self.assertEqual(bd_record[TableConstant.TOTAL_DURATION], 500) + + def test_set_value_by_key(self): + """ + test set value by key + """ + src_dict = {} + dst_dict = { + TableConstant.TRANSIT_SIZE: 1024, + TableConstant.TRANSIT_TIME: 50, + TableConstant.BANDWIDTH: 20.48 + } + key_dict = { + Constant.TRANSIT_SIZE_MB: TableConstant.TRANSIT_SIZE, + Constant.TRANSIT_TIME_MS: TableConstant.TRANSIT_TIME, + Constant.BANDWIDTH_GB_S: TableConstant.BANDWIDTH + } + + self.adapter.set_value_by_key(src_dict, dst_dict, key_dict) + + expected = { + Constant.TRANSIT_SIZE_MB: 1024, + Constant.TRANSIT_TIME_MS: 50, + Constant.BANDWIDTH_GB_S: 20.48 + } + self.assertEqual(src_dict, expected) + + def test_set_value_by_key_with_missing_values(self): + """ + test set value by key with missing values + """ + src_dict = {} + dst_dict = { + TableConstant.TRANSIT_SIZE: 1024 + } + key_dict = { + Constant.TRANSIT_SIZE_MB: TableConstant.TRANSIT_SIZE, + Constant.TRANSIT_TIME_MS: TableConstant.TRANSIT_TIME, + Constant.BANDWIDTH_GB_S: TableConstant.BANDWIDTH + } + + self.adapter.set_value_by_key(src_dict, dst_dict, key_dict) + + expected = { + Constant.TRANSIT_SIZE_MB: 1024, + Constant.TRANSIT_TIME_MS: 0, + Constant.BANDWIDTH_GB_S: 0 + } + self.assertEqual(src_dict, expected) + + def test_transfer_matrix_from_db_to_json_empty_data(self): + """ + test transfer matrix from db to json with empty data + """ + result = self.adapter.transfer_matrix_from_db_to_json([]) + self.assertEqual(result, {}) + + result = self.adapter.transfer_matrix_from_db_to_json(None) + self.assertEqual(result, {}) + + def test_transfer_matrix_from_db_to_json_with_data(self): + """ + test transfer matrix from db to json with data + """ + result = self.adapter.transfer_matrix_from_db_to_json(self.mock_matrix_data) + + expected_hccl_name = "AllReduce@group_0" + expected_key = "0-1" + expected_matrix_data = { + Constant.TRANSIT_SIZE_MB: 1024, + Constant.TRANSIT_TIME_MS: 50, + Constant.BANDWIDTH_GB_S: 20.48, + Constant.TRANSPORT_TYPE: "RDMA", + Constant.OP_NAME: "AllReduce" + } + + self.assertIn("step_0", result) + self.assertIn("forward", result["step_0"]) + self.assertIn(expected_hccl_name, result["step_0"]["forward"]) + self.assertIn(expected_key, result["step_0"]["forward"][expected_hccl_name]) + self.assertEqual(result["step_0"]["forward"][expected_hccl_name][expected_key], expected_matrix_data) + + def test_transfer_matrix_from_json_to_db_empty_data(self): + """ + test transfer matrix from json to db with empty data + """ + result = self.adapter.transfer_matrix_from_json_to_db({}) + self.assertEqual(result, []) + + def test_transfer_matrix_from_json_to_db_with_data(self): + """ + test transfer matrix from json to db with data + """ + json_data = { + "rank_set_0": { + "step_0": { + "AllReduce@group_0": { + "0-1": { + Constant.TRANSIT_SIZE_MB: 1024, + Constant.TRANSIT_TIME_MS: 50, + Constant.BANDWIDTH_GB_S: 20.48, + Constant.TRANSPORT_TYPE: "RDMA", + Constant.OP_NAME: "AllReduce" + } + } + } + } + } + + result = self.adapter.transfer_matrix_from_json_to_db(json_data) + + self.assertEqual(len(result), 1) + matrix_record = result[0] + + self.assertEqual(matrix_record[TableConstant.RANK_SET], "rank_set_0") + self.assertEqual(matrix_record[TableConstant.STEP], "step_0") + self.assertEqual(matrix_record[TableConstant.HCCL_OP_NAME], "AllReduce") + self.assertEqual(matrix_record[TableConstant.GROUP_NAME], "group_0") + self.assertEqual(matrix_record[TableConstant.SRC_RANK], "0") + self.assertEqual(matrix_record[TableConstant.DST_RANK], "1") + self.assertEqual(matrix_record[TableConstant.TRANSIT_SIZE], 1024) + self.assertEqual(matrix_record[TableConstant.TRANSIT_TIME], 50) + self.assertEqual(matrix_record[TableConstant.BANDWIDTH], 20.48) + self.assertEqual(matrix_record[TableConstant.TRANSPORT_TYPE], "RDMA") + self.assertEqual(matrix_record[TableConstant.OPNAME], "AllReduce") + + def test_transfer_matrix_from_json_to_db_without_group_name_success(self): + """ + test matrix from json to db without group name + """ + json_data = { + "rank_set_0": { + "step_0": { + "AllReduce": { # 没有@group_0 + "0-1": { + Constant.TRANSIT_SIZE_MB: 1024, + Constant.TRANSIT_TIME_MS: 50 + } + } + } + } + } + + result = self.adapter.transfer_matrix_from_json_to_db(json_data) + + self.assertEqual(len(result), 1) + matrix_record = result[0] + self.assertEqual(matrix_record[TableConstant.HCCL_OP_NAME], "AllReduce") + self.assertEqual(matrix_record[TableConstant.GROUP_NAME], "") + + def test_transfer_comm_from_json_to_db_without_ratio_fields_success(self): + """ + test from json to db without ratio field + """ + json_data = { + "rank_set_0": { + "step_0": { + "rank_0": { + "AllReduce@group_0": { + Constant.COMMUNICATION_TIME_INFO: { + Constant.START_TIMESTAMP: 1000, + Constant.ELAPSE_TIME_MS: 100 + # no other param + } + } + } + } + } + } + + comm_data, bd_data = self.adapter.transfer_comm_from_json_to_db(json_data) + + self.assertEqual(len(comm_data), 1) + comm_record = comm_data[0] + self.assertEqual(comm_record[TableConstant.START_TIMESTAMP], 1000) + self.assertEqual(comm_record[TableConstant.ELAPSED_TIME], 100) + # 其他字段应该为默认值0 + self.assertEqual(comm_record[TableConstant.TRANSIT_TIME], 0) + self.assertEqual(comm_record[TableConstant.WAIT_TIME], 0) + + def test_transfer_comm_from_db_to_json_multiple_steps(self): + """ + test from db to json multiple_steps + """ + multi_step_time_info = [ + { + TableConstant.STEP: "step_0", + TableConstant.TYPE: "forward", + TableConstant.HCCL_OP_NAME: "AllReduce", + TableConstant.GROUP_NAME: "group_0", + TableConstant.START_TIMESTAMP: 1000, + TableConstant.ELAPSED_TIME: 100 + }, + { + TableConstant.STEP: "step_1", + TableConstant.TYPE: "backward", + TableConstant.HCCL_OP_NAME: "AllGather", + TableConstant.GROUP_NAME: "group_1", + TableConstant.START_TIMESTAMP: 2000, + TableConstant.ELAPSED_TIME: 200 + } + ] + + result = self.adapter.transfer_comm_from_db_to_json(multi_step_time_info, []) + + # 验证两个步骤都存在 + self.assertIn("step_0", result) + self.assertIn("step_1", result) + self.assertIn("forward", result["step_0"]) + self.assertIn("backward", result["step_1"]) + + # 验证不同的HCCL操作 + self.assertIn("AllReduce@group_0", result["step_0"]["forward"]) + self.assertIn("AllGather@group_1", result["step_1"]["backward"]) + diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/test_cluster_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/test_cluster_analysis.py new file mode 100644 index 0000000000..fccc43bf02 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/test_cluster_analysis.py @@ -0,0 +1,482 @@ +# Copyright (c) 2025-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import sys +import tempfile +import unittest +from unittest import mock +from unittest.mock import MagicMock, patch + +from msprof_analyze.prof_common.constant import Constant +from msprof_analyze.cluster_analyse.cluster_analysis import cluster_analysis_main +from msprof_analyze.cluster_analyse.cluster_analysis import Interface + + +NAMESPACE = "msprof_analyze.cluster_analyse" + + +class TestClusterAnalyseClusterAnalysis(unittest.TestCase): + """ + test cluster analysis + solutions: cluster_analysis.py is the entrance of cluster_analysis, + its main function is parse the argv and run encountered analysis task. + However, run whole task in UTest is not reasonable, so the main solutions is checking return of failure. + """ + + def setUp(self): + # argv backup + self._orig_argv = sys.argv + + self.test_dir = tempfile.mkdtemp() + self.profiling_path = os.path.join(self.test_dir, "profiling_data") + self.output_path = os.path.join(self.test_dir, "output") + + os.makedirs(self.profiling_path, exist_ok=True) + os.makedirs(self.output_path, exist_ok=True) + + self.ascend_pt_dir = os.path.join(self.profiling_path, "test_ascend_pt") + self.ascend_ms_dir = os.path.join(self.profiling_path, "test_ascend_ms") + self.prof_dir = os.path.join(self.profiling_path, "PROF_114514") + + os.makedirs(self.ascend_pt_dir, exist_ok=True) + os.makedirs(self.ascend_ms_dir, exist_ok=True) + os.makedirs(self.prof_dir, exist_ok=True) + + def tearDown(self): + # restore argv, avoiding argv pollution + sys.argv = self._orig_argv + + # remove temp + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_interface_data_map_initialization(self): + """ + test Interface class initialization + """ + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "all" + } + + interface = Interface(params) + + # 验证初始数据映射为空 + self.assertEqual(interface.data_map, {}) + self.assertEqual(interface.communication_group, {}) + self.assertEqual(interface.collective_group_dict, {}) + self.assertEqual(interface.communication_ops, []) + self.assertEqual(interface.matrix_ops, []) + + def test_cluster_analysis_main_should_run_success_and_handle_correct_parameter(self): + """ + test main entrance basic + """ + with mock.patch(NAMESPACE + ".cluster_analysis.Interface") as mock_if: + sys.argv = [ + "cluster_analysis.py", + "-d", "./tmp/prof", + "-o", "./tmp/out", + "-m", "all", + "--data_simplification", + "--force", + ] + + # execute cluster entrance + cluster_analysis_main() + + # assert Interface be called once + self.assertEqual(mock_if.call_count, 1) + kwargs = mock_if.call_args[0][0] # first arg is parameter dict + self.assertEqual(kwargs["profiling_path"], "./tmp/prof") + self.assertEqual(kwargs["mode"], "all") + self.assertEqual(kwargs["output_path"], "./tmp/out") + self.assertTrue(kwargs["data_simplification"]) + self.assertTrue(kwargs["force"]) + + # restore origin argv, avoiding argv pollution + sys.argv = self._orig_argv + + def test_cluster_analysis_main_all_parameters_success(self): + """ + test main entrance all parameters + """ + with patch(NAMESPACE + '.cluster_analysis.Interface') as mock_interface: + # mock class Interface + mock_interface_instance = MagicMock() + mock_interface.return_value = mock_interface_instance + + # set all parameters + sys.argv = [ + "cluster_analysis.py", + "-d", self.profiling_path, + "-o", self.output_path, + "-m", "communication_time", + "--data_simplification", + "--force", + "--parallel_mode", "sequential", + "--export_type", "notebook", + "--rank_list", "0,1,2", + "--step_id", "100", + Constant.EXTRA_ARGS, "--bp", "/data2" + ] + + cluster_analysis_main() + + # test Interface + mock_interface.assert_called_once() + call_args = mock_interface.call_args[0][0] + + self.assertEqual(call_args["profiling_path"], self.profiling_path) + self.assertEqual(call_args["output_path"], self.output_path) + self.assertEqual(call_args["mode"], "communication_time") + self.assertTrue(call_args["data_simplification"]) + self.assertTrue(call_args["force"]) + self.assertEqual(call_args["parallel_mode"], "sequential") + self.assertEqual(call_args["export_type"], "notebook") + self.assertEqual(call_args["rank_list"], "0,1,2") + self.assertEqual(call_args["step_id"], 100) + + def test_allocate_prof_data_pytorch_only_will_success(self): + """ + test data pytorch only + """ + with patch(NAMESPACE + '.cluster_analysis.PytorchDataPreprocessor') as mock_pt, \ + patch(NAMESPACE + '.cluster_analysis.MindsporeDataPreprocessor') as mock_ms, \ + patch(NAMESPACE + '.cluster_analysis.MsprofDataPreprocessor') as mock_msprof: + + # mock PyTorch data preprocessor + mock_pt_instance = MagicMock() + mock_pt_instance.get_data_map.return_value = {"rank0": "data0", "rank1": "data1"} + mock_pt_instance.get_data_type.return_value = "db" + mock_pt.return_value = mock_pt_instance + + # mock mindspore data preprocessor will return empty + mock_ms_instance = MagicMock() + mock_ms_instance.get_data_map.return_value = {} + mock_ms.return_value = mock_ms_instance + + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "all" + } + + interface = Interface(params) + result = interface.allocate_prof_data() + + expected = { + Constant.DATA_MAP: {"rank0": "data0", "rank1": "data1"}, + Constant.DATA_TYPE: "db", + Constant.IS_MSPROF: False + } + self.assertEqual(result, expected) + + def test_allocate_prof_data_mindspore_only_will_success(self): + """ + test data mindspore only + """ + with patch(NAMESPACE + '.cluster_analysis.PytorchDataPreprocessor') as mock_pt, \ + patch(NAMESPACE + '.cluster_analysis.MindsporeDataPreprocessor') as mock_ms, \ + patch(NAMESPACE + '.cluster_analysis.MsprofDataPreprocessor') as mock_msprof: + + # mock PyTorch data preprocessor will return empty + mock_pt_instance = MagicMock() + mock_pt_instance.get_data_map.return_value = {} + mock_pt.return_value = mock_pt_instance + + # mock Mindspore data preprocessor + mock_ms_instance = MagicMock() + mock_ms_instance.get_data_map.return_value = {"rank0": "data0"} + mock_ms_instance.get_data_type.return_value = "db" + mock_ms.return_value = mock_ms_instance + + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "all" + } + + interface = Interface(params) + result = interface.allocate_prof_data() + + expected = { + Constant.DATA_MAP: {"rank0": "data0"}, + Constant.DATA_TYPE: "db", + Constant.IS_MSPROF: False, + Constant.IS_MINDSPORE: True + } + self.assertEqual(result, expected) + + def test_allocate_prof_data_msprof_only_will_success(self): + """ + test data msprof only + """ + with patch(NAMESPACE + '.cluster_analysis.PytorchDataPreprocessor') as mock_pt, \ + patch(NAMESPACE + '.cluster_analysis.MindsporeDataPreprocessor') as mock_ms, \ + patch(NAMESPACE + '.cluster_analysis.MsprofDataPreprocessor') as mock_msprof: + + # mock PyTorch and Mindspore preprocessor return empty + mock_pt_instance = MagicMock() + mock_pt_instance.get_data_map.return_value = {} + mock_pt.return_value = mock_pt_instance + + mock_ms_instance = MagicMock() + mock_ms_instance.get_data_map.return_value = {} + mock_ms.return_value = mock_ms_instance + + # mock msprof data + mock_msprof_instance = MagicMock() + mock_msprof_instance.get_data_map.return_value = {"rank0": "prof_data"} + mock_msprof_instance.get_data_type.return_value = "db" + mock_msprof.return_value = mock_msprof_instance + + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "all" + } + + interface = Interface(params) + result = interface.allocate_prof_data() + + expected = { + Constant.DATA_MAP: {"rank0": "prof_data"}, + Constant.DATA_TYPE: "db", + Constant.IS_MSPROF: True + } + self.assertEqual(result, expected) + + def test_allocate_prof_data_both_frameworks_will_return_error(self): + """ + test data both-frameworks error + """ + with patch(NAMESPACE + '.cluster_analysis.PytorchDataPreprocessor') as mock_pt, \ + patch(NAMESPACE + '.cluster_analysis.MindsporeDataPreprocessor') as mock_ms: + + # mock both PyTorch and Mindspore return data + mock_pt_instance = MagicMock() + mock_pt_instance.get_data_map.return_value = {"rank0": "pt_data"} + mock_pt.return_value = mock_pt_instance + + mock_ms_instance = MagicMock() + mock_ms_instance.get_data_map.return_value = {"rank0": "ms_data"} + mock_ms.return_value = mock_ms_instance + + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "all" + } + + interface = Interface(params) + result = interface.allocate_prof_data() + + # assert return empty dict for data will not be process + self.assertEqual(result, {}) + + def test_run_failure_no_data_map(self): + """ + test Interface.run method failure when no data map + """ + with patch(NAMESPACE + '.cluster_analysis.PathManager') as mock_path_manager, \ + patch(NAMESPACE + '.cluster_analysis.logger') as mock_logger, \ + patch(NAMESPACE + '.cluster_analysis.PytorchDataPreprocessor') as mock_pt, \ + patch(NAMESPACE + '.cluster_analysis.MindsporeDataPreprocessor') as mock_ms, \ + patch(NAMESPACE + '.cluster_analysis.MsprofDataPreprocessor') as mock_msprof: + + # Mock path manager checks + mock_path_manager.check_input_directory_path.return_value = None + mock_path_manager.check_path_owner_consistent.return_value = None + + # Mock all data preprocessors return empty + mock_pt_instance = MagicMock() + mock_pt_instance.get_data_map.return_value = {} + mock_pt.return_value = mock_pt_instance + + mock_ms_instance = MagicMock() + mock_ms_instance.get_data_map.return_value = {} + mock_ms.return_value = mock_ms_instance + + mock_msprof_instance = MagicMock() + mock_msprof_instance.get_data_map.return_value = {} + mock_msprof.return_value = mock_msprof_instance + + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "all", + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: self.output_path + } + + interface = Interface(params) + interface.run() + + # Verify warning log for no data + mock_logger.warning.assert_called_with("Can not get rank info or profiling data.") + + def test_run_failure_text_data_with_recipe_mode(self): + """ + test Interface.run method failure when text data with recipe mode + """ + with patch(NAMESPACE + '.cluster_analysis.PathManager') as mock_path_manager, \ + patch(NAMESPACE + '.cluster_analysis.logger') as mock_logger, \ + patch(NAMESPACE + '.cluster_analysis.PytorchDataPreprocessor') as mock_pt, \ + patch(NAMESPACE + '.cluster_analysis.MindsporeDataPreprocessor') as mock_ms, \ + patch(NAMESPACE + '.cluster_analysis.MsprofDataPreprocessor') as mock_msprof: + + # Mock path manager checks + mock_path_manager.check_input_directory_path.return_value = None + mock_path_manager.check_path_owner_consistent.return_value = None + + # Mock data preprocessor returns text data type + mock_pt_instance = MagicMock() + mock_pt_instance.get_data_map.return_value = {"rank0": "data0"} + mock_pt_instance.get_data_type.return_value = "text" + mock_pt.return_value = mock_pt_instance + + mock_ms_instance = MagicMock() + mock_ms_instance.get_data_map.return_value = {} + mock_ms.return_value = mock_ms_instance + + mock_msprof_instance = MagicMock() + mock_msprof_instance.get_data_map.return_value = {} + mock_msprof.return_value = mock_msprof_instance + + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "freq_analysis", # recipe mode + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: self.output_path + } + + interface = Interface(params) + interface.run() + + # Verify error log for text data with recipe mode + mock_logger.error.assert_called_with("The current analysis node only supports DB as input data. Please check.") + + def test_run_with_data_simplification(self): + """ + test Interface.run method with data simplification enabled + """ + with patch(NAMESPACE + '.cluster_analysis.PathManager') as mock_path_manager, \ + patch(NAMESPACE + '.cluster_analysis.logger') as mock_logger, \ + patch(NAMESPACE + '.cluster_analysis.CommunicationGroupGenerator') as mock_comm_generator, \ + patch(NAMESPACE + '.cluster_analysis.AnalysisFacade') as mock_analysis_facade, \ + patch(NAMESPACE + '.cluster_analysis.FileManager') as mock_file_manager, \ + patch(NAMESPACE + '.cluster_analysis.PytorchDataPreprocessor') as mock_pt, \ + patch(NAMESPACE + '.cluster_analysis.MindsporeDataPreprocessor') as mock_ms, \ + patch(NAMESPACE + '.cluster_analysis.MsprofDataPreprocessor') as mock_msprof: + + # Mock path manager checks + mock_path_manager.check_input_directory_path.return_value = None + mock_path_manager.check_path_owner_consistent.return_value = None + mock_path_manager.check_path_writeable.return_value = None + + # Mock data preprocessors + mock_pt_instance = MagicMock() + mock_pt_instance.get_data_map.return_value = {"rank0": "data0"} + mock_pt_instance.get_data_type.return_value = "db" + mock_pt.return_value = mock_pt_instance + + mock_ms_instance = MagicMock() + mock_ms_instance.get_data_map.return_value = {} + mock_ms.return_value = mock_ms_instance + + mock_msprof_instance = MagicMock() + mock_msprof_instance.get_data_map.return_value = {} + mock_msprof.return_value = mock_msprof_instance + + # Mock file manager + mock_file_manager.create_output_dir.return_value = None + + # Mock analysis facade + mock_analysis_facade_instance = MagicMock() + mock_analysis_facade.return_value = mock_analysis_facade_instance + + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "communication_time", + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: self.output_path, + Constant.DATA_SIMPLIFICATION: True + } + + interface = Interface(params) + interface.run() + + # Verify communication group generator is NOT called when data simplification is enabled + mock_comm_generator.assert_not_called() + + # Verify analysis facade is called + mock_analysis_facade.assert_called() + mock_analysis_facade_instance.cluster_analyze.assert_called() + + def test_run_with_all_mode(self): + """ + test Interface.run method with 'all' mode + """ + with patch(NAMESPACE + '.cluster_analysis.PathManager') as mock_path_manager, \ + patch(NAMESPACE + '.cluster_analysis.logger') as mock_logger, \ + patch(NAMESPACE + '.cluster_analysis.CommunicationGroupGenerator') as mock_comm_generator, \ + patch(NAMESPACE + '.cluster_analysis.AnalysisFacade') as mock_analysis_facade, \ + patch(NAMESPACE + '.cluster_analysis.FileManager') as mock_file_manager, \ + patch(NAMESPACE + '.cluster_analysis.PytorchDataPreprocessor') as mock_pt, \ + patch(NAMESPACE + '.cluster_analysis.MindsporeDataPreprocessor') as mock_ms, \ + patch(NAMESPACE + '.cluster_analysis.MsprofDataPreprocessor') as mock_msprof: + + # Mock path manager checks + mock_path_manager.check_input_directory_path.return_value = None + mock_path_manager.check_path_owner_consistent.return_value = None + mock_path_manager.check_path_writeable.return_value = None + + # Mock data preprocessors + mock_pt_instance = MagicMock() + mock_pt_instance.get_data_map.return_value = {"rank0": "data0"} + mock_pt_instance.get_data_type.return_value = "db" + mock_pt.return_value = mock_pt_instance + + mock_ms_instance = MagicMock() + mock_ms_instance.get_data_map.return_value = {} + mock_ms.return_value = mock_ms_instance + + mock_msprof_instance = MagicMock() + mock_msprof_instance.get_data_map.return_value = {} + mock_msprof.return_value = mock_msprof_instance + + # Mock file manager + mock_file_manager.create_output_dir.return_value = None + + # Mock communication group generator + mock_comm_generator_instance = MagicMock() + mock_comm_generator_instance.generate.return_value = {"comm_data": "test"} + mock_comm_generator.return_value = mock_comm_generator_instance + + # Mock analysis facade + mock_analysis_facade_instance = MagicMock() + mock_analysis_facade.return_value = mock_analysis_facade_instance + + params = { + Constant.PROFILING_PATH: self.profiling_path, + Constant.MODE: "all", + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: self.output_path + } + + interface = Interface(params) + interface.run() + + # Verify communication group generation for 'all' mode + mock_comm_generator.assert_called() + mock_comm_generator_instance.generate.assert_called() + + # Verify analysis facade + mock_analysis_facade.assert_called() + mock_analysis_facade_instance.cluster_analyze.assert_called() -- Gitee