diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py index 3559a9a28f52c42c0e62e7a53191a5276ab7354d..ff371cf7a8878d3384e4321d0ab8922018c6325c 100644 --- a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py @@ -32,6 +32,9 @@ class CommunicationAnalysisDB: self.dump_data() def dump_data(self): + if not self.res_comm_time and not self.res_comm_bandwidth: + print("[WARNING] There is no final communication data generated") + return output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) DBManager.create_tables(result_db, self.COMMUNICATION_TIME_TABLE, self.COMMUNICATION_BANDWIDTH_TABLE) @@ -66,7 +69,7 @@ class CommunicationAnalysisDB: if data[TableConstant.TYPE] == Constant.P2P: rank_tuple = Constant.P2P else: - rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME], [])) res_dict.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], []).append(data) def compute_total_info(self): diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py index df58fcecff82be946fb95378b283db904af0b1d7..dbee80debd7ca4b9ea25636ba249fa7a83753ddb 100644 --- a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py @@ -24,6 +24,9 @@ class CommMatrixAnalysisDB: self.dump_data() def dump_data(self): + if not self.res_comm_matrix: + print("[WARNING] There is no final communication_matrix data generated") + return output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) DBManager.create_tables(result_db, self.COMMUNICATION_MATRIX_TABLE) @@ -128,6 +131,6 @@ class CommMatrixAnalysisDB: if data[TableConstant.STEP] == Constant.P2P: rank_tuple = Constant.P2P else: - rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME], [])) self.comm_matrix_struct.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], {}). \ setdefault(op_name, []).append(data) diff --git a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py index 20a71df3c57437e5f278ebe450c8811b26bbe3ef..f570deee1c9ac53f7bbe65be9660d9e014576d04 100644 --- a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py +++ b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py @@ -53,6 +53,7 @@ class StepTraceTimeAnalysis: def dump_data(self): if not self.step_data_list: print("[WARNING] Can't get step time info!") + return if self.data_type == Constant.TEXT: headers = self.get_headers() FileManager.create_csv_file(self.collection_path, self.step_data_list, self.CLUSTER_TRACE_TIME_CSV, headers) @@ -70,19 +71,20 @@ class StepTraceTimeAnalysis: for rank_id, profiling_dir_path in self.data_map.items(): if self.data_type == Constant.TEXT: step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) - if step_time_file: + if os.path.exists(step_time_file): self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) else: step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) - if step_time_file and DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE): + if (os.path.exists(step_time_file) and + DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE)): conn, cursor = DBManager.create_connect_db(step_time_file) sql = "select * from {0}".format(Constant.TABLE_STEP_TRACE) data = DBManager.fetch_all_data(cursor, sql, is_dict=False) self.step_time_dict[rank_id] = data DBManager.destroy_db_connect(conn, cursor) if not self.step_time_dict.get(rank_id): - print(f"[WARNING] Rank {rank_id} does not have a valid step_trace_time.json.") + print(f"[WARNING] Rank {rank_id} does not have a valid step_trace_time data in {self.data_type} file.") def analyze_step_time(self): for rank_id, data_bean_list in self.step_time_dict.items(): diff --git a/profiler/cluster_analyse/common_func/db_manager.py b/profiler/cluster_analyse/common_func/db_manager.py index f19bc15dc83821d6d65cf3f39713565da8293989..bdee49be60640362d54a697ca202e9d54a58d4f8 100644 --- a/profiler/cluster_analyse/common_func/db_manager.py +++ b/profiler/cluster_analyse/common_func/db_manager.py @@ -35,7 +35,7 @@ class DBManager: """ create and connect database """ - if check_db_path_valid(db_path): + if check_db_path_valid(db_path, is_create=True): try: conn = sqlite3.connect(db_path) except sqlite3.Error as err: @@ -100,7 +100,7 @@ class DBManager: @classmethod def check_tables_in_db(cls, db_path: any, *tables: any) -> bool: - if check_db_path_valid(db_path, True): + if check_db_path_valid(db_path): conn, curs = cls.create_connect_db(db_path) if not (conn and curs): return False @@ -114,7 +114,7 @@ class DBManager: return False @classmethod - def create_tables(cls, db_path: any, *tables: any) -> bool: + def create_tables(cls, db_path: any, *tables: any): conn, curs = cls.create_connect_db(db_path) for table_name in tables: if not cls.judge_table_exists(curs, table_name): diff --git a/profiler/cluster_analyse/communication_group/base_communication_group.py b/profiler/cluster_analyse/communication_group/base_communication_group.py index 515c77c93acf8cd7019747e2362ae103ba8fa528..a275fefe75d0003bd68793269cbba74f31806bd1 100644 --- a/profiler/cluster_analyse/communication_group/base_communication_group.py +++ b/profiler/cluster_analyse/communication_group/base_communication_group.py @@ -43,11 +43,11 @@ class BaseCommunicationGroup: else: comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) matrix_dir = comm_dir - if comm_dir and matrix_dir: + if os.path.exists(comm_dir) or os.path.exists(matrix_dir): comm_op_dirs.append((rank_id, comm_dir, matrix_dir)) else: print( - f"[WARNING] Rank {rank_id} does not have a valid communication.json or communication_matrix.json.") + f"[WARNING] Rank {rank_id} does not have valid communication data and communication_matrix data.") with Pool() as p: self.rank_comm_dir_dict = p.map(self.read_communication_func, comm_op_dirs) diff --git a/profiler/cluster_analyse/communication_group/communication_db_group.py b/profiler/cluster_analyse/communication_group/communication_db_group.py index 7dcc8f9c2333892ce774fe10cc88e567dd8e564c..c61411edab2f7ba9a619680d9803e1f5302c3966 100644 --- a/profiler/cluster_analyse/communication_group/communication_db_group.py +++ b/profiler/cluster_analyse/communication_group/communication_db_group.py @@ -23,17 +23,18 @@ class CommunicationDBGroup(BaseCommunicationGroup): time_data = {} bandwidth_data = {} matrix_data = {} - if DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_TIME, - Constant.TABLE_COMM_ANALYZER_BANDWIDTH, - Constant.TABLE_COMM_ANALYZER_MATRIX): + if os.path.exists(db_path): conn, cursor = DBManager.create_connect_db(db_path) time_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_TIME) bandwidth_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) matrix_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_MATRIX) - if self.analysis_mode in ["all", "communication_time"]: + if (DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_TIME, + Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + and self.analysis_mode in ["all", "communication_time"]): time_data = DBManager.fetch_all_data(cursor, time_info_sql) bandwidth_data = DBManager.fetch_all_data(cursor, bandwidth_info_sql) - if self.analysis_mode in ["all", "communication_matrix"]: + if (DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_MATRIX) + and self.analysis_mode in ["all", "communication_matrix"]): matrix_data = DBManager.fetch_all_data(cursor, matrix_info_sql) DBManager.destroy_db_connect(conn, cursor) return rank_id, (self.data_group_by_step(time_data), self.data_group_by_step(bandwidth_data), @@ -49,19 +50,21 @@ class CommunicationDBGroup(BaseCommunicationGroup): def dump_data(self): output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) - DBManager.create_tables(result_db, self.COMMUNICATION_GROUP_TABLE) res = [] - conn, cursor = DBManager.create_connect_db(result_db) for data_type, data_list in self.communication_group.items(): for data in data_list: rank_set = "(" + ",".join(str(i) for i in data) + ")" data = [data_type, rank_set] res.append(data) if res: + DBManager.create_tables(result_db, self.COMMUNICATION_GROUP_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) sql = "insert into {} values ({value})".format(self.COMMUNICATION_GROUP_TABLE, value="?," * (len(res[0]) - 1) + "?") DBManager.executemany_sql(conn, sql, res) - DBManager.destroy_db_connect(conn, cursor) + DBManager.destroy_db_connect(conn, cursor) + else: + print("[WARNING] The CommunicationGroup table won't be created because no data has been calculated.") comm_data_dict = { Constant.COLLECTIVE_GROUP: self.collective_group_dict, Constant.COMMUNICATION_TIME_INFO: self.communication_time_info,