From 86cc0e37b199a127d91a1ce67f39c3890d7088b1 Mon Sep 17 00:00:00 2001 From: fanglanyue Date: Mon, 25 Aug 2025 15:59:17 +0800 Subject: [PATCH] add ut for cluster analysis --- .../test_msprof_step_trace_time_adapter.py | 101 +++++ .../analysis/test_stage_group_analysis.py | 309 +++++++++++++++ .../analysis/test_step_trace_time_analysis.py | 358 ++++++++++++++++++ .../test_step_trace_time_analysis.py | 81 ---- 4 files changed, 768 insertions(+), 81 deletions(-) create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_msprof_step_trace_time_adapter.py create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_stage_group_analysis.py create mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_step_trace_time_analysis.py delete mode 100644 profiler/msprof_analyze/test/ut/cluster_analyse/cluster_data_preprocess/test_step_trace_time_analysis.py diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_msprof_step_trace_time_adapter.py b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_msprof_step_trace_time_adapter.py new file mode 100644 index 00000000000..1be5db3a636 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_msprof_step_trace_time_adapter.py @@ -0,0 +1,101 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest import mock + +from msprof_analyze.cluster_analyse.analysis.msprof_step_trace_time_adapter import (MsprofStepTraceTimeAdapter, + MsprofStepTraceTimeDBAdapter) +from msprof_analyze.cluster_analyse.common_func.time_range_calculator import TimeRange +from msprof_analyze.prof_common.constant import Constant + + +class TestMsprofStepTraceTimeAdapter(unittest.TestCase): + + def test_generate_step_trace_time_data_when_json_has_events_then_aggregates_correctly(self): + """Basic aggregation for MsprofStepTraceTimeAdapter with mocked json input.""" + + mocked_events = [ + {"name": MsprofStepTraceTimeAdapter.COMMUNICATION, "dur": "30"}, + {"name": MsprofStepTraceTimeAdapter.COMMUNICATION, "dur": "20"}, + {"name": MsprofStepTraceTimeAdapter.COMPUTE, "dur": "100"}, + {"name": MsprofStepTraceTimeAdapter.FREE, "dur": "50"}, + {"name": MsprofStepTraceTimeAdapter.COMM_NOT_OVERLAP, "dur": "15"}, + {"name": "hcom_receive_287", "dur": "5"} + ] + + with mock.patch( + "msprof_analyze.prof_common.file_manager.FileManager.read_json_file", + return_value=mocked_events, + ): + adapter = MsprofStepTraceTimeAdapter(["test.json"]) + beans = adapter.generate_step_trace_time_data() + + # Ensure we created exactly one bean and captured data + expect_headers = ['Step', 'Type', 'Index', 'Computing', 'Communication(Not Overlapped)', + 'Overlapped', 'Communication', 'Free', 'Stage', 'Bubble', + 'Communication(Not Overlapped and Exclude Receive)', 'Preparing'] + expect_row = [100.0, 15.0, 35.0, 50.0, 50.0, 160.0, 5.0, 10.0, 0.0] + self.assertEqual(len(beans), 1) + step_bean = beans[0] + self.assertEqual(step_bean.all_headers, expect_headers) + self.assertAlmostEqual(step_bean.row, expect_row) + + +class TestMsprofStepTraceTimeDBAdapter(unittest.TestCase): + + def setUp(self): + self.communication_op_info = [[0, 10, 20], [1, 25, 30], [2, 35, 45], [3, 45, 50]] + self.compute_task_info = [[0, 10], [15, 20], [25, 35]] + self.string_id_map = {0: "hcom_send_0", 1: "hcom_receive_0_", 2: "hcom_broadcast_0", 3: "hcom_reduce_0"} + + def test_get_compute_data_when_valid_compute_data_then_return_list_time_range(self): + db_adapter = MsprofStepTraceTimeDBAdapter("test.db") + db_adapter.compute_task_info = self.compute_task_info + res = db_adapter._get_compute_data() + self.assertEqual(len(res), 3) + self.assertIsInstance(res[0], TimeRange) + + def test_get_communication_data_when_valid_data_then_return_comm_bubble_time_range(self): + db_adapter = MsprofStepTraceTimeDBAdapter("test.db") + db_adapter.communication_op_info = self.communication_op_info + db_adapter.string_id_map = self.string_id_map + comm_data, buble_data = db_adapter._get_communication_data() + self.assertEqual(len(comm_data), 4) + self.assertEqual(len(buble_data), 1) + self.assertIsInstance(comm_data[0], TimeRange) + self.assertIsInstance(buble_data[0], TimeRange) + self.assertEqual(buble_data[0].start_ts, 25) + + def test_generate_step_trace_time_data(self): + with mock.patch.object(MsprofStepTraceTimeDBAdapter, "_init_task_info_from_db", return_value=None): + db_adapter = MsprofStepTraceTimeDBAdapter("test.db") + # Directly inject prepared data + db_adapter.communication_op_info = self.communication_op_info + db_adapter.compute_task_info = self.compute_task_info + db_adapter.string_id_map = self.string_id_map + + result = db_adapter.generate_step_trace_time_data() + + self.assertEqual(len(result), 1) + row = result[0] + comm_total_ns = (20 - 10) + (30 - 25) + (50 - 45) + (45 - 35) + bubble_ns = (30 - 25) # only the receive op + self.assertAlmostEqual(row[4], comm_total_ns / Constant.NS_TO_US) # Communication total + self.assertAlmostEqual(row[7], bubble_ns / Constant.NS_TO_US) # Bubble + + self.assertAlmostEqual(row[3], row[4] - row[2]) # overlapped = communication - comm_not_overlap + self.assertAlmostEqual(row[8], row[2] - row[7]) # comm_not_overlap_excl_recv = comm_not_overlap - bubble + diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_stage_group_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_stage_group_analysis.py new file mode 100644 index 00000000000..e6be8205959 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_stage_group_analysis.py @@ -0,0 +1,309 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +from unittest.mock import patch, MagicMock, mock_open +import pandas as pd + +from msprof_analyze.cluster_analyse.analysis.stage_group_analysis import StageInfoAnalysis +from msprof_analyze.cluster_analyse.common_func.table_constant import TableConstant +from msprof_analyze.prof_common.constant import Constant + + +class TestStageInfoAnalysis(unittest.TestCase): + """Unit tests for StageInfoAnalysis class""" + + def setUp(self): + """Set up test fixtures before each test method""" + self.base_param = { + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: "/test/path", + Constant.DATA_TYPE: Constant.TEXT, + Constant.DATA_SIMPLIFICATION: False, + Constant.COMM_DATA_DICT: {} + } + + # Sample communication group data + self.sample_comm_group_data = [ + { + TableConstant.TYPE: Constant.COLLECTIVE, + TableConstant.RANK_SET: {0, 1, 2, 3}, + TableConstant.GROUP_NAME: "group1", + TableConstant.GROUP_ID: "g1", + TableConstant.PG_NAME: "default_group" + }, + { + TableConstant.TYPE: Constant.P2P, + TableConstant.RANK_SET: {0, 2}, + TableConstant.GROUP_NAME: "group2", + TableConstant.GROUP_ID: "g2", + TableConstant.PG_NAME: "pp" + }, + { + TableConstant.TYPE: Constant.P2P, + TableConstant.RANK_SET: {1, 3}, + TableConstant.GROUP_NAME: "group3", + TableConstant.GROUP_ID: "g3", + TableConstant.PG_NAME: "pp" + }, + { + TableConstant.TYPE: Constant.COLLECTIVE, + TableConstant.RANK_SET: {0, 1}, + TableConstant.GROUP_NAME: "group4", + TableConstant.GROUP_ID: "g4", + TableConstant.PG_NAME: "dp" + }, + { + TableConstant.TYPE: Constant.COLLECTIVE, + TableConstant.RANK_SET: {2, 3}, + TableConstant.GROUP_NAME: "group5", + TableConstant.GROUP_ID: "g5", + TableConstant.PG_NAME: "dp" + } + ] + + def test_init_when_valid_params_then_initialize_correctly(self): + """Test initialization with valid parameters""" + param = self.base_param.copy() + stage_analysis = StageInfoAnalysis(param) + self.assertEqual(stage_analysis.cluster_analysis_output_path, "/test/path") + self.assertEqual(stage_analysis.data_type, Constant.TEXT) + self.assertFalse(stage_analysis.simplified_mode) + self.assertEqual(stage_analysis.communication_data_dict, {}) + self.assertEqual(stage_analysis.collective_group_dict, {}) + self.assertEqual(stage_analysis.p2p_link, []) + self.assertEqual(stage_analysis.p2p_union_group, []) + self.assertEqual(stage_analysis.stage_group, []) + + def test_prepare_data_when_comm_data_in_dict_then_extract_successfully(self): + """Test prepare_data when communication data is provided in dict""" + param = self.base_param.copy() + param[Constant.COMM_DATA_DICT] = { + Constant.KEY_COMM_GROUP_PARALLEL_INFO: self.sample_comm_group_data + } + stage_analysis = StageInfoAnalysis(param) + result = stage_analysis.prepare_data() + + self.assertTrue(result) + self.assertEqual(len(stage_analysis.collective_group_dict), 3) + self.assertEqual(len(stage_analysis.p2p_link), 2) + + def test_prepare_data_when_no_comm_data_then_return_false(self): + """Test prepare_data when no communication data available""" + # Given no communication data + stage_analysis = StageInfoAnalysis(self.base_param) + + # When calling prepare_data with mocked load_communication_group_df returning None + with patch.object(stage_analysis, 'load_communication_group_df', return_value=None): + result = stage_analysis.prepare_data() + + # Then should return False + self.assertFalse(result) + + def test_extract_infos_when_valid_dataframe_then_extract_correctly(self): + """Test extract_infos with valid dataframe""" + # Given valid dataframe + df = pd.DataFrame(self.sample_comm_group_data) + stage_analysis = StageInfoAnalysis(self.base_param) + + # When calling extract_infos + result = stage_analysis.extract_infos(df) + + # Then should extract collective and p2p groups correctly + self.assertTrue(result) + self.assertEqual(len(stage_analysis.collective_group_dict), 3) + self.assertEqual(len(stage_analysis.p2p_link), 2) + self.assertIn("group1", stage_analysis.collective_group_dict) + self.assertEqual(stage_analysis.collective_group_dict["group1"], {0, 1, 2, 3}) + + def test_extract_infos_when_no_p2p_groups_then_return_false(self): + """Test extract_infos when no p2p groups found""" + data = [ + { + TableConstant.TYPE: Constant.COLLECTIVE, + TableConstant.RANK_SET: {0, 1, 2, 3}, + TableConstant.GROUP_NAME: "group1", + TableConstant.GROUP_ID: "g1", + TableConstant.PG_NAME: "default_group" + } + ] + df = pd.DataFrame(data) + stage_analysis = StageInfoAnalysis(self.base_param) + result = stage_analysis.extract_infos(df) + # Then should return False due to no p2p groups + self.assertFalse(result) + self.assertEqual(len(stage_analysis.collective_group_dict), 1) + self.assertEqual(len(stage_analysis.p2p_link), 0) + + def test_extract_infos_when_none_dataframe_then_return_false(self): + """Test extract_infos with None dataframe""" + # Given None dataframe, should return False + stage_analysis = StageInfoAnalysis(self.base_param) + result = stage_analysis.extract_infos(None) + self.assertFalse(result) + + def test_generate_p2p_union_group_when_disconnected_groups_then_create_separate_groups(self): + """Test generate_p2p_union_group with disconnected p2p groups""" + stage_analysis = StageInfoAnalysis(self.base_param) + stage_analysis.p2p_link = [{0, 1}, {2, 3}, {4, 5}] + stage_analysis.generate_p2p_union_group() + self.assertEqual(len(stage_analysis.p2p_union_group), 3) + self.assertIn({0, 1}, stage_analysis.p2p_union_group) + self.assertIn({2, 3}, stage_analysis.p2p_union_group) + self.assertIn({4, 5}, stage_analysis.p2p_union_group) + + def test_generate_p2p_union_group_when_connected_groups_then_merge_correctly(self): + """Test generate_p2p_union_group with connected p2p groups""" + stage_analysis = StageInfoAnalysis(self.base_param) + stage_analysis.p2p_link = [{0, 1}, {1, 2}, {3, 4}] + stage_analysis.generate_p2p_union_group() + self.assertEqual(len(stage_analysis.p2p_union_group), 2) + # {0,1} and {1,2} should be merged into {0,1,2} + self.assertIn({0, 1, 2}, stage_analysis.p2p_union_group) + self.assertIn({3, 4}, stage_analysis.p2p_union_group) + + def test_generate_stage_group_when_valid_collective_groups_then_generate_stages(self): + """Test generate_stage_group with valid collective groups""" + stage_analysis = StageInfoAnalysis(self.base_param) + stage_analysis.collective_group_dict = { + "group1": {4, 5}, + "group2": {5, 7}, + "group3": {0, 1, 2, 3, 4, 5, 6, 7}, + "group4": {0, 1}, + "group5": {1, 3}, + "group6": {2, 3}, + "group7": {6, 7}, + } + stage_analysis.p2p_union_group = [{0, 4}, {1, 5}, {2, 6}, {3, 7}] + stage_analysis.generate_stage_group() + self.assertEqual(len(stage_analysis.stage_group), 2) + # Each collective group should become a stage + self.assertIn([0, 1, 2, 3], stage_analysis.stage_group) + self.assertIn([4, 5, 6, 7], stage_analysis.stage_group) + + def test_whether_valid_comm_group_when_valid_group_then_return_true(self): + """Test whether_valid_comm_group with valid communication group""" + stage_analysis = StageInfoAnalysis(self.base_param) + stage_analysis.p2p_union_group = [{0, 1}, {2, 3}] + rank_set = {0, 4, 5} # Only intersects with one p2p group + result = stage_analysis.whether_valid_comm_group(rank_set) + self.assertTrue(result) + + def test_whether_valid_comm_group_when_invalid_group_then_return_false(self): + """Test whether_valid_comm_group with invalid communication group""" + # Given invalid communication group + stage_analysis = StageInfoAnalysis(self.base_param) + stage_analysis.p2p_union_group = [{0, 1}, {2, 3}] + rank_set = {0, 1, 2} # Intersects with multiple p2p groups + result = stage_analysis.whether_valid_comm_group(rank_set) + self.assertFalse(result) + + @patch('os.path.exists') + def test_load_communication_group_df_when_cluster_output_not_exist_when_return_none(self, mock_exists): + mock_exists.return_value = False + stage_analysis = StageInfoAnalysis(self.base_param) + result = stage_analysis.load_communication_group_df() + self.assertIsNone(result) + + @patch('os.path.exists') + @patch('msprof_analyze.prof_common.file_manager.FileManager.read_json_file') + def test_load_communication_group_df_for_text_when_valid_file_then_load_successfully(self, mock_json_load, + mock_exists): + """Test load_communication_group_df_for_text with valid file""" + # Given valid file and data + mock_exists.return_value = True + mock_json_load.return_value = { + Constant.KEY_COMM_GROUP_PARALLEL_INFO: self.sample_comm_group_data + } + stage_analysis = StageInfoAnalysis(self.base_param) + result = stage_analysis.load_communication_group_df_for_text() + + self.assertIsNotNone(result) + self.assertEqual(len(result), 5) + self.assertIn(TableConstant.TYPE, result.columns) + + @patch('os.path.exists') + @patch('msprof_analyze.prof_common.file_manager.FileManager.read_json_file') + def test_load_communication_group_df_for_text_when_file_not_exists_then_return_none(self, mock_json_load, + mock_exists): + """Test load_communication_group_df_for_text when file doesn't exist""" + # Mock path doesn't exist + mock_exists.return_value = False + stage_analysis = StageInfoAnalysis(self.base_param) + result = stage_analysis.load_communication_group_df_for_text() + self.assertIsNone(result) + # Mock path exist but json doesn't have parallel info empty + mock_exists.return_value = True + mock_json_load.return_value = {Constant.KEY_COMM_GROUP_PARALLEL_INFO: []} + stage_analysis = StageInfoAnalysis(self.base_param) + result = stage_analysis.load_communication_group_df_for_text() + self.assertIsNone(result) + + @patch('os.path.exists') + @patch('msprof_analyze.prof_common.database_service.DatabaseService.query_data') + def test_load_communication_group_df_for_db_when_valid_db_then_load_successfully(self, mock_query, mock_exists): + """Test load_communication_group_df_for_db with valid database""" + mock_exists.return_value = True + comm_group_df = pd.DataFrame(self.sample_comm_group_data) + comm_group_df["rank_set"] = comm_group_df["rank_set"].apply(lambda x: "(" + ",".join(str(i) for i in x) + ")") + mock_query.return_value = { + Constant.TABLE_COMMUNICATION_GROUP: comm_group_df + } + param = self.base_param.copy() + param[Constant.DATA_TYPE] = Constant.DB + stage_analysis = StageInfoAnalysis(param) + result = stage_analysis.load_communication_group_df_for_db() + self.assertIsNotNone(result) + self.assertEqual(len(result), 5) + + @patch('os.path.exists') + @patch('msprof_analyze.prof_common.database_service.DatabaseService.query_data') + def test_load_communication_group_df_for_db_when_dir_not_exists_then_return_none(self, mock_query, mock_exists): + """Test load_communication_group_df_for_db when directory doesn't exist""" + # Given directory doesn't exist + mock_exists.return_value = False + param = self.base_param.copy() + param[Constant.DATA_TYPE] = Constant.DB + stage_analysis = StageInfoAnalysis(param) + result = stage_analysis.load_communication_group_df_for_db() + self.assertIsNone(result) + # Mock path exist but json doesn't have parallel info empty + mock_exists.return_value = True + mock_query.return_value = {} + param = self.base_param.copy() + param[Constant.DATA_SIMPLIFICATION] = True + stage_analysis = StageInfoAnalysis(param) + result = stage_analysis.load_communication_group_df_for_db() + self.assertIsNone(result) + + def test_run_when_prepare_data_succeeds_then_return_stage_group(self): + """Test run method when prepare_data succeeds""" + # Given successful data preparation + param = self.base_param.copy() + param[Constant.COMM_DATA_DICT] = { + Constant.KEY_COMM_GROUP_PARALLEL_INFO: self.sample_comm_group_data + } + stage_analysis = StageInfoAnalysis(param) + + result = stage_analysis.run() + self.assertIsInstance(result, list) + self.assertEqual(len(result), 2) # One collective group becomes one stage + + def test_run_when_prepare_data_fails_then_return_empty_list(self): + """Test run method when prepare_data fails""" + stage_analysis = StageInfoAnalysis(self.base_param) + with patch.object(stage_analysis, 'prepare_data', return_value=False): + result = stage_analysis.run() + self.assertEqual(result, []) diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_step_trace_time_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_step_trace_time_analysis.py new file mode 100644 index 00000000000..75d37c55054 --- /dev/null +++ b/profiler/msprof_analyze/test/ut/cluster_analyse/analysis/test_step_trace_time_analysis.py @@ -0,0 +1,358 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import tempfile +from unittest.mock import patch + +from msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis import StepTraceTimeAnalysis +from msprof_analyze.cluster_analyse.prof_bean.step_trace_time_bean import StepTraceTimeBean +from msprof_analyze.prof_common.constant import Constant + + +def _build_analysis(**kwargs): + # Build analysis instance with defaults + params = { + Constant.COLLECTION_PATH: str(kwargs.get("collection_path", "")), + Constant.CLUSTER_ANALYSIS_OUTPUT_PATH: str(kwargs.get("output_path", "")), + Constant.DATA_MAP: kwargs.get("data_map", {}), + Constant.COMM_DATA_DICT: kwargs.get("comm_data_dict", {}), + Constant.DATA_TYPE: kwargs.get("data_type", Constant.TEXT), + Constant.DATA_SIMPLIFICATION: kwargs.get("data_simplification", False), + Constant.IS_MSPROF: kwargs.get("is_msprof", False), + Constant.IS_MINDSPORE: kwargs.get("is_mindspore", False), + } + return StepTraceTimeAnalysis(params) + + +class TestStepTraceTimeAnalysis(unittest.TestCase): + DIR_PATH = '' + + def test_get_max_data_row_when_given_data_return_max_rows(self): + check = StepTraceTimeAnalysis({}) + ls = [ + [1, 3, 5, 7, 10], + [2, 4, 6, 8, 11], + [1000, -1, -1, -1, -1] + ] + ret = check.get_max_data_row(ls) + self.assertEqual([1000, 4, 6, 8, 11], ret) + + def test_get_max_data_when_given_row_single_ls_return_this_row(self): + check = StepTraceTimeAnalysis({}) + ls = [ + [1, 3, 5, 7, 10] + ] + ret = check.get_max_data_row(ls) + self.assertEqual([1, 3, 5, 7, 10], ret) + + def test_analyze_step_time_when_give_normal_expect_stage(self): + check = StepTraceTimeAnalysis({}) + check.data_type = Constant.TEXT + check.step_time_dict = { + 0: [ + StepTraceTimeBean({"Step": 0, "time1": 1, "time2": 2}), + StepTraceTimeBean({"Step": 1, "time1": 1, "time2": 2}), + ], + 1: [ + StepTraceTimeBean({"Step": 0, "time1": 10, "time2": 20}), + StepTraceTimeBean({"Step": 1, "time1": 10, "time2": 20}) + ] + } + check.communication_data_dict = {Constant.STAGE: [[0, 1]]} + check.analyze_step_time() + self.assertIn([0, 'stage', (0, 1), 10.0, 20.0], check.step_data_list) + + def test_analyze_step_time_when_given_none_step_expect_stage_and_rank_row(self): + check = StepTraceTimeAnalysis({}) + check.data_type = Constant.TEXT + check.step_time_dict = { + 0: [ + StepTraceTimeBean({"Step": None, "time1": 1, "time2": 2}) + ], + 1: [ + StepTraceTimeBean({"Step": None, "time1": 10, "time2": 20}), + ], + 2: [ + StepTraceTimeBean({"Step": None, "time1": 2, "time2": 3}), + ], + 3: [ + StepTraceTimeBean({"Step": None, "time1": 1, "time2": 1}), + ], + } + check.communication_data_dict = {Constant.STAGE: [[0, 1], [2, 3]]} + check.analyze_step_time() + self.assertIn([None, 'stage', (2, 3), 2.0, 3.0], check.step_data_list) + self.assertIn([None, 'rank', 0, 1.0, 2.0], check.step_data_list) + + def test_find_msprof_json_when_multi_msprof_json_timestamps_then_return_latest_files(self): + # Create two timestamped files and expect the latest one + with tempfile.TemporaryDirectory() as d: + older = os.path.join(d, "msprof_20240101010101.json") + newer = os.path.join(d, "msprof_20250101010101.json") + with open(older, "w", encoding="utf-8") as f: + f.write("{}") + with open(newer, "w", encoding="utf-8") as f: + f.write("{}") + analysis = _build_analysis() + ret = analysis.find_msprof_json(d) + self.assertEqual(len(ret), 1) + self.assertEqual(os.path.basename(ret[0]), "msprof_20250101010101.json") + + def test_find_msprof_json_when_multi_msprof_slice_json_then_return_latest_files(self): + # Create two timestamped files and expect the latest one + with tempfile.TemporaryDirectory() as d: + older = os.path.join(d, "msprof_slice_0_20240101010101.json") + newer = os.path.join(d, "msprof_slice_0_20250101010101.json") + with open(older, "w", encoding="utf-8") as f: + f.write("{}") + with open(newer, "w", encoding="utf-8") as f: + f.write("{}") + analysis = _build_analysis() + ret = analysis.find_msprof_json(d) + self.assertEqual(len(ret), 1) + self.assertEqual(os.path.basename(ret[0]), "msprof_slice_0_20250101010101.json") + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.ParallelStrategyCalculator") + def test_partition_ranks_data_when_parallel_map_available_then_append_parallel_columns(self, mock_calc_cls): + # Simulate parallel strategy result and validate appended columns + analysis = _build_analysis(data_type=Constant.TEXT) + analysis.distributed_args = {"dummy": True} + analysis.step_time_dict = { + 0: [StepTraceTimeBean({"Step": 1, "time1": 7, "time2": 8})], + 1: [StepTraceTimeBean({"Step": 1, "time1": 5, "time2": 10})] + } + analysis.step_data_list = [ + [1, Constant.RANK, 0, 1], + [1, Constant.RANK, 1, 2], + ] + + instance = mock_calc_cls.return_value + instance.run.return_value = {0: (0, 1, 2), 1: (3, 4, 5)} + + analysis.partition_ranks_data() + # Each rank row should be extended by 3 parallel columns + self.assertEqual(analysis.step_data_list[0][-3:], [0, 1, 2]) + self.assertEqual(analysis.step_data_list[1][-3:], [3, 4, 5]) + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.ParallelStrategyCalculator") + def test_partition_ranks_data_when_parallel_map_not_available_then_return(self, mock_calc_cls): + analysis = _build_analysis(data_type=Constant.TEXT) + # distributed_args is None + analysis.partition_ranks_data() + mock_calc_cls.assert_not_called() + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.FileManager.create_csv_file") + def test_dump_data_when_text_then_write_csv_with_headers(self, mock_create_csv): + # Verify CSV writer called with proper headers + with tempfile.TemporaryDirectory() as d: + output_dir = os.path.join(d, "out") + os.makedirs(output_dir, exist_ok=True) + analysis = _build_analysis(data_type=Constant.TEXT, output_path=str(output_dir)) + + analysis.step_data_list = [[1, Constant.RANK, 0, 1, 2, 3]] + fake_bean = StepTraceTimeBean({"Step": 1, "time1": 7, "time2": 8}) + analysis.step_time_dict = {0: [fake_bean]} + analysis.dump_data() + + self.assertTrue(mock_create_csv.called) + args, _ = mock_create_csv.call_args # args: (path, rows, filename, headers) + self.assertEqual(args[2], analysis.CLUSTER_TRACE_TIME_CSV) + self.assertEqual(args[3], fake_bean.all_headers) + self.assertEqual(args[1], analysis.step_data_list) + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.MsprofStepTraceTimeAdapter") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.StepTraceTimeAnalysis.find_msprof_json") + def test_load_step_trace_time_data_when_text_msprof_then_populates_dict(self, mock_find_json, mock_adapter_cls): + # Should populate step_time_dict for TEXT + is_msprof flow + with tempfile.TemporaryDirectory() as d: + # Prepare data_map: rank_id -> profiling dir + profiling_dir = os.path.join(d, "rank0") + os.makedirs(os.path.join(profiling_dir, "mindstudio_profiler_output"), exist_ok=True) + data_map = {0: profiling_dir} + + analysis = _build_analysis(data_type=Constant.TEXT, is_msprof=True, data_map=data_map) + + # Mock json discovery and adapter return + mock_find_json.return_value = [os.path.join(profiling_dir, "mindstudio_profiler_output", + "msprof_20240101010101.json")] + adapter_instance = mock_adapter_cls.return_value + adapter_instance.generate_step_trace_time_data.return_value = [StepTraceTimeBean({"Step": 1, "time1": 7, + "time2": 8})] + + analysis.load_step_trace_time_data() + self.assertIn(0, analysis.step_time_dict) + self.assertEqual(len(analysis.step_time_dict[0]), 1) + self.assertIsInstance(analysis.step_time_dict[0][0], StepTraceTimeBean) + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.FileManager.read_csv_file") + def test_load_step_trace_time_data_when_text_plain_then_reads_csv(self, mock_read_csv): + # Should read csv for TEXT non-msprof when file exists + with tempfile.TemporaryDirectory() as d: + profiling_dir = os.path.join(d, "rank0") + single_output = os.path.join(profiling_dir, Constant.SINGLE_OUTPUT) + os.makedirs(single_output, exist_ok=True) + step_time_csv = os.path.join(single_output, Constant.STEP_TIME_CSV) + with open(step_time_csv, "w", encoding="utf-8") as f: + f.write("Step,time1,time2\n") + f.write("1,7,8\n") + + mock_read_csv.return_value = [StepTraceTimeBean({"Step": 1, "time1": 7, "time2": 8})] + + data_map = {0: profiling_dir} + analysis = _build_analysis(data_type=Constant.TEXT, is_msprof=False, data_map=data_map) + analysis.load_step_trace_time_data() + + self.assertIn(0, analysis.step_time_dict) + self.assertEqual(len(analysis.step_time_dict[0]), 1) + self.assertIsInstance(analysis.step_time_dict[0][0], StepTraceTimeBean) + + def test_load_step_trace_time_data_when_text_plain_file_missing_then_empty(self): + # Should not populate when csv missing + with tempfile.TemporaryDirectory() as d: + profiling_dir = os.path.join(d, "rank0") + data_map = {0: profiling_dir} + analysis = _build_analysis(data_type=Constant.TEXT, is_msprof=False, data_map=data_map) + analysis.load_step_trace_time_data() + self.assertNotIn(0, analysis.step_time_dict) + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.MsprofDataPreprocessor." + "get_msprof_profiler_db_path") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.MsprofStepTraceTimeDBAdapter") + def test_load_step_trace_time_data_when_db_msprof_then_use_db_adapter(self, mock_db_adapter_cls, mock_get_db_path): + # Should use DB adapter for DB + is_msprof + with tempfile.TemporaryDirectory() as d: + profiling_dir = os.path.join(d, "rank0") + data_map = {0: profiling_dir} + analysis = _build_analysis(data_type=Constant.DB, is_msprof=True, data_map=data_map) + + mock_get_db_path.return_value = os.path.join(profiling_dir, "profiler.db") + adapter_instance = mock_db_adapter_cls.return_value + adapter_instance.generate_step_trace_time_data.return_value = [(1, 2, 3)] + + analysis.load_step_trace_time_data() + self.assertIn(0, analysis.step_time_dict) + self.assertEqual(analysis.step_time_dict[0], [(1, 2, 3)]) + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.MsprofStepTraceTimeDBAdapter") + def test_load_step_trace_time_data_when_db_mindspore_then_use_db_adapter(self, mock_db_adapter_cls): + # Should use DB adapter for DB + is_mindspore + with tempfile.TemporaryDirectory() as d: + profiling_dir = os.path.join(d, "rank0") + single_output = os.path.join(profiling_dir, Constant.SINGLE_OUTPUT) + os.makedirs(single_output, exist_ok=True) + data_map = {0: profiling_dir} + analysis = _build_analysis(data_type=Constant.DB, is_mindspore=True, data_map=data_map) + + adapter_instance = mock_db_adapter_cls.return_value + adapter_instance.generate_step_trace_time_data.return_value = [(4, 5, 6)] + + analysis.load_step_trace_time_data() + self.assertIn(0, analysis.step_time_dict) + self.assertEqual(analysis.step_time_dict[0], [(4, 5, 6)]) + self.assertTrue(mock_db_adapter_cls.called) + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.check_tables_in_db") + def test_load_step_trace_time_data_when_db_plain_no_table_then_empty(self, mock_check): + # Should not populate when table missing or file missing + with tempfile.TemporaryDirectory() as d: + profiling_dir = os.path.join(d, "rank0") + single_output = os.path.join(profiling_dir, Constant.SINGLE_OUTPUT) + os.makedirs(single_output, exist_ok=True) + analysis_db = os.path.join(single_output, Constant.DB_COMMUNICATION_ANALYZER) + mock_check.return_value = False + + data_map = {0: profiling_dir} + analysis = _build_analysis(data_type=Constant.DB, data_map=data_map) + analysis.load_step_trace_time_data() + self.assertFalse(len(analysis.step_time_dict)) + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.FileManager.read_json_file") + def test_load_step_trace_time_data_when_metadata_present_then_set_distributed_args(self, mock_read_json): + # Should set distributed_args from profiler_metadata.json when present + with tempfile.TemporaryDirectory() as d: + profiling_dir = os.path.join(d, "rank0") + os.makedirs(profiling_dir, exist_ok=True) + # Create metadata file path + metadata_path = os.path.join(profiling_dir, StepTraceTimeAnalysis.PROFILER_METADATA_JSON) + with open(metadata_path, "w", encoding="utf-8") as f: + f.write("{}") + + dist_args = {"dp": 2, "pp": 1, "tp": 4} + mock_read_json.return_value = {Constant.DISTRIBUTED_ARGS: dist_args} + + data_map = {0: profiling_dir} + analysis = _build_analysis(data_type=Constant.TEXT, data_map=data_map) + self.assertIsNone(analysis.distributed_args) + + analysis.load_step_trace_time_data() + self.assertEqual(analysis.distributed_args, dist_args) + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.destroy_db_connect") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.executemany_sql") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.create_connect_db") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.get_table_column_count") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.create_tables") + def test_dump_data_when_db_and_rows_short_then_pad_and_insert(self, mock_create_tables, mock_col_count, + mock_connect, mock_execmany, mock_destroy): + # Validate that rows are padded to match table columns and inserted + with tempfile.TemporaryDirectory() as d: + output_dir = d + analysis = _build_analysis(data_type=Constant.DB, output_path=str(output_dir)) + # Two rows with length 5; table expects 8 + analysis.step_data_list = [ + [1, Constant.RANK, 0, 1, 2], + [1, Constant.RANK, 1, 3, 4], + ] + + mock_col_count.return_value = 8 + mock_connect.return_value = (object(), object()) + analysis.dump_data() + + # Verify padding happened before insert + args, _ = mock_execmany.call_args + conn_arg, sql_arg, data_arg = args + self.assertIsNotNone(conn_arg) + self.assertIn("values (?,?,?,?,?,?,?,?)", sql_arg) # Expect 8 placeholders for insert + self.assertTrue(all(len(row) == 8 for row in data_arg)) # Each row should be length 8 now + + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.destroy_db_connect") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.executemany_sql") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.create_connect_db") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.get_table_column_count") + @patch("msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis.DBManager.create_tables") + def test_dump_data_when_db_and_rows_long_enough_then_insert_without_padding(self, mock_create_tables, + mock_col_count, mock_connect, + mock_execmany, mock_destroy): + # Validate no padding when row length >= column count + with tempfile.TemporaryDirectory() as d: + output_dir = d + analysis = _build_analysis(data_type=Constant.DB, output_path=str(output_dir)) + # Rows length 6; table expects 6 + analysis.step_data_list = [ + [1, Constant.RANK, 0, 1, 2, 3], + [1, Constant.RANK, 1, 4, 5, 6], + ] + + mock_col_count.return_value = 6 + mock_connect.return_value = (object(), object()) + + analysis.dump_data() + + args, _ = mock_execmany.call_args + _, sql_arg, data_arg = args + self.assertIn("values (?,?,?,?,?,?)", sql_arg) + self.assertEqual(data_arg, analysis.step_data_list) diff --git a/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_data_preprocess/test_step_trace_time_analysis.py b/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_data_preprocess/test_step_trace_time_analysis.py deleted file mode 100644 index 067886ec201..00000000000 --- a/profiler/msprof_analyze/test/ut/cluster_analyse/cluster_data_preprocess/test_step_trace_time_analysis.py +++ /dev/null @@ -1,81 +0,0 @@ -# Copyright (c) 2025, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest - -from msprof_analyze.cluster_analyse.analysis.step_trace_time_analysis import StepTraceTimeAnalysis -from msprof_analyze.cluster_analyse.prof_bean.step_trace_time_bean import StepTraceTimeBean -from msprof_analyze.prof_common.constant import Constant - - -class TestStepTraceTimeAnalysis(unittest.TestCase): - DIR_PATH = '' - - def test_get_max_data_row_when_given_data_return_max_rows(self): - check = StepTraceTimeAnalysis({}) - ls = [ - [1, 3, 5, 7, 10], - [2, 4, 6, 8, 11], - [1000, -1, -1, -1, -1] - ] - ret = check.get_max_data_row(ls) - self.assertEqual([1000, 4, 6, 8, 11], ret) - - def test_get_max_data_when_given_row_single_ls_return_this_row(self): - check = StepTraceTimeAnalysis({}) - ls = [ - [1, 3, 5, 7, 10] - ] - ret = check.get_max_data_row(ls) - self.assertEqual([1, 3, 5, 7, 10], ret) - - def test_analyze_step_time_when_give_normal_expect_stage(self): - check = StepTraceTimeAnalysis({}) - check.data_type = Constant.TEXT - check.step_time_dict = { - 0: [ - StepTraceTimeBean({"Step": 0, "time1": 1, "time2": 2}), - StepTraceTimeBean({"Step": 1, "time1": 1, "time2": 2}), - ], - 1: [ - StepTraceTimeBean({"Step": 0, "time1": 10, "time2": 20}), - StepTraceTimeBean({"Step": 1, "time1": 10, "time2": 20}) - ] - } - check.communication_data_dict = {Constant.STAGE: [[0, 1]]} - check.analyze_step_time() - self.assertIn([0, 'stage', (0, 1), 10.0, 20.0], check.step_data_list) - - def test_analyze_step_time_when_given_none_step_expect_stage_and_rank_row(self): - check = StepTraceTimeAnalysis({}) - check.data_type = Constant.TEXT - check.step_time_dict = { - 0: [ - StepTraceTimeBean({"Step": None, "time1": 1, "time2": 2}) - ], - 1: [ - StepTraceTimeBean({"Step": None, "time1": 10, "time2": 20}), - ], - 2: [ - StepTraceTimeBean({"Step": None, "time1": 2, "time2": 3}), - ], - 3: [ - StepTraceTimeBean({"Step": None, "time1": 1, "time2": 1}), - ], - } - check.communication_data_dict = {Constant.STAGE: [[0, 1], [2, 3]]} - check.analyze_step_time() - self.assertIn([None, 'stage', (2, 3), 2.0, 3.0], check.step_data_list) - self.assertIn([None, 'rank', 0, 1.0, 2.0], check.step_data_list) \ No newline at end of file -- Gitee