From 4ff0f290dccd660175a17a4540ded9aebdaffc80 Mon Sep 17 00:00:00 2001 From: w00800385 Date: Wed, 13 Mar 2024 11:19:11 +0800 Subject: [PATCH] =?UTF-8?q?=E5=BD=93=E8=BE=93=E5=85=A5DB=E6=95=B0=E6=8D=AE?= =?UTF-8?q?=E4=B8=BA=E7=A9=BA=E6=88=96=E8=80=85=E8=AE=A1=E7=AE=97=E7=BB=93?= =?UTF-8?q?=E6=9E=9C=E4=B8=BA=E7=A9=BA=E6=97=B6=EF=BC=8C=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E6=8F=90=E7=A4=BA=E6=97=A5=E5=BF=97=E6=89=93=E5=8D=B0=EF=BC=8C?= =?UTF-8?q?=E4=B8=8D=E7=94=9F=E6=88=90DB=E6=88=96=E8=80=85=E5=AF=B9?= =?UTF-8?q?=E5=BA=94=E8=A1=A8?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../communication_analysis_db.py | 5 ++++- .../comm_matrix_analysis_db.py | 5 ++++- .../analysis/step_trace_time_analysis.py | 8 +++++--- .../cluster_analyse/common_func/db_manager.py | 6 +++--- .../base_communication_group.py | 4 ++-- .../communication_db_group.py | 19 +++++++++++-------- 6 files changed, 29 insertions(+), 18 deletions(-) diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py index 3559a9a28..ff371cf7a 100644 --- a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py @@ -32,6 +32,9 @@ class CommunicationAnalysisDB: self.dump_data() def dump_data(self): + if not self.res_comm_time and not self.res_comm_bandwidth: + print("[WARNING] There is no final communication data generated") + return output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) DBManager.create_tables(result_db, self.COMMUNICATION_TIME_TABLE, self.COMMUNICATION_BANDWIDTH_TABLE) @@ -66,7 +69,7 @@ class CommunicationAnalysisDB: if data[TableConstant.TYPE] == Constant.P2P: rank_tuple = Constant.P2P else: - rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME], [])) res_dict.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], []).append(data) def compute_total_info(self): diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py index df58fcecf..dbee80deb 100644 --- a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py @@ -24,6 +24,9 @@ class CommMatrixAnalysisDB: self.dump_data() def dump_data(self): + if not self.res_comm_matrix: + print("[WARNING] There is no final communication_matrix data generated") + return output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) DBManager.create_tables(result_db, self.COMMUNICATION_MATRIX_TABLE) @@ -128,6 +131,6 @@ class CommMatrixAnalysisDB: if data[TableConstant.STEP] == Constant.P2P: rank_tuple = Constant.P2P else: - rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME], [])) self.comm_matrix_struct.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], {}). \ setdefault(op_name, []).append(data) diff --git a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py index 20a71df3c..f570deee1 100644 --- a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py +++ b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py @@ -53,6 +53,7 @@ class StepTraceTimeAnalysis: def dump_data(self): if not self.step_data_list: print("[WARNING] Can't get step time info!") + return if self.data_type == Constant.TEXT: headers = self.get_headers() FileManager.create_csv_file(self.collection_path, self.step_data_list, self.CLUSTER_TRACE_TIME_CSV, headers) @@ -70,19 +71,20 @@ class StepTraceTimeAnalysis: for rank_id, profiling_dir_path in self.data_map.items(): if self.data_type == Constant.TEXT: step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) - if step_time_file: + if os.path.exists(step_time_file): self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) else: step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) - if step_time_file and DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE): + if (os.path.exists(step_time_file) and + DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE)): conn, cursor = DBManager.create_connect_db(step_time_file) sql = "select * from {0}".format(Constant.TABLE_STEP_TRACE) data = DBManager.fetch_all_data(cursor, sql, is_dict=False) self.step_time_dict[rank_id] = data DBManager.destroy_db_connect(conn, cursor) if not self.step_time_dict.get(rank_id): - print(f"[WARNING] Rank {rank_id} does not have a valid step_trace_time.json.") + print(f"[WARNING] Rank {rank_id} does not have a valid step_trace_time data in {self.data_type} file.") def analyze_step_time(self): for rank_id, data_bean_list in self.step_time_dict.items(): diff --git a/profiler/cluster_analyse/common_func/db_manager.py b/profiler/cluster_analyse/common_func/db_manager.py index f19bc15dc..bdee49be6 100644 --- a/profiler/cluster_analyse/common_func/db_manager.py +++ b/profiler/cluster_analyse/common_func/db_manager.py @@ -35,7 +35,7 @@ class DBManager: """ create and connect database """ - if check_db_path_valid(db_path): + if check_db_path_valid(db_path, is_create=True): try: conn = sqlite3.connect(db_path) except sqlite3.Error as err: @@ -100,7 +100,7 @@ class DBManager: @classmethod def check_tables_in_db(cls, db_path: any, *tables: any) -> bool: - if check_db_path_valid(db_path, True): + if check_db_path_valid(db_path): conn, curs = cls.create_connect_db(db_path) if not (conn and curs): return False @@ -114,7 +114,7 @@ class DBManager: return False @classmethod - def create_tables(cls, db_path: any, *tables: any) -> bool: + def create_tables(cls, db_path: any, *tables: any): conn, curs = cls.create_connect_db(db_path) for table_name in tables: if not cls.judge_table_exists(curs, table_name): diff --git a/profiler/cluster_analyse/communication_group/base_communication_group.py b/profiler/cluster_analyse/communication_group/base_communication_group.py index 515c77c93..a275fefe7 100644 --- a/profiler/cluster_analyse/communication_group/base_communication_group.py +++ b/profiler/cluster_analyse/communication_group/base_communication_group.py @@ -43,11 +43,11 @@ class BaseCommunicationGroup: else: comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) matrix_dir = comm_dir - if comm_dir and matrix_dir: + if os.path.exists(comm_dir) or os.path.exists(matrix_dir): comm_op_dirs.append((rank_id, comm_dir, matrix_dir)) else: print( - f"[WARNING] Rank {rank_id} does not have a valid communication.json or communication_matrix.json.") + f"[WARNING] Rank {rank_id} does not have valid communication data and communication_matrix data.") with Pool() as p: self.rank_comm_dir_dict = p.map(self.read_communication_func, comm_op_dirs) diff --git a/profiler/cluster_analyse/communication_group/communication_db_group.py b/profiler/cluster_analyse/communication_group/communication_db_group.py index 7dcc8f9c2..c61411eda 100644 --- a/profiler/cluster_analyse/communication_group/communication_db_group.py +++ b/profiler/cluster_analyse/communication_group/communication_db_group.py @@ -23,17 +23,18 @@ class CommunicationDBGroup(BaseCommunicationGroup): time_data = {} bandwidth_data = {} matrix_data = {} - if DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_TIME, - Constant.TABLE_COMM_ANALYZER_BANDWIDTH, - Constant.TABLE_COMM_ANALYZER_MATRIX): + if os.path.exists(db_path): conn, cursor = DBManager.create_connect_db(db_path) time_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_TIME) bandwidth_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) matrix_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_MATRIX) - if self.analysis_mode in ["all", "communication_time"]: + if (DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_TIME, + Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + and self.analysis_mode in ["all", "communication_time"]): time_data = DBManager.fetch_all_data(cursor, time_info_sql) bandwidth_data = DBManager.fetch_all_data(cursor, bandwidth_info_sql) - if self.analysis_mode in ["all", "communication_matrix"]: + if (DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_MATRIX) + and self.analysis_mode in ["all", "communication_matrix"]): matrix_data = DBManager.fetch_all_data(cursor, matrix_info_sql) DBManager.destroy_db_connect(conn, cursor) return rank_id, (self.data_group_by_step(time_data), self.data_group_by_step(bandwidth_data), @@ -49,19 +50,21 @@ class CommunicationDBGroup(BaseCommunicationGroup): def dump_data(self): output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) - DBManager.create_tables(result_db, self.COMMUNICATION_GROUP_TABLE) res = [] - conn, cursor = DBManager.create_connect_db(result_db) for data_type, data_list in self.communication_group.items(): for data in data_list: rank_set = "(" + ",".join(str(i) for i in data) + ")" data = [data_type, rank_set] res.append(data) if res: + DBManager.create_tables(result_db, self.COMMUNICATION_GROUP_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) sql = "insert into {} values ({value})".format(self.COMMUNICATION_GROUP_TABLE, value="?," * (len(res[0]) - 1) + "?") DBManager.executemany_sql(conn, sql, res) - DBManager.destroy_db_connect(conn, cursor) + DBManager.destroy_db_connect(conn, cursor) + else: + print("[WARNING] The CommunicationGroup table won't be created because no data has been calculated.") comm_data_dict = { Constant.COLLECTIVE_GROUP: self.collective_group_dict, Constant.COMMUNICATION_TIME_INFO: self.communication_time_info, -- Gitee