From 4cb40794693bc792710c79d4ebd241598f059623 Mon Sep 17 00:00:00 2001 From: sunboquan Date: Tue, 6 Feb 2024 18:51:29 +0800 Subject: [PATCH] 2 interface for ascend insight --- .../analysis/analysis_facade.py | 7 +++-- .../analysis/communication_analysis.py | 28 ++++++++++++++++--- profiler/cluster_analyse/cluster_analysis.py | 11 ++++++-- .../cluster_analyse/common_func/constant.py | 2 ++ .../communication_group_generator.py | 22 +++++++++------ 5 files changed, 52 insertions(+), 18 deletions(-) diff --git a/profiler/cluster_analyse/analysis/analysis_facade.py b/profiler/cluster_analyse/analysis/analysis_facade.py index 34228f97a2..b383a704df 100644 --- a/profiler/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/cluster_analyse/analysis/analysis_facade.py @@ -14,6 +14,7 @@ # limitations under the License. from multiprocessing import Process +from common_func.constant import Constant from analysis.communication_analysis import CommunicationAnalysis from analysis.step_trace_time_analysis import StepTraceTimeAnalysis from analysis.communication_analysis import CommMatrixAnalysis @@ -22,14 +23,14 @@ from analysis.communication_analysis import CommMatrixAnalysis class AnalysisFacade: analysis_module = {CommunicationAnalysis, StepTraceTimeAnalysis, CommMatrixAnalysis} - def __init__(self, param: dict): - self.param = param + def __init__(self, params: dict): + self.params = params def cluster_analyze(self): # 多个profiler用多进程处理 process_list = [] for analysis in self.analysis_module: - process = Process(target=analysis(self.param).run) + process = Process(target=analysis(self.params).run) process.start() process_list.append(process) diff --git a/profiler/cluster_analyse/analysis/communication_analysis.py b/profiler/cluster_analyse/analysis/communication_analysis.py index a3c51d46a9..88ac073a9c 100644 --- a/profiler/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/cluster_analyse/analysis/communication_analysis.py @@ -82,6 +82,8 @@ class CommunicationAnalysis(BaseCommAnalysis): total_dict[size][1] += size_info[1] def run(self): + if not self.communication_ops: + return self.split_op_by_group() self.combine_ops_total_info() self.dump_data() @@ -146,6 +148,8 @@ class CommunicationAnalysis(BaseCommAnalysis): class CommMatrixAnalysis(BaseCommAnalysis): SAVED_JSON = "cluster_communication_matrix.json" + STAT_LIST = ['middle', 'top', 'bottom', 'total'] + TOTAL = 'total' def __init__(self, param: dict): super().__init__(param) @@ -154,10 +158,13 @@ class CommMatrixAnalysis(BaseCommAnalysis): @staticmethod def combine_link(link_info_dict: dict, single_link_dict: dict): link_info_dict[Constant.TRANSPORT_TYPE] = single_link_dict.get(Constant.TRANSPORT_TYPE) + link_info_dict[Constant.OP_NAME] = single_link_dict.get(Constant.OP_NAME, '') link_info_dict[Constant.TRANSIT_TIME_MS] += single_link_dict.get(Constant.TRANSIT_TIME_MS, 0) link_info_dict[Constant.TRANSIT_SIZE_MB] += single_link_dict.get(Constant.TRANSIT_SIZE_MB, 0) def run(self): + if not self.communication_ops: + return self.split_op_by_group() self.combine_ops_total_info() self.dump_data() @@ -201,7 +208,8 @@ class CommMatrixAnalysis(BaseCommAnalysis): link_info = defaultdict(lambda: { Constant.TRANSPORT_TYPE: '', Constant.TRANSIT_TIME_MS: 0, - Constant.TRANSIT_SIZE_MB: 0 + Constant.TRANSIT_SIZE_MB: 0, + Constant.OP_NAME: '' }) for rank_id, rank_dict in op_dict.items(): process_link_key() @@ -211,13 +219,25 @@ class CommMatrixAnalysis(BaseCommAnalysis): total_op_info = defaultdict(lambda: { Constant.TRANSPORT_TYPE: '', Constant.TRANSIT_TIME_MS: 0, - Constant.TRANSIT_SIZE_MB: 0 + Constant.TRANSIT_SIZE_MB: 0, + Constant.OP_NAME: '' }) for op_name, op_dict in step_dict.items(): - for link_key, link_dict in op_dict.items(): - self.combine_link(total_op_info[link_key], link_dict) + if self.check_add_op(op_name): + for link_key, link_dict in op_dict.items(): + self.combine_link(total_op_info[link_key], link_dict) for link_key, link_dict in total_op_info.items(): link_dict[Constant.BANDWIDTH_GB_S] = \ self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), link_dict.get(Constant.TRANSIT_TIME_MS, 0)) step_dict[Constant.TOTAL_OP_INFO] = total_op_info + + def check_add_op(self: any, op_name: str): + """ + 兼容2个版本,判断是否需要将此算子信息相加 + """ + for stat_name in self.STAT_LIST: + if stat_name in op_name: + if stat_name != self.TOTAL: + return False + return True diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index 50eeedd731..e07cac1703 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -31,6 +31,7 @@ class Interface: def __init__(self, params: dict): self.collection_path = PathManager.get_realpath(params.get(Constant.COLLECTION_PATH)) + self.analysis_mode = params.get(Constant.ANALYSIS_MODE) self.data_map = {} self.communication_group = {} self.collective_group_dict = {} @@ -61,20 +62,24 @@ class Interface: if not data_map: print("[WARNING] Can not get rank info or profiling data.") return - comm_data_dict = CommunicationGroupGenerator(self.collection_path, data_map).generate() params = { Constant.COLLECTION_PATH: self.collection_path, Constant.DATA_MAP: data_map, - Constant.COMM_DATA_DICT: comm_data_dict + Constant.ANALYSIS_MODE: self.analysis_mode } + comm_data_dict = CommunicationGroupGenerator(params).generate() + params[Constant.COMM_DATA_DICT] = comm_data_dict AnalysisFacade(params).cluster_analyze() if __name__ == "__main__": parser = argparse.ArgumentParser(description="cluster analysis module") parser.add_argument('-d', '--collection_path', type=str, required=True, help="profiling data path") + parser.add_argument('-m', '--mode', choices=['all', 'communication_time', 'communication_matrix'], + default='all', help="different analysis mode") args_parsed = parser.parse_args() parameter = { - Constant.COLLECTION_PATH: args_parsed.collection_path + Constant.COLLECTION_PATH: args_parsed.collection_path, + Constant.ANALYSIS_MODE: args_parsed.mode } Interface(parameter).run() diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index 5ca830edef..e426a9d225 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -53,6 +53,7 @@ class Constant(object): TRANSIT_SIZE_MB = "Transit Size(MB)" SIZE_DISTRIBUTION = "Size Distribution" WAIT_TIME_MS = "Wait Time(ms)" + OP_NAME = "Op Name" BANDWIDTH_GB_S = "Bandwidth(GB/s)" COMMUNICATION = "communication.json" @@ -65,6 +66,7 @@ class Constant(object): COMMUNICATION_GROUP = "communication_group" TRANSPORT_TYPE = "Transport Type" COMM_DATA_DICT = "comm_data_dict" + ANALYSIS_MODE = "analysis_mode" # step time RANK = 'rank' diff --git a/profiler/cluster_analyse/communication_group/communication_group_generator.py b/profiler/cluster_analyse/communication_group/communication_group_generator.py index bab983de5b..4963bf9539 100644 --- a/profiler/cluster_analyse/communication_group/communication_group_generator.py +++ b/profiler/cluster_analyse/communication_group/communication_group_generator.py @@ -24,9 +24,10 @@ from common_func.file_manager import FileManager class CommunicationGroupGenerator: COMMUNICATION_GROUP_JSON = "communication_group.json" - def __init__(self, collection_path: str, data_map: dict): - self.collection_path = collection_path - self.data_map = data_map + def __init__(self, params: dict): + self.collection_path = params.get(Constant.COLLECTION_PATH) + self.data_map = params.get(Constant.DATA_MAP) + self.analysis_mode = params.get(Constant.ANALYSIS_MODE) self.communication_group = {} self.collective_group_dict = defaultdict(set) self.p2p_group_dict = defaultdict(list) @@ -57,13 +58,18 @@ class CommunicationGroupGenerator: if not isinstance(step_id_dict, dict): print(f"[WARNING] rank{rank_id}'s communication.json has a wrong data struct.") continue - self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) for comm_op_type, comm_op_dict in step_id_dict.items(): self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) - @staticmethod - def read_comm_json_func(params: tuple): + for step_id, step_id_dict in rank_id_matrix_dict.items(): + if not isinstance(step_id_dict, dict): + print(f"[WARNING] rank{rank_id}'s communication_matrix.json has a wrong data struct.") + continue + self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) + self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + + def read_comm_json_func(self: any, params: tuple): if len(params) < 3: return -1, {}, {} rank_id = params[0] @@ -71,9 +77,9 @@ class CommunicationGroupGenerator: matrix_json_path = params[2] comm_data = {} matrix_data = {} - if os.path.exists(comm_json_path): + if os.path.exists(comm_json_path) and self.analysis_mode in ['all', 'communication_time']: comm_data = FileManager.read_json_file(comm_json_path) - if os.path.exists(matrix_json_path): + if os.path.exists(matrix_json_path) and self.analysis_mode in ['all', 'communication_matrix']: matrix_data = FileManager.read_json_file(matrix_json_path) return rank_id, comm_data, matrix_data -- Gitee