diff --git a/profiler/cluster_analyse/analysis/analysis_facade.py b/profiler/cluster_analyse/analysis/analysis_facade.py index b383a704df27d18e0191b2b251efd9de61beee55..0b870bbaafa6483bf2cfde49971d79106c07aa23 100644 --- a/profiler/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/cluster_analyse/analysis/analysis_facade.py @@ -14,14 +14,13 @@ # limitations under the License. from multiprocessing import Process -from common_func.constant import Constant -from analysis.communication_analysis import CommunicationAnalysis +from analysis.communication.comm_analysis_generator import CommunicationAnalysisGenerator +from analysis.communication_matrix.comm_matrix_generator import CommMatrixAnalysisGenerator from analysis.step_trace_time_analysis import StepTraceTimeAnalysis -from analysis.communication_analysis import CommMatrixAnalysis class AnalysisFacade: - analysis_module = {CommunicationAnalysis, StepTraceTimeAnalysis, CommMatrixAnalysis} + analysis_module = {CommunicationAnalysisGenerator, StepTraceTimeAnalysis, CommMatrixAnalysisGenerator} def __init__(self, params: dict): self.params = params diff --git a/profiler/cluster_analyse/analysis/base_analysis_json.py b/profiler/cluster_analyse/analysis/base_analysis_json.py new file mode 100644 index 0000000000000000000000000000000000000000..be8e42d3d5a8e30d477e9d522530d1e160a3a93e --- /dev/null +++ b/profiler/cluster_analyse/analysis/base_analysis_json.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from common_func.constant import Constant +from common_func.file_manager import FileManager + + +class BaseAnalysisJson: + + def __init__(self, param: dict): + self.collection_path = param.get(Constant.COLLECTION_PATH) + self.data_map = param.get(Constant.DATA_MAP) + self.communication_ops = [] + self.collective_group_dict = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) + self.comm_ops_struct = {} + + @staticmethod + def compute_ratio(dividend: float, divisor: float): + if abs(divisor) < Constant.EPS: + return 0 + else: + return round(dividend / divisor, 4) + + @staticmethod + def check_add_op(op_name: str): + """ + 兼容2个版本,判断是否需要将此算子信息相加 + """ + stat_list = ["middle", "top", "bottom", "total"] + total = "total" + for stat_name in stat_list: + if stat_name in op_name: + if stat_name != total: + return False + return True + + @abstractmethod + def run(self): + pass + + def dump_data(self): + if not self.comm_ops_struct: + print("[WARNING] There is no final comm ops data generated") + return + output_comm_data = {} + for key in self.comm_ops_struct: + output_comm_data[str(key)] = self.comm_ops_struct.get(key) + FileManager.create_json_file(self.collection_path, output_comm_data, self.SAVED_JSON) + + def split_op_by_group(self): + for single_op in self.communication_ops: + if single_op.get(Constant.COMM_OP_TYPE) == Constant.P2P: + rank_tup = Constant.P2P + else: + rank_tup = tuple(self.collective_group_dict.get(single_op.get(Constant.GROUP_NAME), [])) + rank_id = single_op.get(Constant.RANK_ID, 'N/A') + step_id = single_op.get(Constant.STEP_ID, 'N/A') + op_name = single_op.get(Constant.COMM_OP_NAME, 'N/A') + op_info = single_op.get(Constant.COMM_OP_INFO) + self.comm_ops_struct.setdefault(rank_tup, {}).setdefault(step_id, {}).\ + setdefault(op_name, {}).setdefault(rank_id, op_info) + + def combine_ops_total_info(self): + for rank_tup, group_dict in self.comm_ops_struct.items(): + for step_id, communication_ops in group_dict.items(): + self.compute_total_info(communication_ops) diff --git a/profiler/cluster_analyse/analysis/communication/__init__.py b/profiler/cluster_analyse/analysis/communication/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/cluster_analyse/analysis/communication/comm_analysis_generator.py b/profiler/cluster_analyse/analysis/communication/comm_analysis_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..2e727d9f2bac3d5c873d6a81d2f332d4f7763514 --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication/comm_analysis_generator.py @@ -0,0 +1,32 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from analysis.communication.communication_analysis_db import CommunicationAnalysisDB +from analysis.communication.communication_analysis_json import CommunicationJsonAnalysis +from common_func.constant import Constant + + +class CommunicationAnalysisGenerator: + + GROUP_MAP = { + Constant.DB: CommunicationAnalysisDB, + Constant.TEXT: CommunicationJsonAnalysis + } + + def __init__(self, params: dict): + self.generator = self.GROUP_MAP[params[Constant.DATA_TYPE]](params) + + def run(self): + self.generator.run() diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py new file mode 100644 index 0000000000000000000000000000000000000000..017ae205125f264e4e3e2d10a2b5b420ca209400 --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py @@ -0,0 +1,181 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from analysis.base_analysis_json import BaseAnalysisJson +from common_func.db_manager import DBManager +from common_func.constant import Constant +from common_func.table_constant import TableConstant + + +class CommunicationAnalysisDB: + COMMUNICATION_BANDWIDTH_TABLE = "ClusterCommAnalyzerBandwidth" + COMMUNICATION_TIME_TABLE = "ClusterCommAnalyzerTime" + TIME_EXTENSION = "time" + RANK_BAND_TYPE = "{}-{}" + + def __init__(self, params: any): + self.collection_path = params.get(Constant.COLLECTION_PATH) + self.communication_time_info = params.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_TIME_INFO) + self.communication_bandwidth_info = params.get(Constant.COMM_DATA_DICT, {}).get( + Constant.COMMUNICATION_BANDWIDTH_INFO) + self.collective_group_dict = params.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) + self.comm_time_struct = {} + self.comm_bandwidth_struct = {} + self.res_comm_time = [] + self.res_comm_bandwidth = [] + + def run(self): + if not self.communication_time_info and not self.communication_bandwidth_info: + return + self.split_and_add_rank_set(self.communication_time_info, self.comm_time_struct) + self.split_and_add_rank_set(self.communication_bandwidth_info, self.comm_bandwidth_struct) + self.compute_total_info() + self.dump_data() + + def dump_data(self): + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_TIME_TABLE, self.COMMUNICATION_BANDWIDTH_TABLE) + res_time, res_bandwidth = [], [] + conn, cursor = DBManager.create_connect_db(result_db) + for data in self.res_comm_time: + res_time.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.RANK_ID], + data[TableConstant.HCCL_OP_NAME], data[TableConstant.GROUP_NAME], + data[TableConstant.START_TIMESTAMP], data[TableConstant.ELAPSED_TIME], + data[TableConstant.TRANSIT_TIME], data[TableConstant.WAIT_TIME], + data[TableConstant.SYNCHRONIZATION_TIME], data[TableConstant.IDLE_TIME], + data[TableConstant.SYNCHRONIZATION_TIME_RATIO], data[TableConstant.WAIT_TIME_RATIO]]) + if res_time: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_TIME_TABLE, + value="?," * (len(res_time[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res_time) + for data in self.res_comm_bandwidth: + res_bandwidth.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.RANK_ID], + data[TableConstant.HCCL_OP_NAME], data[TableConstant.GROUP_NAME], + data[TableConstant.TRANSPORT_TYPE], data[TableConstant.TRANSIT_SIZE], + data[TableConstant.TRANSIT_TIME], data[TableConstant.BANDWIDTH], + data[TableConstant.LARGE_PACKET_RATIO], data[TableConstant.PACKAGE_SIZE], + data[TableConstant.COUNT], data[TableConstant.TOTAL_DURATION]]) + if res_bandwidth: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_BANDWIDTH_TABLE, + value="?," * (len(res_bandwidth[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res_bandwidth) + DBManager.destroy_db_connect(conn, cursor) + + def split_and_add_rank_set(self, data_list, res_dict): + for data in data_list: + if data[TableConstant.TYPE] == Constant.P2P: + rank_tuple = Constant.P2P + else: + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + res_dict.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], []).append(data) + + def compute_total_info(self): + for rank_tuple, op_dict in self.comm_time_struct.items(): + if rank_tuple != Constant.P2P: + for step, data_list in op_dict.items(): + self.compute_rank_set_total_time_info(data_list, rank_tuple) + else: + rank_set = set() + for step, data_list in op_dict.items(): + rank_set.add(data[TableConstant.RANK_ID] for data in data_list) + for step, data_list in op_dict.items(): + self.compute_rank_set_total_time_info(data_list, rank_set, True) + for rank_tuple, op_dict in self.comm_bandwidth_struct.items(): + for step, data_list in op_dict.items(): + if rank_tuple != Constant.P2P: + self.compute_rank_set_total_bandwidth_info(data_list, rank_tuple) + else: + self.compute_rank_set_total_bandwidth_info(data_list, rank_tuple, True) + + def compute_rank_set_total_bandwidth_info(self, data_list, rank_tuple, is_p2p=False): + if not data_list: + return + data_dict = {} + rank_tuple = "(" + ",".join(str(i) for i in rank_tuple) + ")" if not is_p2p else Constant.P2P + for data in data_list: + data[TableConstant.RANK_SET] = rank_tuple + rank_band_type = self.RANK_BAND_TYPE.format(data[TableConstant.RANK_ID], + data[TableConstant.TRANSPORT_TYPE]) + data_dict.setdefault(rank_band_type, []).append(data) + self.res_comm_bandwidth.append(data) + for rank_band_type, bandwidth_list in data_dict.items(): + package_set = set() + for data in bandwidth_list: + package_set.add(data[TableConstant.PACKAGE_SIZE]) + for package in package_set: + total_comm_bandwidth_info = dict() + for data in bandwidth_list: + self.compute_bandwidth(total_comm_bandwidth_info, data, package) + bandwidth = BaseAnalysisJson.compute_ratio(total_comm_bandwidth_info[TableConstant.TRANSIT_SIZE], + total_comm_bandwidth_info[TableConstant.TRANSIT_TIME]) + total_comm_bandwidth_info[TableConstant.BANDWIDTH] = bandwidth + total_comm_bandwidth_info[TableConstant.PACKAGE_SIZE] = package + total_comm_bandwidth_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_comm_bandwidth_info[TableConstant.GROUP_NAME] = "" + total_comm_bandwidth_info[TableConstant.LARGE_PACKET_RATIO] = 0.0 + self.res_comm_bandwidth.append(total_comm_bandwidth_info) + + def compute_bandwidth(self, res_dict, data_dict, package): + for key in data_dict.keys(): + if key in [TableConstant.TRANSIT_TIME, TableConstant.TRANSIT_SIZE]: + if key not in res_dict.keys(): + res_dict[key] = 0.0 + res_dict[key] += data_dict[key] + elif key in [TableConstant.COUNT, TableConstant.TOTAL_DURATION]: + if data_dict[TableConstant.PACKAGE_SIZE] == package: + if key not in res_dict.keys(): + res_dict[key] = 0.0 + res_dict[key] += data_dict[key] + else: + res_dict[key] = 0.0 + else: + res_dict[key] = data_dict[key] + + def compute_time(self, res_dict, data_dict, dict_key): + if dict_key.endswith(self.TIME_EXTENSION): + if dict_key not in res_dict.keys(): + res_dict[dict_key] = 0.0 + res_dict[dict_key] += data_dict[dict_key] + else: + res_dict[dict_key] = data_dict[dict_key] + + def compute_rank_set_total_time_info(self, data_list: list, rank_tuple: any, is_p2p: bool = False): + if not data_list: + return + rank_set = "(" + ",".join(str(i) for i in rank_tuple) + ")" if not is_p2p else Constant.P2P + for rank_id in rank_tuple: + total_comm_time_info = dict() + for data in data_list: + if data[TableConstant.RANK_ID] == rank_id: + data[TableConstant.RANK_SET] = rank_set + data[TableConstant.SYNCHRONIZATION_TIME_RATIO] = 0.0 + data[TableConstant.WAIT_TIME_RATIO] = 0.0 + for key, value in data.items(): + self.compute_time(total_comm_time_info, data, key) + syn_ratio = BaseAnalysisJson.compute_ratio(total_comm_time_info[TableConstant.SYNCHRONIZATION_TIME], + total_comm_time_info[TableConstant.SYNCHRONIZATION_TIME] + + total_comm_time_info[TableConstant.TRANSIT_TIME]) + wait_time_ratio = BaseAnalysisJson.compute_ratio(total_comm_time_info[TableConstant.WAIT_TIME], + total_comm_time_info[TableConstant.WAIT_TIME] + + total_comm_time_info[TableConstant.TRANSIT_TIME]) + total_comm_time_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_comm_time_info[TableConstant.GROUP_NAME] = "" + total_comm_time_info[TableConstant.START_TIMESTAMP] = 0.0 + total_comm_time_info[TableConstant.WAIT_TIME_RATIO] = wait_time_ratio + total_comm_time_info[TableConstant.SYNCHRONIZATION_TIME_RATIO] = syn_ratio + self.res_comm_time.append(total_comm_time_info) + self.res_comm_time.extend(data_list) diff --git a/profiler/cluster_analyse/analysis/communication_analysis.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_json.py similarity index 43% rename from profiler/cluster_analyse/analysis/communication_analysis.py rename to profiler/cluster_analyse/analysis/communication/communication_analysis_json.py index 88ac073a9cc899ecfb32378a8aca662de2bfe879..9b86eada4b49c3e2a3d845f4b222442ee039139b 100644 --- a/profiler/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_json.py @@ -14,61 +14,12 @@ # limitations under the License. from collections import defaultdict -from abc import abstractmethod +from analysis.base_analysis_json import BaseAnalysisJson from common_func.constant import Constant -from common_func.file_manager import FileManager -class BaseCommAnalysis: - - def __init__(self, param: dict): - self.collection_path = param.get(Constant.COLLECTION_PATH) - self.data_map = param.get(Constant.DATA_MAP) - self.communication_ops = [] - self.collective_group_dict = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) - self.comm_ops_struct = {} - - @staticmethod - def compute_ratio(dividend: float, divisor: float): - if abs(divisor) < Constant.EPS: - return 0 - else: - return round(dividend / divisor, 4) - - @abstractmethod - def run(self): - pass - - def dump_data(self): - if not self.comm_ops_struct: - print("[WARNING] There is no final comm ops data generated") - return - output_comm_data = {} - for key in self.comm_ops_struct: - output_comm_data[str(key)] = self.comm_ops_struct.get(key) - FileManager.create_json_file(self.collection_path, output_comm_data, self.SAVED_JSON) - - def split_op_by_group(self): - for single_op in self.communication_ops: - if single_op.get(Constant.COMM_OP_TYPE) == Constant.P2P: - rank_tup = Constant.P2P - else: - rank_tup = tuple(self.collective_group_dict.get(single_op.get(Constant.GROUP_NAME), [])) - rank_id = single_op.get(Constant.RANK_ID, 'N/A') - step_id = single_op.get(Constant.STEP_ID, 'N/A') - op_name = single_op.get(Constant.COMM_OP_NAME, 'N/A') - op_info = single_op.get(Constant.COMM_OP_INFO) - self.comm_ops_struct.setdefault(rank_tup, {}).setdefault(step_id, {}).\ - setdefault(op_name, {}).setdefault(rank_id, op_info) - - def combine_ops_total_info(self): - for rank_tup, group_dict in self.comm_ops_struct.items(): - for step_id, communication_ops in group_dict.items(): - self.compute_total_info(communication_ops) - - -class CommunicationAnalysis(BaseCommAnalysis): +class CommunicationJsonAnalysis(BaseAnalysisJson): SAVED_JSON = "cluster_communication.json" def __init__(self, param: dict): @@ -144,100 +95,3 @@ class CommunicationAnalysis(BaseCommAnalysis): bandwidth_dict[Constant.BANDWIDTH_GB_S] = \ self.compute_ratio(bandwidth_dict.get(Constant.TRANSIT_SIZE_MB, 0), bandwidth_dict.get(Constant.TRANSIT_TIME_MS, 0)) - - -class CommMatrixAnalysis(BaseCommAnalysis): - SAVED_JSON = "cluster_communication_matrix.json" - STAT_LIST = ['middle', 'top', 'bottom', 'total'] - TOTAL = 'total' - - def __init__(self, param: dict): - super().__init__(param) - self.communication_ops = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.MATRIX_OPS) - - @staticmethod - def combine_link(link_info_dict: dict, single_link_dict: dict): - link_info_dict[Constant.TRANSPORT_TYPE] = single_link_dict.get(Constant.TRANSPORT_TYPE) - link_info_dict[Constant.OP_NAME] = single_link_dict.get(Constant.OP_NAME, '') - link_info_dict[Constant.TRANSIT_TIME_MS] += single_link_dict.get(Constant.TRANSIT_TIME_MS, 0) - link_info_dict[Constant.TRANSIT_SIZE_MB] += single_link_dict.get(Constant.TRANSIT_SIZE_MB, 0) - - def run(self): - if not self.communication_ops: - return - self.split_op_by_group() - self.combine_ops_total_info() - self.dump_data() - - def compute_total_info(self, step_dict: dict): - self.merge_same_links(step_dict) - self.combine_link_info(step_dict) - - def merge_same_links(self, step_dict: dict): - def process_link_key(): - for link_key in rank_dict: - if '-' not in link_key: - print(f"[WARNING] {op_name} has an invalid link key {link_key}!") - break - src_rank = link_key.split('-')[0] - dst_rank = link_key.split('-')[1] - if src_rank == dst_rank: - if src_rank not in project_local_global_rank_map: - project_local_global_rank_map[src_rank] = rank_id - elif project_local_global_rank_map.get(src_rank) != rank_id: - print(f"[WARNING] In the same communication group, local ranks projecting to global ranks repeat!") - self.combine_link(link_info[link_key], rank_dict[link_key]) - - def convert_local_to_global_rank(): - tmp_link = {} - for link_key, link_dict in link_info.items(): - src_rank = link_key.split('-')[0] - dst_rank = link_key.split('-')[1] - src_rank = project_local_global_rank_map[src_rank] \ - if src_rank in project_local_global_rank_map else src_rank - dst_rank = project_local_global_rank_map[dst_rank] \ - if dst_rank in project_local_global_rank_map else dst_rank - link_dict[Constant.BANDWIDTH_GB_S] = \ - self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), - link_dict.get(Constant.TRANSIT_TIME_MS, 0)) - tmp_link[f"{src_rank}-{dst_rank}"] = link_dict - return tmp_link - - project_local_global_rank_map = dict() - for op_name, op_dict in step_dict.items(): - link_info = defaultdict(lambda: { - Constant.TRANSPORT_TYPE: '', - Constant.TRANSIT_TIME_MS: 0, - Constant.TRANSIT_SIZE_MB: 0, - Constant.OP_NAME: '' - }) - for rank_id, rank_dict in op_dict.items(): - process_link_key() - step_dict[op_name] = convert_local_to_global_rank() - - def combine_link_info(self, step_dict: dict): - total_op_info = defaultdict(lambda: { - Constant.TRANSPORT_TYPE: '', - Constant.TRANSIT_TIME_MS: 0, - Constant.TRANSIT_SIZE_MB: 0, - Constant.OP_NAME: '' - }) - for op_name, op_dict in step_dict.items(): - if self.check_add_op(op_name): - for link_key, link_dict in op_dict.items(): - self.combine_link(total_op_info[link_key], link_dict) - for link_key, link_dict in total_op_info.items(): - link_dict[Constant.BANDWIDTH_GB_S] = \ - self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), - link_dict.get(Constant.TRANSIT_TIME_MS, 0)) - step_dict[Constant.TOTAL_OP_INFO] = total_op_info - - def check_add_op(self: any, op_name: str): - """ - 兼容2个版本,判断是否需要将此算子信息相加 - """ - for stat_name in self.STAT_LIST: - if stat_name in op_name: - if stat_name != self.TOTAL: - return False - return True diff --git a/profiler/cluster_analyse/analysis/communication_matrix/__init__.py b/profiler/cluster_analyse/analysis/communication_matrix/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py new file mode 100644 index 0000000000000000000000000000000000000000..afc1d40d8fce9cb4acc9c9ee0e467ac851534b06 --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py @@ -0,0 +1,147 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os + +from analysis.base_analysis_json import BaseAnalysisJson +from common_func.db_manager import DBManager +from common_func.constant import Constant +from common_func.table_constant import TableConstant + + +class CommMatrixAnalysisDB: + COMMUNICATION_MATRIX_TABLE = "ClusterCommAnalyzerMatrix" + + def __init__(self, params: any): + self.collection_path = params.get(Constant.COLLECTION_PATH) + self.matrix_info = params.get(Constant.COMM_DATA_DICT, {}).get(Constant.MATRIX_OPS) + self.collective_group_dict = params.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) + self.comm_matrix_struct = {} + self.res_comm_matrix = [] + + def run(self): + if not self.matrix_info: + return + self.set_rank_tuple() + self.combine_total_matrix_info() + self.dump_data() + + def dump_data(self): + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_MATRIX_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) + res = [] + for data in self.res_comm_matrix: + op_name = data.get(TableConstant.OPNAME) if data.get(TableConstant.OPNAME) is not None else "" + res.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.HCCL_OP_NAME], + data[TableConstant.GROUP_NAME], data[TableConstant.SRC_RANK], data[TableConstant.DST_RANK], + data[TableConstant.TRANSIT_SIZE], data[TableConstant.TRANSIT_TIME], + data[TableConstant.BANDWIDTH], data[TableConstant.TRANSPORT_TYPE], op_name]) + if res: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_MATRIX_TABLE, + value="?," * (len(res[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res) + DBManager.destroy_db_connect(conn, cursor) + + def combine_total_matrix_info(self): + for rank_tuple, group_dict in self.comm_matrix_struct.items(): + if rank_tuple != Constant.P2P: + rank_tuple = "(" + ",".join(str(i) for i in rank_tuple) + ")" + for step, step_dict in group_dict.items(): + self.merge_same_info(step_dict, rank_tuple) + self.combine_total_info(step_dict) + + def combine_total_info(self, step_dict: dict): + link_key_set = set() + for op_name, matrix_dict in step_dict.items(): + if BaseAnalysisJson.check_add_op(op_name): + self.res_comm_matrix.extend(matrix_dict.values()) + for key in matrix_dict.keys(): + link_key_set.add(key) + for link_key in link_key_set: + total_matrix_info = dict() + total_matrix_info[TableConstant.TRANSIT_SIZE] = 0.0 + total_matrix_info[TableConstant.TRANSIT_TIME] = 0.0 + for op_name, matrix_dict in step_dict.items(): + if link_key in matrix_dict.keys(): + total_matrix_info[TableConstant.RANK_SET] = matrix_dict[link_key][TableConstant.RANK_SET] + self.combine_link_info(total_matrix_info, matrix_dict[link_key]) + bandwidth = BaseAnalysisJson.compute_ratio(total_matrix_info[TableConstant.TRANSIT_SIZE], + total_matrix_info[TableConstant.TRANSIT_TIME]) + total_matrix_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_matrix_info[TableConstant.GROUP_NAME] = "" + total_matrix_info[TableConstant.BANDWIDTH] = bandwidth + self.res_comm_matrix.append(total_matrix_info) + + def combine_link_info(self, link_info, data: dict): + for col in data.keys(): + if col in [TableConstant.TRANSIT_TIME, TableConstant.TRANSIT_SIZE]: + link_info[col] += data[col] + else: + link_info[col] = data[col] + + def merge_same_info(self, step_dict: dict, rank_tuple): + def process_matrix(): + for data in op_list: + if data[TableConstant.SRC_RANK] == data[TableConstant.DST_RANK]: + if data[TableConstant.SRC_RANK] not in local_global_rank_map: + local_global_rank_map[data[TableConstant.SRC_RANK]] = data[TableConstant.RANK_ID] + elif local_global_rank_map[data[TableConstant.SRC_RANK]] != data[TableConstant.RANK_ID]: + print(f"[WARNING] In the same communication group, local ranks projecting to global ranks " + f"repeat!") + if (link_key.split('-')[0] == data[TableConstant.SRC_RANK] and + link_key.split('-')[1] == data[TableConstant.DST_RANK]): + self.combine_link_info(matrix_info, data) + new_matrix_list[link_key] = matrix_info + + def convert_local_to_global_rank(): + res_dict = dict() + for key, new_matrix in new_matrix_list.items(): + src_rank = new_matrix[TableConstant.SRC_RANK] + dst_rank = new_matrix[TableConstant.DST_RANK] + src_rank = local_global_rank_map[src_rank] if src_rank in local_global_rank_map else src_rank + dst_rank = local_global_rank_map[dst_rank] if dst_rank in local_global_rank_map else dst_rank + bandwidth = BaseAnalysisJson.compute_ratio(new_matrix[TableConstant.TRANSIT_SIZE], + new_matrix[TableConstant.TRANSIT_TIME]) + key = f"{src_rank}-{dst_rank}" + new_matrix[TableConstant.SRC_RANK] = src_rank + new_matrix[TableConstant.DST_RANK] = dst_rank + new_matrix[TableConstant.BANDWIDTH] = bandwidth + res_dict[key] = new_matrix + return res_dict + + local_global_rank_map = dict() + for op_name, op_list in step_dict.items(): + new_matrix_list = {} + link_key_set = set() + for op_data in op_list: + link_key_set.add(op_data[TableConstant.SRC_RANK] + "-" + op_data[TableConstant.DST_RANK]) + for link_key in link_key_set: + matrix_info = dict() + matrix_info[TableConstant.RANK_SET] = rank_tuple + matrix_info[TableConstant.TRANSIT_SIZE] = 0.0 + matrix_info[TableConstant.TRANSIT_TIME] = 0.0 + process_matrix() + step_dict[op_name] = convert_local_to_global_rank() + + def set_rank_tuple(self): + for data in self.matrix_info: + op_name = data[TableConstant.HCCL_OP_NAME] + "@" + data[TableConstant.GROUP_NAME] + if data[TableConstant.STEP] == Constant.P2P: + rank_tuple = Constant.P2P + else: + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + self.comm_matrix_struct.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], {}). \ + setdefault(op_name, []).append(data) diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py new file mode 100644 index 0000000000000000000000000000000000000000..38511615758d3906207d722e701968b7a0178d0a --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py @@ -0,0 +1,105 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict + +from analysis.base_analysis_json import BaseAnalysisJson +from common_func.constant import Constant + + +class CommMatrixAnalysisJson(BaseAnalysisJson): + SAVED_JSON = "cluster_communication_matrix.json" + + def __init__(self, param: dict): + super().__init__(param) + self.communication_ops = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.MATRIX_OPS) + + @staticmethod + def combine_link(link_info_dict: dict, single_link_dict: dict): + link_info_dict[Constant.TRANSPORT_TYPE] = single_link_dict.get(Constant.TRANSPORT_TYPE) + link_info_dict[Constant.OP_NAME] = single_link_dict.get(Constant.OP_NAME, '') + link_info_dict[Constant.TRANSIT_TIME_MS] += single_link_dict.get(Constant.TRANSIT_TIME_MS, 0) + link_info_dict[Constant.TRANSIT_SIZE_MB] += single_link_dict.get(Constant.TRANSIT_SIZE_MB, 0) + + def run(self): + if not self.communication_ops: + return + self.split_op_by_group() + self.combine_ops_total_info() + self.dump_data() + + def compute_total_info(self, step_dict: dict): + self.merge_same_links(step_dict) + self.combine_link_info(step_dict) + + def merge_same_links(self, step_dict: dict): + def process_link_key(): + for link_key in rank_dict: + if '-' not in link_key: + print(f"[WARNING] {op_name} has an invalid link key {link_key}!") + break + src_rank = link_key.split('-')[0] + dst_rank = link_key.split('-')[1] + if src_rank == dst_rank: + if src_rank not in project_local_global_rank_map: + project_local_global_rank_map[src_rank] = rank_id + elif project_local_global_rank_map.get(src_rank) != rank_id: + print(f"[WARNING] In the same communication group, local ranks projecting to global ranks " + f"repeat!") + self.combine_link(link_info[link_key], rank_dict[link_key]) + + def convert_local_to_global_rank(): + tmp_link = {} + for link_key, link_dict in link_info.items(): + src_rank = link_key.split('-')[0] + dst_rank = link_key.split('-')[1] + src_rank = project_local_global_rank_map[src_rank] \ + if src_rank in project_local_global_rank_map else src_rank + dst_rank = project_local_global_rank_map[dst_rank] \ + if dst_rank in project_local_global_rank_map else dst_rank + link_dict[Constant.BANDWIDTH_GB_S] = \ + self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), + link_dict.get(Constant.TRANSIT_TIME_MS, 0)) + tmp_link[f"{src_rank}-{dst_rank}"] = link_dict + return tmp_link + + project_local_global_rank_map = dict() + for op_name, op_dict in step_dict.items(): + link_info = defaultdict(lambda: { + Constant.TRANSPORT_TYPE: '', + Constant.TRANSIT_TIME_MS: 0, + Constant.TRANSIT_SIZE_MB: 0, + Constant.OP_NAME: '' + }) + for rank_id, rank_dict in op_dict.items(): + process_link_key() + step_dict[op_name] = convert_local_to_global_rank() + + def combine_link_info(self, step_dict: dict): + total_op_info = defaultdict(lambda: { + Constant.TRANSPORT_TYPE: '', + Constant.TRANSIT_TIME_MS: 0, + Constant.TRANSIT_SIZE_MB: 0, + Constant.OP_NAME: '' + }) + for op_name, op_dict in step_dict.items(): + if self.check_add_op(op_name): + for link_key, link_dict in op_dict.items(): + self.combine_link(total_op_info[link_key], link_dict) + for link_key, link_dict in total_op_info.items(): + link_dict[Constant.BANDWIDTH_GB_S] = \ + self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), + link_dict.get(Constant.TRANSIT_TIME_MS, 0)) + step_dict[Constant.TOTAL_OP_INFO] = total_op_info diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_generator.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..d943338aa7ac1965a76dc40f447f8e3d468943f2 --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_generator.py @@ -0,0 +1,32 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from analysis.communication_matrix.comm_matrix_analysis_db import CommMatrixAnalysisDB +from analysis.communication_matrix.comm_matrix_analysis_json import CommMatrixAnalysisJson +from common_func.constant import Constant + + +class CommMatrixAnalysisGenerator: + + GROUP_MAP = { + Constant.DB: CommMatrixAnalysisDB, + Constant.TEXT: CommMatrixAnalysisJson + } + + def __init__(self, params: dict): + self.generator = self.GROUP_MAP[params[Constant.DATA_TYPE]](params) + + def run(self): + self.generator.run() diff --git a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py index d24a7f1fe635e62c0857e276578463539a61ee76..20a71df3c57437e5f278ebe450c8811b26bbe3ef 100644 --- a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py +++ b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py @@ -14,8 +14,8 @@ # limitations under the License. import os -from collections import defaultdict +from common_func.db_manager import DBManager from common_func.constant import Constant from common_func.file_manager import FileManager from prof_bean.step_trace_time_bean import StepTraceTimeBean @@ -23,6 +23,7 @@ from prof_bean.step_trace_time_bean import StepTraceTimeBean class StepTraceTimeAnalysis: CLUSTER_TRACE_TIME_CSV = "cluster_step_trace_time.csv" + CLUSTER_TRACE_TIME_TABLE = "ClusterStepTraceTime" def __init__(self, param: dict): self.collection_path = param.get(Constant.COLLECTION_PATH) @@ -30,6 +31,7 @@ class StepTraceTimeAnalysis: self.communication_group = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_GROUP) self.step_time_dict = {} self.step_data_list = [] + self.data_type = param.get(Constant.DATA_TYPE) @staticmethod def get_max_data_row(data_group_list: list): @@ -51,21 +53,44 @@ class StepTraceTimeAnalysis: def dump_data(self): if not self.step_data_list: print("[WARNING] Can't get step time info!") - headers = self.get_headers() - FileManager.create_csv_file(self.collection_path, self.step_data_list, self.CLUSTER_TRACE_TIME_CSV, headers) + if self.data_type == Constant.TEXT: + headers = self.get_headers() + FileManager.create_csv_file(self.collection_path, self.step_data_list, self.CLUSTER_TRACE_TIME_CSV, headers) + else: + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.CLUSTER_TRACE_TIME_TABLE) + conn, cursor = DBManager.create_connect_db(result_db) + sql = "insert into {} values ({value})".format(self.CLUSTER_TRACE_TIME_TABLE, + value="?," * (len(self.step_data_list[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, self.step_data_list) + DBManager.destroy_db_connect(conn, cursor) def load_step_trace_time_data(self): for rank_id, profiling_dir_path in self.data_map.items(): - step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) - if step_time_file: - self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) + if self.data_type == Constant.TEXT: + step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.STEP_TIME_CSV) + if step_time_file: + self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) + else: + step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, + Constant.DB_COMMUNICATION_ANALYZER) + if step_time_file and DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE): + conn, cursor = DBManager.create_connect_db(step_time_file) + sql = "select * from {0}".format(Constant.TABLE_STEP_TRACE) + data = DBManager.fetch_all_data(cursor, sql, is_dict=False) + self.step_time_dict[rank_id] = data + DBManager.destroy_db_connect(conn, cursor) if not self.step_time_dict.get(rank_id): print(f"[WARNING] Rank {rank_id} does not have a valid step_trace_time.json.") def analyze_step_time(self): for rank_id, data_bean_list in self.step_time_dict.items(): for data_bean in data_bean_list: - self.step_data_list.append([data_bean.step, Constant.RANK, rank_id] + data_bean.row) + if self.data_type == Constant.TEXT: + self.step_data_list.append([data_bean.step, Constant.RANK, rank_id] + data_bean.row) + else: + self.step_data_list.append([data_bean[0], Constant.RANK, rank_id] + list(data_bean[1:])) stage_list = self.communication_group.get(Constant.P2P) if not stage_list: return @@ -80,7 +105,11 @@ class StepTraceTimeAnalysis: step_group_dict.setdefault(key, []).append(data_list[3:]) for key, data_group_list in step_group_dict.items(): - self.step_data_list.append([key[0], Constant.STAGE, key[1]] + self.get_max_data_row(data_group_list)) + if self.data_type == Constant.TEXT: + self.step_data_list.append([key[0], Constant.STAGE, key[1]] + self.get_max_data_row(data_group_list)) + else: + index = "(" + ",".join(str(i) for i in key[1]) + ")" + self.step_data_list.append([key[0], Constant.STAGE, index] + self.get_max_data_row(data_group_list)) def get_headers(self): if self.step_time_dict: diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index e07cac170300650bbf735f7e302b33377dd30a5e..68eae526fb05479bc8b93f3bfc51037df221dc25 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -14,6 +14,7 @@ # limitations under the License. import argparse +import glob import os from cluster_data_preprocess.pytorch_data_preprocessor import PytorchDataPreprocessor @@ -28,6 +29,8 @@ from analysis.analysis_facade import AnalysisFacade class Interface: ASCEND_PT = "ascend_pt" ASCEND_MS = "ascend_ms" + DB_RESULT_INFO = "*.db" + ALL_RESULT_INFO = "*.*" def __init__(self, params: dict): self.collection_path = PathManager.get_realpath(params.get(Constant.COLLECTION_PATH)) @@ -38,6 +41,25 @@ class Interface: self.communication_ops = [] self.matrix_ops = [] + def check_db_or_other_files(self, data_map: dict) -> tuple: + type_db_count = 0 + type_text_count = 0 + for _, folder_path in data_map.items(): + folder_path = os.path.join(folder_path, Constant.SINGLE_OUTPUT) + db_files = glob.glob(os.path.join(folder_path, self.DB_RESULT_INFO)) + all_files = glob.glob(os.path.join(folder_path, self.ALL_RESULT_INFO)) + if all_files and db_files and len(all_files) != len(db_files): + return False, None + if db_files: + type_db_count += 1 + else: + type_text_count += 1 + if type_db_count == len(data_map): + return True, Constant.DB + if type_text_count == len(data_map): + return True, Constant.TEXT + return False, None + def allocate_prof_data(self): ascend_pt_dirs = [] ascend_ms_dirs = [] @@ -51,7 +73,7 @@ class Interface: ms_data_map = MindsporeDataPreprocessor(ascend_ms_dirs).get_data_map() if pt_data_map and ms_data_map: print("[ERROR] Can not analyze pytorch and mindspore meantime.") - return[] + return [] return pt_data_map if pt_data_map else ms_data_map def run(self): @@ -62,10 +84,15 @@ class Interface: if not data_map: print("[WARNING] Can not get rank info or profiling data.") return + is_valid, data_type = self.check_db_or_other_files(data_map) + if not is_valid: + print("[WARNING] The current folder contains both DB and other files. Please check.") + return params = { Constant.COLLECTION_PATH: self.collection_path, Constant.DATA_MAP: data_map, - Constant.ANALYSIS_MODE: self.analysis_mode + Constant.ANALYSIS_MODE: self.analysis_mode, + Constant.DATA_TYPE: data_type } comm_data_dict = CommunicationGroupGenerator(params).generate() params[Constant.COMM_DATA_DICT] = comm_data_dict diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index e426a9d22567ae9e70411f709c1c09ce02cbdeca..71caee40db8b58ff263ad5d7311e797684883f3d 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -30,6 +30,7 @@ class Constant(object): MAX_JSON_SIZE = 1024 * 1024 * 1024 * 10 MAX_CSV_SIZE = 1024 * 1024 * 1024 * 5 MAX_PATH_LENGTH = 4096 + MAX_READ_DB_FILE_BYTES = 1024 * 1024 * 1024 * 8 # communication P2P = "p2p" @@ -66,11 +67,12 @@ class Constant(object): COMMUNICATION_GROUP = "communication_group" TRANSPORT_TYPE = "Transport Type" COMM_DATA_DICT = "comm_data_dict" + DATA_TYPE = "data_type" ANALYSIS_MODE = "analysis_mode" # step time - RANK = 'rank' - STAGE = 'stage' + RANK = "rank" + STAGE = "stage" # epsilon EPS = 1e-15 @@ -78,3 +80,17 @@ class Constant(object): # file suffix JSON_SUFFIX = ".json" CSV_SUFFIX = ".csv" + + # result files type + TEXT = "text" + DB = "db" + + # db name + DB_COMMUNICATION_ANALYZER = "analysis.db" + DB_CLUSTER_COMMUNICATION_ANALYZER = "cluster_analysis.db" + + # db tables + TABLE_COMM_ANALYZER_BANDWIDTH = "CommAnalyzerBandwidth" + TABLE_COMM_ANALYZER_TIME = "CommAnalyzerTime" + TABLE_COMM_ANALYZER_MATRIX = "CommAnalyzerMatrix" + TABLE_STEP_TRACE = "StepTraceTime" diff --git a/profiler/cluster_analyse/common_func/db_manager.py b/profiler/cluster_analyse/common_func/db_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..5ec6e10a9e5eb496990ebde0f41b13b4c4c00f0a --- /dev/null +++ b/profiler/cluster_analyse/common_func/db_manager.py @@ -0,0 +1,204 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sqlite3 + +from common_func.constant import Constant +from common_func.empty_class import EmptyClass +from common_func.file_manager import check_db_path_valid +from common_func.tables_config import TablesConfig + + +class DBManager: + """ + class to manage DB operation + """ + FETCH_SIZE = 10000 + INSERT_SIZE = 10000 + MAX_ROW_COUNT = 100000000 + + @staticmethod + def create_connect_db(db_path: str) -> tuple: + """ + create and connect database + """ + if check_db_path_valid(db_path): + try: + conn = sqlite3.connect(db_path) + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return EmptyClass("empty conn"), EmptyClass("empty curs") + try: + if isinstance(conn, sqlite3.Connection): + curs = conn.cursor() + os.chmod(db_path, Constant.FILE_AUTHORITY) + return conn, curs + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return EmptyClass("empty conn"), EmptyClass("empty curs") + return EmptyClass("empty conn"), EmptyClass("empty curs") + + @staticmethod + def destroy_db_connect(conn: any, curs: any) -> None: + """ + destroy db connection + """ + try: + if isinstance(curs, sqlite3.Cursor): + curs.close() + except sqlite3.Error as err: + print(f"[ERROR] {err}") + try: + if isinstance(conn, sqlite3.Connection): + conn.close() + except sqlite3.Error as err: + print(f"[ERROR] {err}") + + @staticmethod + def judge_table_exists(curs: any, table_name: str) -> any: + """ + judge table exists + """ + if not isinstance(curs, sqlite3.Cursor): + return False + try: + curs.execute("select count(*) from sqlite_master where type='table' and name=?", (table_name,)) + return curs.fetchone()[0] + except sqlite3.Error as err: + print("[ERROR] {}".format(err)) + return False + + @staticmethod + def sql_generate_table(table_map: str): + header_with_type_begin = "(" + header_with_type_end = ")" + header_with_type_list = [] + if table_map in TablesConfig.DATA: + items = TablesConfig.DATA[table_map] + for item in items: + if item[0] == "index": + header_with_type_list.append('"' + item[0] + '" ' + item[1].split(",")[0]) + else: + header_with_type_list.append(item[0] + ' ' + item[1].split(",")[0]) + header_with_type_begin += ",".join(header_with_type_list) + header_with_type_begin += header_with_type_end + return header_with_type_begin + + @classmethod + def check_tables_in_db(cls, db_path: any, *tables: any) -> bool: + if check_db_path_valid(db_path, True): + conn, curs = cls.create_connect_db(db_path) + if not (conn and curs): + return False + res = True + for table in tables: + if not cls.judge_table_exists(curs, table): + res = False + break + cls.destroy_db_connect(conn, curs) + return res + + @classmethod + def create_tables(cls, db_path: any, *tables: any) -> bool: + conn, curs = cls.create_connect_db(db_path) + for table_name in tables: + if not cls.judge_table_exists(curs, table_name): + table_map = "{0}Map".format(table_name) + header_with_type = cls.sql_generate_table(table_map) + sql = "CREATE TABLE IF NOT EXISTS " + table_name + header_with_type + cls.execute_sql(conn, sql) + + @staticmethod + def execute_sql(conn: any, sql: str, params: any = None) -> bool: + """ + execute sql + """ + try: + if isinstance(conn, sqlite3.Connection): + if params: + conn.cursor().execute(sql, params) + else: + conn.cursor().execute(sql) + conn.commit() + return True + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return False + print("[ERROR] conn is invalid param") + return False + + @staticmethod + def executemany_sql(conn: any, sql: str, params: any) -> bool: + """ + execute many sql once + """ + try: + if isinstance(conn, sqlite3.Connection): + conn.cursor().executemany(sql, params) + conn.commit() + return True + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return False + print("[ERROR] conn is invalid param") + return False + + @classmethod + def fetch_all_data(cls: any, curs: any, sql: str, param: tuple = None, is_dict: bool = True) -> list: + """ + fetch 10000 num of data from db each time to get all data + """ + if not isinstance(curs, sqlite3.Cursor): + return [] + data = [] + try: + if param: + res = curs.execute(sql, param) + else: + res = curs.execute(sql) + except sqlite3.Error as err: + print(f"[ERROR] {err}") + curs.row_factory = None + return [] + try: + description = res.description + while True: + res = curs.fetchmany(cls.FETCH_SIZE) + if is_dict: + data += CustomizedDictFactory.generate_dict_from_db(res, description) + else: + data += res + if len(data) > cls.MAX_ROW_COUNT: + print("[WARRING] The records count in the table exceeds the limit!") + if len(res) < cls.FETCH_SIZE: + break + return data + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return [] + finally: + curs.row_factory = None + + +class CustomizedDictFactory: + @staticmethod + def generate_dict_from_db(data_result: any, description: any) -> any: + description_set = [i[0] for i in description] + res = [] + for data in data_result: + data_dict = dict(zip(description_set, data)) + res.append(data_dict) + return res diff --git a/profiler/cluster_analyse/common_func/empty_class.py b/profiler/cluster_analyse/common_func/empty_class.py new file mode 100644 index 0000000000000000000000000000000000000000..9d41eccbb563dee2e5b14ffb1a3c011e4cd3abcc --- /dev/null +++ b/profiler/cluster_analyse/common_func/empty_class.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class EmptyClass: + + def __init__(self: any, info: str = "") -> None: + self._info = info + + @classmethod + def __bool__(cls: any) -> bool: + return False + + @classmethod + def __str__(cls: any) -> str: + return "" + + @property + def info(self: any) -> str: + return self._info + + @staticmethod + def is_empty() -> bool: + return True diff --git a/profiler/cluster_analyse/common_func/file_manager.py b/profiler/cluster_analyse/common_func/file_manager.py index 3853c806f92de1d8da14e32105fcc869789a9a40..eefa80ceae680332d5cab7198af639ea449aa9b3 100644 --- a/profiler/cluster_analyse/common_func/file_manager.py +++ b/profiler/cluster_analyse/common_func/file_manager.py @@ -71,8 +71,8 @@ class FileManager: PathManager.check_path_writeable(output_path) try: with os.fdopen( - os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), - 'w', newline="" + os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), + 'w', newline="" ) as file: writer = csv.writer(file) if headers: @@ -91,7 +91,7 @@ class FileManager: PathManager.check_path_writeable(output_path) try: with os.fdopen( - os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), 'w' + os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), 'w' ) as file: file.write(json.dumps(data)) except Exception as e: @@ -115,3 +115,13 @@ class FileManager: file_size = os.path.getsize(file_path) if file_size > limit_size: raise RuntimeError(f"The file({base_name}) size exceeds the preset max value.") + + +def check_db_path_valid(path: str, is_create: bool = False, max_size: int = Constant.MAX_READ_DB_FILE_BYTES) -> bool: + if os.path.islink(path): + print(f'[ERROR] The db file path: {path} is link. Please check the path') + return False + if not is_create and os.path.exists(path) and os.path.getsize(path) > max_size: + print(f'[ERROR] The db file: {path} is too large to read. Please check the file') + return False + return True diff --git a/profiler/cluster_analyse/common_func/table_constant.py b/profiler/cluster_analyse/common_func/table_constant.py new file mode 100644 index 0000000000000000000000000000000000000000..de6d47e97e5683493905de5353a9978195e87b70 --- /dev/null +++ b/profiler/cluster_analyse/common_func/table_constant.py @@ -0,0 +1,27 @@ +class TableConstant: + + RANK_SET = "rank_set" + STEP = "step" + RANK_ID = "rank_id" + TYPE = "type" + HCCL_OP_NAME = "hccl_op_name" + GROUP_NAME = "group_name" + START_TIMESTAMP = "start_timestamp" + ELAPSED_TIME = "elapse_time" + TRANSIT_TIME = "transit_time" + WAIT_TIME = "wait_time" + SYNCHRONIZATION_TIME = "synchronization_time" + IDLE_TIME = "idle_time" + SYNCHRONIZATION_TIME_RATIO = "synchronization_time_ratio" + WAIT_TIME_RATIO = "wait_time_ratio" + BAND_TYPE = "band_type" + TRANSIT_SIZE = "transit_size" + BANDWIDTH = "bandwidth" + LARGE_PACKET_RATIO = "large_packet_ratio" + PACKAGE_SIZE = "package_size" + COUNT = "count" + TOTAL_DURATION = "total_duration" + SRC_RANK = "src_rank" + DST_RANK = "dst_rank" + TRANSPORT_TYPE = "transport_type" + OPNAME = "op_name" diff --git a/profiler/cluster_analyse/common_func/tables_config.py b/profiler/cluster_analyse/common_func/tables_config.py new file mode 100644 index 0000000000000000000000000000000000000000..8331dbaf906a6841ae742e7f11170bcb84aab31e --- /dev/null +++ b/profiler/cluster_analyse/common_func/tables_config.py @@ -0,0 +1,79 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TablesConfig: + DATA = { + "ClusterCommAnalyzerTimeMap": [ + ("rank_set", "TEXT, null"), + ("step", "TEXT, null"), + ("rank_id", "INTEGER, null"), + ("hccl_op_name", "TEXT, null"), + ("group_name", "TEXT, null"), + ("start_timestamp", "NUMERIC, null"), + ("elapsed_time", "NUMERIC, null"), + ("transit_time", "NUMERIC, null"), + ("wait_time", "NUMERIC, null"), + ("synchronization_time", "NUMERIC, null"), + ("idle_time", "NUMERIC, null"), + ("synchronization_time_ratio", "NUMERIC, null"), + ("wait_time_ratio", "NUMERIC, null") + ], + "CommunicationGroupMap": [ + ("type", "TEXT, null"), + ("rank_set", "TEXT, null") + ], + "ClusterCommAnalyzerBandwidthMap": [ + ("rank_set", "TEXT, null"), + ("step", "TEXT, null"), + ("rank_id", "INTEGER, null"), + ("hccl_op_name", "TEXT, null"), + ("group_name", "TEXT, null"), + ("band_type", "TEXT, null"), + ("transit_size", "NUMERIC, null"), + ("transit_time", "NUMERIC, null"), + ("bandwidth", "NUMERIC, null"), + ("large_packet_ratio", "NUMERIC, null"), + ("package_size", "NUMERIC, null"), + ("count", "NUMERIC, null"), + ("total_duration", "NUMERIC, null") + ], + "ClusterCommAnalyzerMatrixMap": [ + ("rank_set", "TEXT, null"), + ("step", "TEXT, null"), + ("hccl_op_name", "TEXT, null"), + ("group_name", "TEXT, null"), + ("src_rank", "TEXT, null"), + ("dst_rank", "TEXT, null"), + ("transit_size", "NUMERIC, null"), + ("transit_time", "NUMERIC, null"), + ("bandwidth", "NUMERIC, null"), + ("transport_type", "TEXT, null"), + ("op_name", "TEXT, null") + ], + "ClusterStepTraceTimeMap": [ + ("step", "TEXT, null"), + ("type", "TEXT, null"), + ("index", "TEXT, null"), + ("computing", "NUMERIC, null"), + ("communication_not_overlapped", "NUMERIC, null"), + ("overlapped", "NUMERIC, null"), + ("communication", "NUMERIC, null"), + ("free", "NUMERIC, null"), + ("stage", "NUMERIC, null"), + ("bubble", "NUMERIC, null"), + ("communication_not_overlapped_and_exclude_receive", "NUMERIC, null") + ] + } diff --git a/profiler/cluster_analyse/communication_group/base_communication_group.py b/profiler/cluster_analyse/communication_group/base_communication_group.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a219cf9911be45eb4ffcd9ff4c90e4adb76357 --- /dev/null +++ b/profiler/cluster_analyse/communication_group/base_communication_group.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from abc import abstractmethod +from collections import defaultdict +from copy import deepcopy +from multiprocessing import Pool + +from common_func.constant import Constant + + +class BaseCommunicationGroup: + def __init__(self, params: dict): + self.collection_path = params.get(Constant.COLLECTION_PATH) + self.data_map = params.get(Constant.DATA_MAP) + self.data_type = params.get(Constant.DATA_TYPE) + self.analysis_mode = params.get(Constant.ANALYSIS_MODE) + self.rank_comm_dir_dict = {} + self.p2p_link = [] + self.collective_group_dict = defaultdict(set) + self.p2p_comm_group = [] + self.communication_group = {} + + def load_communication_data(self): + comm_op_dirs = [] + for rank_id, profiling_dir_path in self.data_map.items(): + if self.data_type == Constant.TEXT: + comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_JSON) + matrix_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_MATRIX_JSON) + else: + comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) + matrix_dir = comm_dir + if comm_dir and matrix_dir: + comm_op_dirs.append((rank_id, comm_dir, matrix_dir)) + else: + print( + f"[WARNING] Rank {rank_id} does not have a valid communication.json or communication_matrix.json.") + with Pool() as p: + self.rank_comm_dir_dict = p.map(self.read_communication_func, comm_op_dirs) + + def set_p2p_groups(self): + self.p2p_link = sorted(self.p2p_link, key=lambda x: min(x)) + while self.p2p_link: + union_set = deepcopy(self.p2p_link[0]) + rm_list = [self.p2p_link[0]] + for idx, link_rank_set_x in enumerate(self.p2p_link[1:]): + if UnionFind.is_connected(link_rank_set_x, union_set): + union_set = union_set.union(link_rank_set_x) + rm_list.append(link_rank_set_x) + self.p2p_comm_group.append(union_set) + self.p2p_link = [element for element in self.p2p_link if element not in rm_list] + + def generate_collective_communication_group(self): + self.communication_group[Constant.COLLECTIVE] = \ + [list(group) for group_name, group in self.collective_group_dict.items()] + + def generate_p2p_communication_group(self): + stage_group = {} + for group_name, rank_set in self.collective_group_dict.items(): + if not self.whether_valid_comm_group(rank_set): + continue + unioned_set = set() + remove_key = [] + for first_rank, stage in stage_group.items(): + if UnionFind.is_connected(rank_set, stage): + unioned_set = UnionFind.union(rank_set, stage, unioned_set) + remove_key.append(first_rank) + if unioned_set: + for key in remove_key: + del stage_group[key] + stage_group[min(unioned_set)] = unioned_set + else: + stage_group[min(rank_set)] = rank_set + first_rank_sort_list = sorted([first_rank for first_rank in stage_group]) + self.communication_group[Constant.P2P] = \ + [list(stage_group.get(first_rank, {})) for first_rank in first_rank_sort_list] + + def whether_valid_comm_group(self, rank_set: set): + """ + while distinguish which communication group should be used to infer stage info, these group should be ignored: + 1. group can not include more than 1 rank in every single p2p group + """ + for p2p_rank_set in self.p2p_comm_group: + if len(rank_set.intersection(p2p_rank_set)) > 1: + return False + return True + + @abstractmethod + def read_communication_func(self, params: tuple): + pass + + @abstractmethod + def analyze_communication_data(self): + pass + + @abstractmethod + def dump_data(self): + pass + + def generate(self): + self.load_communication_data() + self.analyze_communication_data() + self.set_p2p_groups() + self.generate_collective_communication_group() + self.generate_p2p_communication_group() + return self.dump_data() + + +class UnionFind(object): + """Disjoint Set Union""" + + @classmethod + def union(cls, p: set, q: set, o: set): + """make p and q the same set""" + return p | q | o + + @classmethod + def is_connected(cls, p: set, q: set): + """ + check whether set p and set q are connected + """ + if p & q: + return True + else: + return False diff --git a/profiler/cluster_analyse/communication_group/communication_db_group.py b/profiler/cluster_analyse/communication_group/communication_db_group.py new file mode 100644 index 0000000000000000000000000000000000000000..db1e8bd4f598b49fb4bc717092e9eb46c6360321 --- /dev/null +++ b/profiler/cluster_analyse/communication_group/communication_db_group.py @@ -0,0 +1,120 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from common_func.db_manager import DBManager +from common_func.constant import Constant +from common_func.table_constant import TableConstant +from communication_group.base_communication_group import BaseCommunicationGroup + + +class CommunicationDBGroup(BaseCommunicationGroup): + COMMUNICATION_GROUP_TABLE = "CommunicationGroup" + + def __init__(self, params: dict): + super().__init__(params) + self.communication_bandwidth_info = [] + self.communication_time_info = [] + self.matrix_info = [] + + def read_communication_func(self, params: tuple): + if len(params) < 3: + return -1, {}, {} + rank_id = params[0] + db_path = params[1] + time_data = {} + bandwidth_data = {} + matrix_data = {} + if DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_TIME, + Constant.TABLE_COMM_ANALYZER_BANDWIDTH, + Constant.TABLE_COMM_ANALYZER_MATRIX): + conn, cursor = DBManager.create_connect_db(db_path) + time_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_TIME) + bandwidth_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + matrix_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_MATRIX) + if self.analysis_mode in ["all", "communication_time"]: + time_data = DBManager.fetch_all_data(cursor, time_info_sql) + bandwidth_data = DBManager.fetch_all_data(cursor, bandwidth_info_sql) + if self.analysis_mode in ["all", "communication_matrix"]: + matrix_data = DBManager.fetch_all_data(cursor, matrix_info_sql) + DBManager.destroy_db_connect(conn, cursor) + return (rank_id, self.data_group_by_step(time_data), self.data_group_by_step(bandwidth_data), + self.data_group_by_step(matrix_data)) + + @staticmethod + def data_group_by_step(data: any) -> any: + res = {} + for item in data: + res.setdefault(item[TableConstant.STEP], []).append(item) + return res + + def dump_data(self): + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_GROUP_TABLE) + res = [] + conn, cursor = DBManager.create_connect_db(result_db) + for data_type, data_list in self.communication_group.items(): + for data in data_list: + rank_set = "(" + ",".join(str(i) for i in data) + ")" + data = [data_type, rank_set] + res.append(data) + if res: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_GROUP_TABLE, + value="?," * (len(res[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res) + DBManager.destroy_db_connect(conn, cursor) + comm_data_dict = { + Constant.COLLECTIVE_GROUP: self.collective_group_dict, + Constant.COMMUNICATION_TIME_INFO: self.communication_time_info, + Constant.COMMUNICATION_BANDWIDTH_INFO: self.communication_bandwidth_info, + Constant.MATRIX_OPS: self.matrix_info, + Constant.COMMUNICATION_GROUP: self.communication_group + } + return comm_data_dict + + def analyze_communication_data(self): + for rank_id, time_data, bandwidth_data, matrix_data in self.rank_comm_dir_dict: + for step, data_list in time_data.items(): + for data in data_list: + self.compute_collective_group(data, rank_id, self.communication_time_info) + for data in bandwidth_data[step]: + self.compute_collective_group(data, rank_id, self.communication_bandwidth_info) + for step, data_list in matrix_data.items(): + self.add_p2p_and_rank(rank_id, step, matrix_data) + for data in data_list: + self.compute_collective_group(data, rank_id, self.matrix_info) + + def compute_collective_group(self, data, rank_id, res_list): + if data[TableConstant.TYPE] == Constant.COLLECTIVE: + self.collective_group_dict[data[TableConstant.GROUP_NAME]].add(rank_id) + data[TableConstant.RANK_ID] = rank_id + res_list.append(data) + + def add_p2p_and_rank(self, rank_id: int, step: str, data_dict: dict): + data_list = data_dict[step] + if not data_list: + print(f"[WARNING] rank {rank_id} {step} don't have communication matrix ops data") + return + for data in data_list: + if data[TableConstant.TYPE] != Constant.COLLECTIVE and data[TableConstant.TYPE] != Constant.P2P: + print(f"[WARNING] Unknown communication operators type!") + continue + if data[TableConstant.TYPE] == Constant.P2P: + if data[TableConstant.SRC_RANK] != data[TableConstant.DST_RANK]: + rank_set = {data[TableConstant.SRC_RANK], data[TableConstant.DST_RANK]} + if rank_set not in self.p2p_link: + self.p2p_link.append(rank_set) diff --git a/profiler/cluster_analyse/communication_group/communication_group_generator.py b/profiler/cluster_analyse/communication_group/communication_group_generator.py index 4963bf95399fea29edf31be324a49801e7f485d1..6754c4d8e62a46b6d2b69867bf569b0fe084c5f0 100644 --- a/profiler/cluster_analyse/communication_group/communication_group_generator.py +++ b/profiler/cluster_analyse/communication_group/communication_group_generator.py @@ -13,211 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from copy import deepcopy -from multiprocessing import Pool -from collections import defaultdict from common_func.constant import Constant -from common_func.file_manager import FileManager +from communication_group.communication_db_group import CommunicationDBGroup +from communication_group.communication_json_group import CommunicationJsonGroup class CommunicationGroupGenerator: - COMMUNICATION_GROUP_JSON = "communication_group.json" + + GROUP_MAP = { + Constant.DB: CommunicationDBGroup, + Constant.TEXT: CommunicationJsonGroup + } def __init__(self, params: dict): - self.collection_path = params.get(Constant.COLLECTION_PATH) - self.data_map = params.get(Constant.DATA_MAP) - self.analysis_mode = params.get(Constant.ANALYSIS_MODE) - self.communication_group = {} - self.collective_group_dict = defaultdict(set) - self.p2p_group_dict = defaultdict(list) - self.rank_comm_dir_dict = {} - self.communication_ops = [] - self.p2p_comm_group = [] - self.p2p_link = [] - self.matrix_ops = [] + self.processor = self.GROUP_MAP[params.get(Constant.DATA_TYPE)](params) def generate(self): - self.load_communication_json() - self.analyze_communication_ops() - self.set_p2p_groups() - self.generate_collective_communication_group() - self.generate_p2p_communication_group() - FileManager.create_json_file(self.collection_path, self.communication_group, self.COMMUNICATION_GROUP_JSON) - comm_data_dict = { - Constant.COLLECTIVE_GROUP: self.collective_group_dict, - Constant.COMMUNICATION_OPS: self.communication_ops, - Constant.MATRIX_OPS: self.matrix_ops, - Constant.COMMUNICATION_GROUP: self.communication_group - } - return comm_data_dict - - def analyze_communication_ops(self): - for rank_id, rank_id_comm_dict, rank_id_matrix_dict in self.rank_comm_dir_dict: - for step_id, step_id_dict in rank_id_comm_dict.items(): - if not isinstance(step_id_dict, dict): - print(f"[WARNING] rank{rank_id}'s communication.json has a wrong data struct.") - continue - self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) - for comm_op_type, comm_op_dict in step_id_dict.items(): - self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) - - for step_id, step_id_dict in rank_id_matrix_dict.items(): - if not isinstance(step_id_dict, dict): - print(f"[WARNING] rank{rank_id}'s communication_matrix.json has a wrong data struct.") - continue - self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) - self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) - - def read_comm_json_func(self: any, params: tuple): - if len(params) < 3: - return -1, {}, {} - rank_id = params[0] - comm_json_path = params[1] - matrix_json_path = params[2] - comm_data = {} - matrix_data = {} - if os.path.exists(comm_json_path) and self.analysis_mode in ['all', 'communication_time']: - comm_data = FileManager.read_json_file(comm_json_path) - if os.path.exists(matrix_json_path) and self.analysis_mode in ['all', 'communication_matrix']: - matrix_data = FileManager.read_json_file(matrix_json_path) - return rank_id, comm_data, matrix_data - - def load_communication_json(self): - comm_op_dirs = [] - for rank_id, profiling_dir_path in self.data_map.items(): - comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_JSON) - matrix_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_MATRIX_JSON) - if comm_dir and matrix_dir: - comm_op_dirs.append((rank_id, comm_dir, matrix_dir)) - else: - print(f"[WARNING] Rank {rank_id} does not have a valid communication.json or communication_matrix.json.") - with Pool() as p: - self.rank_comm_dir_dict = p.map(self.read_comm_json_func, comm_op_dirs) - - def generate_collective_communication_group(self): - self.communication_group[Constant.COLLECTIVE] = \ - [list(group) for group_name, group in self.collective_group_dict.items()] - - def whether_valid_comm_group(self, rank_set: set): - """ - while distinguish which communication group should be used to infer stage info, these group should be ignored: - 1. group can not include more than 1 rank in every single p2p group - """ - for p2p_rank_set in self.p2p_comm_group: - if len(rank_set.intersection(p2p_rank_set)) > 1: - return False - return True - - def generate_p2p_communication_group(self): - stage_group = {} - for group_name, rank_set in self.collective_group_dict.items(): - if not self.whether_valid_comm_group(rank_set): - continue - unioned_set = set() - remove_key = [] - for first_rank, stage in stage_group.items(): - if UnionFind.is_connected(rank_set, stage): - unioned_set = UnionFind.union(rank_set, stage, unioned_set) - remove_key.append(first_rank) - if unioned_set: - for key in remove_key: - del stage_group[key] - stage_group[min(unioned_set)] = unioned_set - else: - stage_group[min(rank_set)] = rank_set - first_rank_sort_list = sorted([first_rank for first_rank in stage_group]) - self.communication_group[Constant.P2P] = \ - [list(stage_group.get(first_rank, {})) for first_rank in first_rank_sort_list] - - def set_p2p_groups(self): - self.p2p_link = sorted(self.p2p_link, key=lambda x: min(x)) - while self.p2p_link: - union_set = deepcopy(self.p2p_link[0]) - rm_list = [self.p2p_link[0]] - for idx, link_rank_set_x in enumerate(self.p2p_link[1:]): - if UnionFind.is_connected(link_rank_set_x, union_set): - union_set = union_set.union(link_rank_set_x) - rm_list.append(link_rank_set_x) - self.p2p_comm_group.append(union_set) - self.p2p_link = [element for element in self.p2p_link if element not in rm_list] - - def set_p2p_link(self, rank_id: int, step_id: str, rank_id_matrix_dict: dict): - ops = rank_id_matrix_dict.get(step_id, {}) - self.add_matrix_ops(rank_id, step_id, ops) - if not ops: - print(f"[WARNING] rank{rank_id} {step_id} do not have communication matrix ops data.") - return - p2p_ops = ops.get(Constant.P2P, {}) - for op_name, link_dict in p2p_ops.items(): - self.append_p2p_link(op_name, link_dict) - - def append_p2p_link(self, op_name, link_dict): - for link in link_dict: - if '-' not in link: - print(f"[WARNING] {op_name} has an invalid link key {link}!") - break - src_rank = int(link.split('-')[0]) - dst_rank = int(link.split('-')[1]) - if src_rank != dst_rank: - rank_set = set([src_rank, dst_rank]) - if rank_set in self.p2p_link: - continue - self.p2p_link.append(rank_set) - - def get_collective_ops_name(self, rank_id: int, comm_op_dict: dict): - for comm_op in comm_op_dict: - if comm_op.startswith('Total'): - continue - group_name = comm_op.split('@')[-1] - self.collective_group_dict[group_name].add(rank_id) - - def add_communication_ops(self, rank_id: str, step_id: str, comm_op_type: str, comm_op_dict: dict): - for comm_op in comm_op_dict: - if comm_op.startswith('Total'): - continue - group_name = comm_op.split('@')[-1] - self.communication_ops.append({ - Constant.RANK_ID: rank_id, - Constant.STEP_ID: step_id, - Constant.COMM_OP_TYPE: comm_op_type, - Constant.COMM_OP_NAME: comm_op, - Constant.GROUP_NAME: group_name, - Constant.COMM_OP_INFO: comm_op_dict.get(comm_op) - }) - - def add_matrix_ops(self, rank_id: int, step_id: str, step_id_dict: dict): - for comm_op_type, comm_dict in step_id_dict.items(): - if comm_op_type != Constant.COLLECTIVE and comm_op_type != Constant.P2P: - print(f"[WARNING] Unknown communication operators type!") - continue - for op_name, op_link_info in comm_dict.items(): - if op_name.startswith('Total'): - continue - group_name = op_name.split('@')[-1] - self.matrix_ops.append({ - Constant.RANK_ID: rank_id, - Constant.STEP_ID: step_id, - Constant.COMM_OP_TYPE: comm_op_type, - Constant.COMM_OP_NAME: op_name, - Constant.GROUP_NAME: group_name, - Constant.COMM_OP_INFO: op_link_info - }) - - -class UnionFind(object): - """Disjoint Set Union""" - @classmethod - def union(cls, p: set, q: set, o: set): - """make p and q the same set""" - return p | q | o - - @classmethod - def is_connected(cls, p: set, q: set): - """ - check whether set p and set q are connected - """ - if p & q: - return True - else: - return False + return self.processor.generate() diff --git a/profiler/cluster_analyse/communication_group/communication_json_group.py b/profiler/cluster_analyse/communication_group/communication_json_group.py new file mode 100644 index 0000000000000000000000000000000000000000..da6e6c1fe4f699af49ad198df41afb80e34e8772 --- /dev/null +++ b/profiler/cluster_analyse/communication_group/communication_json_group.py @@ -0,0 +1,132 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from common_func.constant import Constant +from common_func.file_manager import FileManager +from communication_group.base_communication_group import BaseCommunicationGroup + + +class CommunicationJsonGroup(BaseCommunicationGroup): + COMMUNICATION_GROUP_JSON = "communication_group.json" + + def __init__(self, params: dict): + super().__init__(params) + self.communication_ops = [] + self.matrix_ops = [] + + def dump_data(self): + FileManager.create_json_file(self.collection_path, self.communication_group, self.COMMUNICATION_GROUP_JSON) + comm_data_dict = { + Constant.COLLECTIVE_GROUP: self.collective_group_dict, + Constant.COMMUNICATION_OPS: self.communication_ops, + Constant.MATRIX_OPS: self.matrix_ops, + Constant.COMMUNICATION_GROUP: self.communication_group + } + return comm_data_dict + + def analyze_communication_data(self): + for rank_id, rank_id_comm_dict, rank_id_matrix_dict in self.rank_comm_dir_dict: + for step_id, step_id_dict in rank_id_comm_dict.items(): + if not isinstance(step_id_dict, dict): + print(f"[WARNING] rank{rank_id}'s communication.json has a wrong data struct.") + continue + self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + for comm_op_type, comm_op_dict in step_id_dict.items(): + self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) + + for step_id, step_id_dict in rank_id_matrix_dict.items(): + if not isinstance(step_id_dict, dict): + print(f"[WARNING] rank{rank_id}'s communication_matrix.json has a wrong data struct.") + continue + self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) + self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + + def read_communication_func(self: any, params: tuple): + if len(params) < 3: + return -1, {}, {} + rank_id = params[0] + comm_json_path = params[1] + matrix_json_path = params[2] + comm_data = {} + matrix_data = {} + if os.path.exists(comm_json_path) and self.analysis_mode in ["all", "communication_time"]: + comm_data = FileManager.read_json_file(comm_json_path) + if os.path.exists(matrix_json_path) and self.analysis_mode in ["all", "communication_matrix"]: + matrix_data = FileManager.read_json_file(matrix_json_path) + return rank_id, comm_data, matrix_data + + def set_p2p_link(self, rank_id: int, step_id: str, rank_id_matrix_dict: dict): + ops = rank_id_matrix_dict.get(step_id, {}) + self.add_matrix_ops(rank_id, step_id, ops) + if not ops: + print(f"[WARNING] rank{rank_id} {step_id} do not have communication matrix ops data.") + return + p2p_ops = ops.get(Constant.P2P, {}) + for op_name, link_dict in p2p_ops.items(): + self.append_p2p_link(op_name, link_dict) + + def append_p2p_link(self, op_name, link_dict): + for link in link_dict: + if '-' not in link: + print(f"[WARNING] {op_name} has an invalid link key {link}!") + break + src_rank = int(link.split('-')[0]) + dst_rank = int(link.split('-')[1]) + if src_rank != dst_rank: + rank_set = set([src_rank, dst_rank]) + if rank_set in self.p2p_link: + continue + self.p2p_link.append(rank_set) + + def get_collective_ops_name(self, rank_id: int, comm_op_dict: dict): + for comm_op in comm_op_dict: + if comm_op.startswith('Total'): + continue + group_name = comm_op.split('@')[-1] + self.collective_group_dict[group_name].add(rank_id) + + def add_communication_ops(self, rank_id: str, step_id: str, comm_op_type: str, comm_op_dict: dict): + for comm_op in comm_op_dict: + if comm_op.startswith('Total'): + continue + group_name = comm_op.split('@')[-1] + self.communication_ops.append({ + Constant.RANK_ID: rank_id, + Constant.STEP_ID: step_id, + Constant.COMM_OP_TYPE: comm_op_type, + Constant.COMM_OP_NAME: comm_op, + Constant.GROUP_NAME: group_name, + Constant.COMM_OP_INFO: comm_op_dict.get(comm_op) + }) + + def add_matrix_ops(self, rank_id: int, step_id: str, step_id_dict: dict): + for comm_op_type, comm_dict in step_id_dict.items(): + if comm_op_type != Constant.COLLECTIVE and comm_op_type != Constant.P2P: + print(f"[WARNING] Unknown communication operators type!") + continue + for op_name, op_link_info in comm_dict.items(): + if op_name.startswith('Total'): + continue + group_name = op_name.split('@')[-1] + self.matrix_ops.append({ + Constant.RANK_ID: rank_id, + Constant.STEP_ID: step_id, + Constant.COMM_OP_TYPE: comm_op_type, + Constant.COMM_OP_NAME: op_name, + Constant.GROUP_NAME: group_name, + Constant.COMM_OP_INFO: op_link_info + })