From 7bc5fde61894d09cbe807e12a381022bdd0229d8 Mon Sep 17 00:00:00 2001 From: zhouxianqi <13165993773@163.com> Date: Tue, 11 Mar 2025 17:23:42 +0800 Subject: [PATCH] add_cluster_st --- ...luster_analyse_msprof_db_simplification.py | 94 +++++++++++++++++++ ...uster_analyse_pytorch_db_simplification.py | 93 ++++++++++++++++++ .../test_cluster_analyze_step_id_param.py | 0 3 files changed, 187 insertions(+) create mode 100644 profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_msprof_db_simplification.py create mode 100644 profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db_simplification.py create mode 100644 profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyze_step_id_param.py diff --git a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_msprof_db_simplification.py b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_msprof_db_simplification.py new file mode 100644 index 0000000000..a72a7c0d59 --- /dev/null +++ b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_msprof_db_simplification.py @@ -0,0 +1,94 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import TestCase + +import pandas as pd + +from msprof_analyze.prof_common.db_manager import DBManager +from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.test.st.utils import execute_cmd + + +class TestClusterAnalysePytorchDbSimplification(TestCase): + """ + Test cluster analyse msprof db in data simplification + """ + ST_DATA_PATH = os.getenv("MSTT_PROFILER_ST_DATA_PATH", + "/home/dcs-50/smoke_project_for_msprof_analyze/mstt_profiler/st_data/") + CLUSTER_PATH = os.path.join(ST_DATA_PATH, "msprof_db_cluster_data") + OUTPUT_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), "TestClusterAnalysePytorchDbSimplification") + COMMAND_SUCCESS = 0 + + def setup_class(self): + # generate db data + PathManager.make_dir_safety(self.ST_DATA_PATH) + cmd = ["msprof-analyze", "cluster", "-d", self.CLUSTER_PATH, "-m", "all", + "--output_path", self.OUTPUT_PATH, "--force", "--data_simplification"] + if execute_cmd(cmd) != self.COMMAND_SUCCESS or not os.path.exists(self.OUTPUT_PATH): + self.fail("pytorch db cluster analyse task failed.") + self.db_path = os.path.join(self.OUTPUT_PATH, "cluster_analysis_output", "cluster_analysis.db") + self.conn, self.cursor = DBManager.create_connect_db(self.db_path) + + def teardown_class(self): + # Delete db Data + DBManager.destroy_db_connect(self.conn, self.cursor) + PathManager.remove_path_safety(os.path.join(self.ST_DATA_PATH, "cluster_analysis_output")) + + def test_host_info_data(self): + query = "select hostName from HostInfo" + data = pd.read_sql(query, self.conn) + self.assertEqual(data["hostName"].tolist(), ["n122-120-121"]) + + def test_rank_device_map_data(self): + query = "select * from RankDeviceMap" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 16) + + def test_step_trace_time_data(self): + query = "select * from ClusterStepTraceTime" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 16) + self.assertTrue(14945901.524 in data["computing"].tolist()) + + def test_comm_group_map_data(self): + query = "select * from CommunicationGroupMapping" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 33) + data = data[data["group_name"] == '7519234732706649132'] + self.assertEqual(data["rank_set"].tolist(), ["(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)"]) + + def test_comm_matrix_data(self): + query = "SELECT * FROM ClusterCommunicationMatrix WHERE hccl_op_name = 'Total Op Info' " + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 144) + query = "SELECT transport_type, transit_size, transit_time, bandwidth FROM ClusterCommunicationMatrix WHERE " \ + "hccl_op_name='Total Op Info' and group_name='1046397798680881114' and src_rank=12 and dst_rank=4" + data = pd.read_sql(query, self.conn) + self.assertEqual(data.iloc[0].tolist(), ['RDMA', 59341.69862400028, 17684.277734, 3.3556190146182354]) + + def test_comm_time_data(self): + query = "select rank_id, count(0) cnt from ClusterCommunicationTime where hccl_op_name = " \ + "'Total Op Info' group by rank_id" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 16) + self.assertEqual(data["cnt"].tolist(), [4 for _ in range(16)]) + + def test_comm_bandwidth_data(self): + query = "select * from ClusterCommunicationBandwidth where hccl_op_name = 'Total Op Info' and " \ + "group_name='12703750860003234865' order by count" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 2) + self.assertEqual(data["count"].tolist(), [2, 36]) diff --git a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db_simplification.py b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db_simplification.py new file mode 100644 index 0000000000..3cdb2f3153 --- /dev/null +++ b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db_simplification.py @@ -0,0 +1,93 @@ +# Copyright (c) 2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from unittest import TestCase + +import pandas as pd +from msprof_analyze.prof_common.path_manager import PathManager +from msprof_analyze.test.st.utils import execute_cmd +from msprof_analyze.prof_common.db_manager import DBManager + + +class TestClusterAnalysePytorchDbSimplification(TestCase): + """ + Test cluster analyse pytorch db in data simplification + """ + ST_DATA_PATH = os.getenv("MSTT_PROFILER_ST_DATA_PATH", + "/home/dcs-50/smoke_project_for_msprof_analyze/mstt_profiler/st_data/") + CLUSTER_PATH = os.path.join(ST_DATA_PATH, "cluster_data_2_db") + OUTPUT_PATH = os.path.join(os.path.abspath(os.path.dirname(__file__)), "TestClusterAnalysePytorchDbSimplification") + COMMAND_SUCCESS = 0 + + def setup_class(self): + # generate db data + PathManager.make_dir_safety(self.ST_DATA_PATH) + cmd = ["msprof-analyze", "cluster", "-d", self.CLUSTER_PATH, "-m", "all", + "--output_path", self.OUTPUT_PATH, "--force", "--data_simplification"] + if execute_cmd(cmd) != self.COMMAND_SUCCESS or not os.path.exists(self.OUTPUT_PATH): + self.fail("pytorch db cluster analyse task failed.") + self.db_path = os.path.join(self.OUTPUT_PATH, "cluster_analysis_output", "cluster_analysis.db") + self.conn, self.cursor = DBManager.create_connect_db(self.db_path) + + def teardown_class(self): + # Delete db Data + DBManager.destroy_db_connect(self.conn, self.cursor) + PathManager.remove_path_safety(os.path.join(self.ST_DATA_PATH, "cluster_analysis_output")) + + def test_host_info_data(self): + query = "select hostName from HostInfo" + data = pd.read_sql(query, self.conn) + self.assertEqual(data["hostName"].tolist(), ["n122-120-121"]) + + def test_rank_device_map_data(self): + query = "select * from RankDeviceMap" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 16) + + def test_step_trace_time_data(self): + query = "select * from ClusterStepTraceTime" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 16) + self.assertTrue(14945901.524 in data["computing"].tolist()) + + def test_comm_group_map_data(self): + query = "select * from CommunicationGroupMapping" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 33) + data = data[data["group_name"] == '7519234732706649132'] + self.assertEqual(data["rank_set"].tolist(), ["(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)"]) + + def test_comm_matrix_data(self): + query = "SELECT * FROM ClusterCommunicationMatrix WHERE hccl_op_name = 'Total Op Info' " + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 144) + query = "SELECT transport_type, transit_size, transit_time, bandwidth FROM ClusterCommunicationMatrix WHERE " \ + "hccl_op_name='Total Op Info' and group_name='1046397798680881114' and src_rank=12 and dst_rank=4" + data = pd.read_sql(query, self.conn) + self.assertEqual(data.iloc[0].tolist(), ['RDMA', 59341.69862400028, 17684.277734, 3.3556190146182354]) + + def test_comm_time_data(self): + query = "select rank_id, count(0) cnt from ClusterCommunicationTime where hccl_op_name = " \ + "'Total Op Info' group by rank_id" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 16) + self.assertEqual(data["cnt"].tolist(), [4 for _ in range(16)]) + + def test_comm_bandwidth_data(self): + query = "select * from ClusterCommunicationBandwidth where hccl_op_name = 'Total Op Info' and " \ + "group_name='12703750860003234865' order by count" + data = pd.read_sql(query, self.conn) + self.assertEqual(len(data), 2) + self.assertEqual(data["count"].tolist(), [2, 36]) diff --git a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyze_step_id_param.py b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyze_step_id_param.py new file mode 100644 index 0000000000..e69de29bb2 -- Gitee