From f6a7ae33520154716403cdbd8aaeb7f47e0f4757 Mon Sep 17 00:00:00 2001 From: sunboquan Date: Mon, 25 Sep 2023 20:18:25 +0800 Subject: [PATCH] optimize stage info --- .../analysis/communication_analysis.py | 2 +- .../communication_group_generator.py | 67 +++++++++++++++---- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/profiler/cluster_analyse/analysis/communication_analysis.py b/profiler/cluster_analyse/analysis/communication_analysis.py index 0fa61f4ca5..1d03687298 100644 --- a/profiler/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/cluster_analyse/analysis/communication_analysis.py @@ -253,4 +253,4 @@ class CommMatrixAnalysis(BaseCommAnalysis): link_dict[Constant.BANDWIDTH_GB_S] = \ self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), link_dict.get(Constant.TRANSIT_TIME_MS, 0)) - + step_dict[Constant.TOTAL_OP_INFO] = total_op_info diff --git a/profiler/cluster_analyse/communication_group/communication_group_generator.py b/profiler/cluster_analyse/communication_group/communication_group_generator.py index 6dabcac3b6..a367e624fb 100644 --- a/profiler/cluster_analyse/communication_group/communication_group_generator.py +++ b/profiler/cluster_analyse/communication_group/communication_group_generator.py @@ -14,6 +14,7 @@ # limitations under the License. import os +from copy import deepcopy from common_func.constant import Constant from common_func.file_manager import FileManager from collections import defaultdict @@ -29,11 +30,15 @@ class CommunicationGroupGenerator: self.collective_group_dict = defaultdict(set) self.p2p_group_dict = defaultdict(list) self.rank_comm_dir_dict = {} + self.rank_matrix_dir_dict = {} self.communication_ops = [] + self.p2p_comm_group = [] + self.p2p_link = [] def generate(self): self.load_communication_json() self.analyze_communication_ops() + self.set_p2p_groups() self.generate_collective_communication_group() self.generate_p2p_communication_group() FileManager.create_json_file(self.collection_path, self.communication_group, self.COMMUNICATION_GROUP_JSON) @@ -45,30 +50,40 @@ class CommunicationGroupGenerator: if not isinstance(step_id_dict, dict): print(f"rank{rank_id}'s communication.json has a wrong data struct.") continue + self.set_p2p_link(rank_id, step_id) + self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) for comm_op_type, comm_op_dict in step_id_dict.items(): - if comm_op_type == Constant.COLLECTIVE: - self.get_collective_ops_name(rank_id, comm_op_dict) - elif comm_op_type == Constant.P2P: - pass - else: - print(f"rank{rank_id}'s communication.json has no p2p or collective.") - continue self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) def load_communication_json(self): for rank_id, profiling_dir_path in self.data_map.items(): comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_JSON) - if comm_dir: + matrix_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_MATRIX_JSON) + if comm_dir and matrix_dir: self.rank_comm_dir_dict[rank_id] = FileManager.read_json_file(comm_dir) - if not self.rank_comm_dir_dict.get(rank_id): - print(f"rank {rank_id} does not have a valid communication.json.") + self.rank_matrix_dir_dict[rank_id] = FileManager.read_json_file(matrix_dir) + else: + print(f"rank {rank_id} does not have a valid communication.json or communication_matrix.json.") def generate_collective_communication_group(self): - self.communication_group[Constant.COLLECTIVE] = [list(group) for group in self.collective_group_dict.values()] + self.communication_group[Constant.COLLECTIVE] = \ + [list(group) for group_name, group in self.collective_group_dict.items()] + + def whether_valid_comm_group(self, rank_set: set): + """ + while distinguish which communication group should be used to infer stage info, these group should be ignored: + 1. group can not include more than 1 rank in every single p2p group + """ + for p2p_rank_set in self.p2p_comm_group: + if len(rank_set.intersection(p2p_rank_set)) > 1: + return False + return True def generate_p2p_communication_group(self): stage_group = {} - for rank_set in self.collective_group_dict.values(): + for group_name, rank_set in self.collective_group_dict.items(): + if not self.whether_valid_comm_group(rank_set): + continue unioned_set = set() remove_key = [] for first_rank, stage in stage_group.items(): @@ -85,6 +100,34 @@ class CommunicationGroupGenerator: self.communication_group[Constant.P2P] = \ [list(stage_group.get(first_rank, {})) for first_rank in first_rank_sort_list] + def set_p2p_groups(self): + self.p2p_link = sorted(self.p2p_link, key=lambda x: min(x)) + while self.p2p_link: + union_set = deepcopy(self.p2p_link[0]) + rm_list = [self.p2p_link[0]] + for idx, link_rank_set_x in enumerate(self.p2p_link[1:]): + if UnionFind.is_connected(link_rank_set_x, union_set): + union_set = union_set.union(link_rank_set_x) + rm_list.append(link_rank_set_x) + self.p2p_comm_group.append(union_set) + self.p2p_link = [element for element in self.p2p_link if element not in rm_list] + + def set_p2p_link(self, rank_id: int, step_id: str): + ops = self.rank_matrix_dir_dict.get(rank_id, {}).get(step_id, {}) + if not ops: + print(f"[WARNING] rank{rank_id} {step_id} do not have communication matrix ops data.") + return + p2p_ops = ops.get(Constant.P2P, {}) + for op_name, link_dict in p2p_ops.items(): + for link in link_dict: + src_rank = int(link.split('-')[0]) + dst_rank = int(link.split('-')[1]) + if src_rank != dst_rank: + rank_set = set([src_rank, dst_rank]) + if rank_set in self.p2p_link: + continue + self.p2p_link.append(rank_set) + def get_collective_ops_name(self, rank_id: int, comm_op_dict: dict): for comm_op in comm_op_dict: if comm_op.startswith('Total'): -- Gitee