From 28236864f7613ac803946c2c8ca894bc6a3eb0d6 Mon Sep 17 00:00:00 2001 From: w00800385 Date: Thu, 7 Mar 2024 20:59:34 +0800 Subject: [PATCH] =?UTF-8?q?att=E6=94=AF=E6=8C=81DB=E6=A0=BC=E5=BC=8F?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E8=BE=93=E5=85=A5=E4=B8=8E=E5=A4=84=E7=90=86?= =?UTF-8?q?=EF=BC=8C=E4=BF=AE=E6=AD=A3codec=20heck=E5=91=8A=E8=AD=A6,?= =?UTF-8?q?=E6=A3=80=E8=A7=86=E6=84=8F=E8=A7=81=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../communication_analysis_db.py | 140 +++++++++++++++++- .../comm_matrix_analysis_db.py | 115 +++++++++++++- .../communication_db_group.py | 87 ++++++++++- 3 files changed, 337 insertions(+), 5 deletions(-) diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py index 0f7085ad9c..3559a9a28f 100644 --- a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py @@ -24,4 +24,142 @@ class CommunicationAnalysisDB: self.res_comm_bandwidth = [] def run(self): - pass \ No newline at end of file + if not self.communication_time_info and not self.communication_bandwidth_info: + return + self.split_and_add_rank_set(self.communication_time_info, self.comm_time_struct) + self.split_and_add_rank_set(self.communication_bandwidth_info, self.comm_bandwidth_struct) + self.compute_total_info() + self.dump_data() + + def dump_data(self): + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_TIME_TABLE, self.COMMUNICATION_BANDWIDTH_TABLE) + res_time, res_bandwidth = [], [] + conn, cursor = DBManager.create_connect_db(result_db) + for data in self.res_comm_time: + res_time.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.RANK_ID], + data[TableConstant.HCCL_OP_NAME], data[TableConstant.GROUP_NAME], + data[TableConstant.START_TIMESTAMP], data[TableConstant.ELAPSED_TIME], + data[TableConstant.TRANSIT_TIME], data[TableConstant.WAIT_TIME], + data[TableConstant.SYNCHRONIZATION_TIME], data[TableConstant.IDLE_TIME], + data[TableConstant.SYNCHRONIZATION_TIME_RATIO], data[TableConstant.WAIT_TIME_RATIO]]) + if res_time: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_TIME_TABLE, + value="?," * (len(res_time[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res_time) + for data in self.res_comm_bandwidth: + res_bandwidth.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.RANK_ID], + data[TableConstant.HCCL_OP_NAME], data[TableConstant.GROUP_NAME], + data[TableConstant.TRANSPORT_TYPE], data[TableConstant.TRANSIT_SIZE], + data[TableConstant.TRANSIT_TIME], data[TableConstant.BANDWIDTH], + data[TableConstant.LARGE_PACKET_RATIO], data[TableConstant.PACKAGE_SIZE], + data[TableConstant.COUNT], data[TableConstant.TOTAL_DURATION]]) + if res_bandwidth: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_BANDWIDTH_TABLE, + value="?," * (len(res_bandwidth[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res_bandwidth) + DBManager.destroy_db_connect(conn, cursor) + + def split_and_add_rank_set(self, data_list, res_dict): + for data in data_list: + if data[TableConstant.TYPE] == Constant.P2P: + rank_tuple = Constant.P2P + else: + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + res_dict.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], []).append(data) + + def compute_total_info(self): + for rank_tuple, op_dict in self.comm_time_struct.items(): + if rank_tuple != Constant.P2P: + for step, data_list in op_dict.items(): + self.compute_rank_set_total_time_info(data_list, rank_tuple) + else: + rank_set = set() + for step, data_list in op_dict.items(): + rank_set.add(data[TableConstant.RANK_ID] for data in data_list) + for step, data_list in op_dict.items(): + self.compute_rank_set_total_time_info(data_list, rank_set, True) + for rank_tuple, op_dict in self.comm_bandwidth_struct.items(): + for step, data_list in op_dict.items(): + if rank_tuple != Constant.P2P: + self.compute_rank_set_total_bandwidth_info(data_list, rank_tuple) + else: + self.compute_rank_set_total_bandwidth_info(data_list, rank_tuple, True) + + def compute_rank_set_total_bandwidth_info(self, data_list, rank_tuple, is_p2p=False): + if not data_list: + return + data_dict = {} + rank_tuple = "(" + ",".join(str(i) for i in rank_tuple) + ")" if not is_p2p else Constant.P2P + for data in data_list: + data[TableConstant.RANK_SET] = rank_tuple + rank_band_type = self.RANK_BAND_TYPE.format(data[TableConstant.RANK_ID], + data[TableConstant.TRANSPORT_TYPE]) + data_dict.setdefault(rank_band_type, []).append(data) + self.res_comm_bandwidth.append(data) + for rank_band_type, bandwidth_list in data_dict.items(): + package_set = set() + for data in bandwidth_list: + package_set.add(data[TableConstant.PACKAGE_SIZE]) + for package in package_set: + total_comm_bandwidth_info = dict() + for data in bandwidth_list: + self.compute_bandwidth(total_comm_bandwidth_info, data, package) + bandwidth = BaseAnalysisJson.compute_ratio(total_comm_bandwidth_info.get(TableConstant.TRANSIT_SIZE), + total_comm_bandwidth_info.get(TableConstant.TRANSIT_TIME)) + total_comm_bandwidth_info[TableConstant.BANDWIDTH] = bandwidth + total_comm_bandwidth_info[TableConstant.PACKAGE_SIZE] = package + total_comm_bandwidth_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_comm_bandwidth_info[TableConstant.GROUP_NAME] = "" + total_comm_bandwidth_info[TableConstant.LARGE_PACKET_RATIO] = 0.0 + self.res_comm_bandwidth.append(total_comm_bandwidth_info) + + def compute_bandwidth(self, res_dict, data_dict, package): + for key in data_dict.keys(): + if key in [TableConstant.TRANSIT_TIME, TableConstant.TRANSIT_SIZE]: + if key not in res_dict.keys(): + res_dict[key] = 0.0 + res_dict[key] += data_dict[key] + elif key in [TableConstant.COUNT, TableConstant.TOTAL_DURATION]: + if data_dict[TableConstant.PACKAGE_SIZE] == package: + if key not in res_dict.keys(): + res_dict[key] = 0.0 + res_dict[key] += data_dict[key] + else: + res_dict[key] = data_dict[key] + + def compute_time(self, res_dict, data_dict, dict_key): + if dict_key.endswith(self.TIME_EXTENSION): + if dict_key not in res_dict.keys(): + res_dict[dict_key] = 0.0 + res_dict[dict_key] += data_dict[dict_key] + else: + res_dict[dict_key] = data_dict[dict_key] + + def compute_rank_set_total_time_info(self, data_list: list, rank_tuple: any, is_p2p: bool = False): + if not data_list: + return + rank_set = "(" + ",".join(str(i) for i in rank_tuple) + ")" if not is_p2p else Constant.P2P + for rank_id in rank_tuple: + total_comm_time_info = dict() + for data in data_list: + if data[TableConstant.RANK_ID] == rank_id: + data[TableConstant.RANK_SET] = rank_set + data[TableConstant.SYNCHRONIZATION_TIME_RATIO] = 0.0 + data[TableConstant.WAIT_TIME_RATIO] = 0.0 + for key, value in data.items(): + self.compute_time(total_comm_time_info, data, key) + syn_ratio = BaseAnalysisJson.compute_ratio(total_comm_time_info.get(TableConstant.SYNCHRONIZATION_TIME), + total_comm_time_info.get(TableConstant.SYNCHRONIZATION_TIME) + + total_comm_time_info.get(TableConstant.TRANSIT_TIME)) + wait_time_ratio = BaseAnalysisJson.compute_ratio(total_comm_time_info.get(TableConstant.WAIT_TIME), + total_comm_time_info.get(TableConstant.WAIT_TIME) + + total_comm_time_info.get(TableConstant.TRANSIT_TIME)) + total_comm_time_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_comm_time_info[TableConstant.GROUP_NAME] = "" + total_comm_time_info[TableConstant.START_TIMESTAMP] = 0.0 + total_comm_time_info[TableConstant.WAIT_TIME_RATIO] = wait_time_ratio + total_comm_time_info[TableConstant.SYNCHRONIZATION_TIME_RATIO] = syn_ratio + self.res_comm_time.append(total_comm_time_info) + self.res_comm_time.extend(data_list) diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py index ea02c990bf..df58fcecff 100644 --- a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py @@ -17,4 +17,117 @@ class CommMatrixAnalysisDB: self.res_comm_matrix = [] def run(self): - pass \ No newline at end of file + if not self.matrix_info: + return + self.set_rank_tuple() + self.combine_total_matrix_info() + self.dump_data() + + def dump_data(self): + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_MATRIX_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) + res = [] + for data in self.res_comm_matrix: + op_name = data.get(TableConstant.OPNAME) if data.get(TableConstant.OPNAME) is not None else "" + res.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.HCCL_OP_NAME], + data[TableConstant.GROUP_NAME], data[TableConstant.SRC_RANK], data[TableConstant.DST_RANK], + data[TableConstant.TRANSIT_SIZE], data[TableConstant.TRANSIT_TIME], + data[TableConstant.BANDWIDTH], data[TableConstant.TRANSPORT_TYPE], op_name]) + if res: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_MATRIX_TABLE, + value="?," * (len(res[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res) + DBManager.destroy_db_connect(conn, cursor) + + def combine_total_matrix_info(self): + for rank_tuple, group_dict in self.comm_matrix_struct.items(): + if rank_tuple != Constant.P2P: + rank_tuple = "(" + ",".join(str(i) for i in rank_tuple) + ")" + for step, step_dict in group_dict.items(): + self.merge_same_info(step_dict, rank_tuple) + self.combine_total_info(step_dict) + + def combine_total_info(self, step_dict: dict): + link_key_set = set() + for op_name, matrix_dict in step_dict.items(): + self.res_comm_matrix.extend(matrix_dict.values()) + if BaseAnalysisJson.check_add_op(op_name): + for key in matrix_dict.keys(): + link_key_set.add(key) + for link_key in link_key_set: + total_matrix_info = dict() + total_matrix_info[TableConstant.TRANSIT_SIZE] = 0.0 + total_matrix_info[TableConstant.TRANSIT_TIME] = 0.0 + for op_name, matrix_dict in step_dict.items(): + if link_key in matrix_dict.keys() and BaseAnalysisJson.check_add_op(op_name): + total_matrix_info[TableConstant.RANK_SET] = matrix_dict[link_key][TableConstant.RANK_SET] + self.combine_link_info(total_matrix_info, matrix_dict[link_key]) + bandwidth = BaseAnalysisJson.compute_ratio(total_matrix_info[TableConstant.TRANSIT_SIZE], + total_matrix_info[TableConstant.TRANSIT_TIME]) + total_matrix_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_matrix_info[TableConstant.GROUP_NAME] = "" + total_matrix_info[TableConstant.BANDWIDTH] = bandwidth + self.res_comm_matrix.append(total_matrix_info) + + def combine_link_info(self, link_info, data: dict): + for col in data.keys(): + if col in [TableConstant.TRANSIT_TIME, TableConstant.TRANSIT_SIZE]: + link_info[col] += data[col] + else: + link_info[col] = data[col] + + def merge_same_info(self, step_dict: dict, rank_tuple): + def process_matrix(): + for data in op_list: + if data[TableConstant.SRC_RANK] == data[TableConstant.DST_RANK]: + if data[TableConstant.SRC_RANK] not in local_global_rank_map: + local_global_rank_map[data[TableConstant.SRC_RANK]] = data[TableConstant.RANK_ID] + elif local_global_rank_map[data[TableConstant.SRC_RANK]] != data[TableConstant.RANK_ID]: + print(f"[WARNING] In the same communication group, local ranks projecting to global ranks " + f"repeat!") + if (link_key.split('-')[0] == data[TableConstant.SRC_RANK] and + link_key.split('-')[1] == data[TableConstant.DST_RANK]): + self.combine_link_info(matrix_info, data) + new_matrix_list[link_key] = matrix_info + + def convert_local_to_global_rank(): + res_dict = dict() + for key, new_matrix in new_matrix_list.items(): + src_rank = new_matrix[TableConstant.SRC_RANK] + dst_rank = new_matrix[TableConstant.DST_RANK] + src_rank = local_global_rank_map[src_rank] if src_rank in local_global_rank_map else src_rank + dst_rank = local_global_rank_map[dst_rank] if dst_rank in local_global_rank_map else dst_rank + bandwidth = BaseAnalysisJson.compute_ratio(new_matrix[TableConstant.TRANSIT_SIZE], + new_matrix[TableConstant.TRANSIT_TIME]) + key = f"{src_rank}-{dst_rank}" + new_matrix[TableConstant.SRC_RANK] = src_rank + new_matrix[TableConstant.DST_RANK] = dst_rank + new_matrix[TableConstant.BANDWIDTH] = bandwidth + res_dict[key] = new_matrix + return res_dict + + local_global_rank_map = dict() + for op_name, op_list in step_dict.items(): + new_matrix_list = {} + link_key_set = set() + for op_data in op_list: + link_key_set.add(op_data[TableConstant.SRC_RANK] + "-" + op_data[TableConstant.DST_RANK]) + for link_key in link_key_set: + matrix_info = dict() + matrix_info[TableConstant.RANK_SET] = rank_tuple + matrix_info[TableConstant.TRANSIT_SIZE] = 0.0 + matrix_info[TableConstant.TRANSIT_TIME] = 0.0 + process_matrix() + step_dict[op_name] = convert_local_to_global_rank() + + def set_rank_tuple(self): + for data in self.matrix_info: + op_name = data[TableConstant.HCCL_OP_NAME] + "@" + data[TableConstant.GROUP_NAME] + if data[TableConstant.STEP] == Constant.P2P: + rank_tuple = Constant.P2P + else: + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + self.comm_matrix_struct.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], {}). \ + setdefault(op_name, []).append(data) diff --git a/profiler/cluster_analyse/communication_group/communication_db_group.py b/profiler/cluster_analyse/communication_group/communication_db_group.py index 0122a63f16..e0cd7215e5 100644 --- a/profiler/cluster_analyse/communication_group/communication_db_group.py +++ b/profiler/cluster_analyse/communication_group/communication_db_group.py @@ -16,10 +16,91 @@ class CommunicationDBGroup(BaseCommunicationGroup): self.matrix_info = [] def read_communication_func(self, params: tuple): - pass + if len(params) < 3: + return -1, ({}, {}, {}) + rank_id = params[0] + db_path = params[1] + time_data = {} + bandwidth_data = {} + matrix_data = {} + if DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_TIME, + Constant.TABLE_COMM_ANALYZER_BANDWIDTH, + Constant.TABLE_COMM_ANALYZER_MATRIX): + conn, cursor = DBManager.create_connect_db(db_path) + time_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_TIME) + bandwidth_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + matrix_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_MATRIX) + if self.analysis_mode in ["all", "communication_time"]: + time_data = DBManager.fetch_all_data(cursor, time_info_sql) + bandwidth_data = DBManager.fetch_all_data(cursor, bandwidth_info_sql) + if self.analysis_mode in ["all", "communication_matrix"]: + matrix_data = DBManager.fetch_all_data(cursor, matrix_info_sql) + DBManager.destroy_db_connect(conn, cursor) + return rank_id, (self.data_group_by_step(time_data), self.data_group_by_step(bandwidth_data), + self.data_group_by_step(matrix_data)) + + @staticmethod + def data_group_by_step(data: any) -> any: + res = {} + for item in data: + res.setdefault(item[TableConstant.STEP], []).append(item) + return res def dump_data(self): - pass + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_GROUP_TABLE) + res = [] + conn, cursor = DBManager.create_connect_db(result_db) + for data_type, data_list in self.communication_group.items(): + for data in data_list: + rank_set = "(" + ",".join(str(i) for i in data) + ")" + data = [data_type, rank_set] + res.append(data) + if res: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_GROUP_TABLE, + value="?," * (len(res[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res) + DBManager.destroy_db_connect(conn, cursor) + comm_data_dict = { + Constant.COLLECTIVE_GROUP: self.collective_group_dict, + Constant.COMMUNICATION_TIME_INFO: self.communication_time_info, + Constant.COMMUNICATION_BANDWIDTH_INFO: self.communication_bandwidth_info, + Constant.MATRIX_OPS: self.matrix_info, + Constant.COMMUNICATION_GROUP: self.communication_group + } + return comm_data_dict def analyze_communication_data(self): - pass + for rank_id, data_tuple in self.rank_comm_dir_dict: + time_data, bandwidth_data, matrix_data = data_tuple[0], data_tuple[1], data_tuple[2] + for step, data_list in time_data.items(): + for data in data_list: + self.compute_collective_group(data, rank_id, self.communication_time_info) + for data in bandwidth_data[step]: + self.compute_collective_group(data, rank_id, self.communication_bandwidth_info) + for step, data_list in matrix_data.items(): + self.add_p2p_and_rank(rank_id, step, matrix_data) + for data in data_list: + self.compute_collective_group(data, rank_id, self.matrix_info) + + def compute_collective_group(self, data, rank_id, res_list): + if data[TableConstant.TYPE] == Constant.COLLECTIVE: + self.collective_group_dict[data[TableConstant.GROUP_NAME]].add(rank_id) + data[TableConstant.RANK_ID] = rank_id + res_list.append(data) + + def add_p2p_and_rank(self, rank_id: int, step: str, data_dict: dict): + data_list = data_dict[step] + if not data_list: + print(f"[WARNING] rank {rank_id} {step} don't have communication matrix ops data") + return + for data in data_list: + if data[TableConstant.TYPE] != Constant.COLLECTIVE and data[TableConstant.TYPE] != Constant.P2P: + print(f"[WARNING] Unknown communication operators type!") + continue + if data[TableConstant.TYPE] == Constant.P2P: + if data[TableConstant.SRC_RANK] != data[TableConstant.DST_RANK]: + rank_set = {data[TableConstant.SRC_RANK], data[TableConstant.DST_RANK]} + if rank_set not in self.p2p_link: + self.p2p_link.append(rank_set) -- Gitee