diff --git a/profiler/cluster_analyse/common_func/db_manager.py b/profiler/cluster_analyse/common_func/db_manager.py index 85bcf975d65a85b39640bfecd12e821c5e2a6a4b..be790e9298b4546902ab99a8f0a811e784363c69 100644 --- a/profiler/cluster_analyse/common_func/db_manager.py +++ b/profiler/cluster_analyse/common_func/db_manager.py @@ -132,11 +132,11 @@ class DBManager: conn, curs = cls.create_connect_db(db_path) if not (conn and curs): return 0 - sql = "SELECT COUNT(*) FROM pragma_table_info('{}')".format(table) + sql = f"PRAGMA table_info({table})" res = 0 try: curs.execute(sql) - res = curs.fetchone()[0] + res = len(curs.fetchall()) except sqlite3.Error as err: print("[ERROR] {}".format(err)) finally: diff --git a/profiler/test/st/cluster_analyse/__init__.py b/profiler/test/st/cluster_analyse/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/test/st/cluster_analyse/cluster_communication_matrixDb.py b/profiler/test/st/cluster_analyse/cluster_communication_matrixDb.py new file mode 100644 index 0000000000000000000000000000000000000000..337bb7891159445ddf013aaf56f0a796ddd0f87e --- /dev/null +++ b/profiler/test/st/cluster_analyse/cluster_communication_matrixDb.py @@ -0,0 +1,121 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Cluster communication matrix db class """ + + +class ClusterCommunicationMatrixDb: + def __init__(self, + step=None, + hccl_op_name=None, + group_name=None, + src_rank=None, + dst_rank=None, + transit_size=None, + transit_time=None, + bandwidth=None, + transport_type=None, + op_name=None): + self._step = step + self._hccl_op_name = hccl_op_name + self._group_name = group_name + self._src_rank = src_rank + self._dst_rank = dst_rank + self._transit_size = transit_size + self._transit_time = transit_time + self._bandwidth = bandwidth + self._transport_type = transport_type + self._op_name = op_name + + @property + def step(self): + return self._step + + @step.setter + def step(self, value): + self._step = value + + @property + def hccl_op_name(self): + return self._hccl_op_name + + @hccl_op_name.setter + def hccl_op_name(self, value): + self._hccl_op_name = value + + # group_name property + @property + def group_name(self): + return self._group_name + + @group_name.setter + def group_name(self, value): + self._group_name = value + + @property + def src_rank(self): + return self._src_rank + + @src_rank.setter + def src_rank(self, value): + self._src_rank = value + + @property + def dst_rank(self): + return self._dst_rank + + @dst_rank.setter + def dst_rank(self, value): + self._dst_rank = value + + @property + def transit_size(self): + return self._transit_size + + @transit_size.setter + def transit_size(self, value): + self._transit_size = value + + @property + def transit_time(self): + return self._transit_time + + @transit_time.setter + def transit_time(self, value): + self._transit_time = value + + @property + def bandwidth(self): + return self._bandwidth + + @bandwidth.setter + def bandwidth(self, value): + self._bandwidth = value + + @property + def transport_type(self): + return self._transport_type + + @transport_type.setter + def transport_type(self, value): + self._transport_type = value + + # op_name property + @property + def op_name(self): + return self._op_name + + @op_name.setter + def op_name(self, value): + self._op_name = value diff --git a/profiler/test/st/cluster_analyse/cluster_step_trace_time_db.py b/profiler/test/st/cluster_analyse/cluster_step_trace_time_db.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6d562896a043ffa8e5b4838e5f7dce546f679f --- /dev/null +++ b/profiler/test/st/cluster_analyse/cluster_step_trace_time_db.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Cluster step trace time class """ + + +class ClusterStepTraceTimeDb: + def __init__(self, + step=None, + type=None, + index=None, + computing=None, + communication_not_overlapped=None, + overlapped=None, + communication=None, + free=None, + stage=None, + bubble=None, + communication_not_overlapped_and_exclude_receive=None, + preparing=None, + dp_index=None, + pp_index=None, + tp_index=None): + self._step = step + self._type = type + self._index = index + self._computing = computing + self._communication_not_overlapped = communication_not_overlapped + self._overlapped = overlapped + self._communication = communication + self._free = free + self._stage = stage + self._bubble = bubble + self._communication_not_overlapped_and_exclude_receive = communication_not_overlapped_and_exclude_receive + self._preparing = preparing + self._dp_index = dp_index + self._pp_index = pp_index + self._tp_index = tp_index + + @property + def step(self): + return self._step + + @step.setter + def step(self, value): + self._step = value + + @property + def type(self): + return self._type + + @type.setter + def type(self, value): + self._type = value + + @property + def index(self): + return self._index + + @index.setter + def index(self, value): + self._index = value + + @property + def computing(self): + return self._computing + + @computing.setter + def computing(self, value): + self._computing = value + + @property + def communication_not_overlapped(self): + return self._communication_not_overlapped + + @communication_not_overlapped.setter + def communication_not_overlapped(self, value): + self._communication_not_overlapped = value + + @property + def overlapped(self): + return self._overlapped + + @overlapped.setter + def overlapped(self, value): + self._overlapped = value + + @property + def communication(self): + return self._communication + + @communication.setter + def communication(self, value): + self._communication = value + + @property + def free(self): + return self._free + + @free.setter + def free(self, value): + self._free = value + + @property + def stage(self): + return self._stage + + @stage.setter + def stage(self, value): + self._stage = value + + @property + def bubble(self): + return self._bubble + + @bubble.setter + def bubble(self, value): + self._bubble = value + + @property + def communication_not_overlapped_and_exclude_receive(self): + return self._communication_not_overlapped_and_exclude_receive + + @communication_not_overlapped_and_exclude_receive.setter + def communication_not_overlapped_and_exclude_receive(self, value): + self._communication_not_overlapped_and_exclude_receive = value + + @property + def preparing(self): + return self._preparing + + @preparing.setter + def preparing(self, value): + self._preparing = value + + @property + def dp_index(self): + return self._dp_index + + @dp_index.setter + def dp_index(self, value): + self._dp_index = value + + @property + def pp_index(self): + return self._pp_index + + @pp_index.setter + def pp_index(self, value): + self._pp_index = value + + @property + def tp_index(self): + return self._tp_index + + @tp_index.setter + def tp_index(self, value): + self._tp_index = value \ No newline at end of file diff --git a/profiler/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py b/profiler/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py new file mode 100644 index 0000000000000000000000000000000000000000..ab233afcf9a1210274c85675a69efddf10a1e42c --- /dev/null +++ b/profiler/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py @@ -0,0 +1,171 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cluster analyse pytorch db""" +import logging +import os +import sqlite3 +import subprocess +from unittest import TestCase + +import pandas as pd + +from profiler.prof_common.file_manager import FileManager +from profiler.prof_common.path_manager import PathManager +from .cluster_communication_matrixDb import ClusterCommunicationMatrixDb +from .cluster_step_trace_time_db import ClusterStepTraceTimeDb + + +class TestClusterAnalysePytorchDb(TestCase): + """ + Test cluster analyse pytorch db + """ + ST_DATA_PATH = os.getenv("MSTT_PROFILER_ST_DATA_PATH", + "/home/dcs-50/smoke_project_for_msprof_analyze/mstt_profiler/st_data/") + CLUSTER_PATH = os.path.join(ST_DATA_PATH, "cluster_data_2_db") + DB_DIR_PATH = ST_DATA_PATH + DB_PATH = "" + STEP_TRACE_TIME_PATH = os.path.join(ST_DATA_PATH, "cluster_data_2_db", "cluster_analysis_output_text", + "cluster_analysis_output", "cluster_step_trace_time.csv") + COMMUNICATION_MATRIX_PATH = os.path.join(ST_DATA_PATH, "cluster_data_2_db", "cluster_analysis_output_text", + "cluster_analysis_output", "cluster_communication_matrix.json") + COMMAND_SUCCESS = 0 + + def setup_class(self): + # generate db data + PathManager.make_dir_safety(self.DB_DIR_PATH) + cmd = ["msprof-analyze", "cluster", "-d", self.CLUSTER_PATH, "-m", "all", + "--output_path", self.DB_DIR_PATH, "--data_simplification", "--force"] + logging.info(cmd) + completed_process = subprocess.run(cmd, capture_output=True, shell=False, check=True) + if completed_process.returncode != self.COMMAND_SUCCESS or not os.path.exists(self.DB_DIR_PATH): + self.assertEqual(completed_process.returncode == self.COMMAND_SUCCESS, + msg="pytorch db cluster analyse task failed.") + self.DB_PATH = os.path.join(self.DB_DIR_PATH, "cluster_analysis_output", "cluster_analysis.db") + + def teardown_class(self): + pass + + def test_msprof_analyze_text_db_trace_time_compare(self): + """ + Test case to compare the cluster step trace time from text file and database. + """ + df = pd.read_csv(self.STEP_TRACE_TIME_PATH) + query_count = "SELECT count(*) FROM ClusterStepTraceTime" + self.assertEqual(len(df), self._select_count(self.DB_PATH, query_count), + "Cluster step trace time count wrong.") + query = "SELECT * FROM ClusterStepTraceTime where type= 'rank' and [index] = 7" + db_cluster_step_trace_time = self._select_by_query(self.DB_PATH, query, ClusterStepTraceTimeDb) + text_cluster_step_trace_time = ClusterStepTraceTimeDb(*df.iloc[0]) + self.assertEqual(text_cluster_step_trace_time.type, db_cluster_step_trace_time.type, + "Cluster step trace time db vs text 'type' property wrong.") + self.assertEqual(text_cluster_step_trace_time.index, db_cluster_step_trace_time.index, + "Cluster step trace time db vs text 'index' property wrong.") + self.assertEqual(round(text_cluster_step_trace_time.computing), round(db_cluster_step_trace_time.computing), + "Cluster step trace time db vs text 'computing' property wrong.") + self.assertEqual(int(text_cluster_step_trace_time.communication_not_overlapped) + 1, + round(db_cluster_step_trace_time.communication_not_overlapped), + "Cluster step trace time db vs text 'communication_not_overlapped' property wrong.") + self.assertEqual(round(text_cluster_step_trace_time.overlapped), round(db_cluster_step_trace_time.overlapped), + "Cluster step trace time db vs text 'overlapped' property wrong.") + self.assertEqual(round(text_cluster_step_trace_time.communication), + round(db_cluster_step_trace_time.communication), + "Cluster step trace time db vs text 'communication' property wrong.") + self.assertEqual(round(text_cluster_step_trace_time.free), round(db_cluster_step_trace_time.free), + "Cluster step trace time db vs text 'free' property wrong.") + self.assertEqual(round(text_cluster_step_trace_time.stage), round(db_cluster_step_trace_time.stage), + "Cluster step trace time db vs text 'stage' property wrong.") + self.assertEqual(round(text_cluster_step_trace_time.bubble), round(db_cluster_step_trace_time.bubble), + "Cluster step trace time db vs text 'bubble' property wrong.") + self.assertEqual(int(text_cluster_step_trace_time.communication_not_overlapped_and_exclude_receive) + 1, + round(db_cluster_step_trace_time.communication_not_overlapped_and_exclude_receive), + "Cluster step trace time db vs text 'communication_not_overlapped_and_exclude_receive' " + "property wrong.") + + def test_msprof_analyze_text_db_communication_matrix_compare(self): + """ + Test case to compare the cluster communication matrix from text file and database. + """ + query = ("SELECT * FROM ClusterCommunicationMatrix WHERE hccl_op_name = 'Total Op Info' and src_rank = 7 " + "and group_name = '15244899533746605158' and dst_rank = 4 and step = 'step'") + db_cluster_communication_matrix = self._select_by_query(self.DB_PATH, query, ClusterCommunicationMatrixDb) + query_count = ("SELECT count(*) FROM ClusterCommunicationMatrix WHERE hccl_op_name = 'Total Op Info' and " + "group_name = '15244899533746605158'") + communication_matrix_json = FileManager.read_json_file(self.COMMUNICATION_MATRIX_PATH) + self.assertEqual(self._select_count(self.DB_PATH, query_count), + len(communication_matrix_json.get('(4, 5, 6, 7)') + .get('step').get('Total Op Info')), + "Cluster communication matrix db vs text count wrong.") + text_cluster_communication_matrix = (communication_matrix_json.get('(4, 5, 6, 7)').get('step') + .get('Total Op Info').get('7-4')) + self.assertEqual(text_cluster_communication_matrix.get('Transport Type'), + db_cluster_communication_matrix.transport_type, + "Cluster communication matrix db vs text 'Transport Type' property wrong.") + self.assertEqual(round(text_cluster_communication_matrix.get('Transit Time(ms)')), + round(db_cluster_communication_matrix.transit_time), + "Cluster communication matrix db vs text 'Transit Time' property wrong.") + self.assertEqual(round(text_cluster_communication_matrix.get('Transit Size(MB)')), + round(db_cluster_communication_matrix.transit_size), + "Cluster communication matrix db vs text 'Transit Size' property wrong.") + self.assertEqual(round(text_cluster_communication_matrix.get('Bandwidth(GB/s)')), + round(db_cluster_communication_matrix.bandwidth), + "Cluster communication matrix db vs text 'Bandwidth' property wrong.") + + def _select_count(self, db_path: str, query: str): + """ + Execute a SQL query to count the number of records in the database. + """ + conn, cursor = self._create_connect_db(db_path) + cursor.execute(query) + count = cursor.fetchone() + self._destroy_db_connect(conn, cursor) + return count[0] + + def _select_by_query(self, db_path: str, query: str, db_class): + """ + Execute a SQL query and return the first record as an instance of db_class. + """ + conn, cursor = self._create_connect_db(db_path) + cursor.execute(query) + rows = cursor.fetchall() + dbs = [db_class(*row) for row in rows] + self._destroy_db_connect(conn, cursor) + return dbs[0] + + def _create_connect_db(self, db_file: str) -> tuple: + """ + Create a connection to the SQLite database. + """ + try: + conn = sqlite3.connect(db_file) + curs = conn.cursor() + return conn, curs + except sqlite3.Error as e: + logging.error("Unable to connect to database: %s", e) + return None, None + + def _destroy_db_connect(self, conn: any, curs: any) -> None: + """ + Close the database connection and cursor. + """ + try: + if isinstance(curs, sqlite3.Cursor): + curs.close() + except sqlite3.Error as err: + logging.error("%s", err) + try: + if isinstance(conn, sqlite3.Connection): + conn.close() + except sqlite3.Error as err: + logging.error("%s", err)