diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/base_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/base_analysis.py index 0d14af7693abbf433af14431ff4958e5a24a3cde..46080d88d946e51801326a90ff8dfdd8726a2c8b 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/base_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/base_analysis.py @@ -31,6 +31,7 @@ class BaseAnalysis: self.data_map = param.get(Constant.DATA_MAP) self.data_type = param.get(Constant.DATA_TYPE) self.communication_ops = [] + self.p2p_group_dict = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.P2P_GROUP) self.collective_group_dict = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) self.comm_ops_struct = {} self.adapter = DataTransferAdapter() @@ -86,7 +87,7 @@ class BaseAnalysis: def split_op_by_group(self): for single_op in self.communication_ops: if single_op.get(Constant.COMM_OP_TYPE) == Constant.P2P: - rank_tup = Constant.P2P + rank_tup = tuple(self.p2p_group_dict.get(single_op.get(Constant.GROUP_NAME), [])) else: rank_tup = tuple(self.collective_group_dict.get(single_op.get(Constant.GROUP_NAME), [])) rank_id = single_op.get(Constant.RANK_ID, 'N/A') diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py index d4df5466c3845f7a2db7f3b5b439059155bc4d47..62b4f4ec57d806ee4dfc144d17675403bf89df9c 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/comm_matrix_analysis.py @@ -141,15 +141,18 @@ class CommMatrixAnalysis(BaseAnalysis): Constant.OP_NAME: '' } total_op_info = defaultdict(lambda: copy.deepcopy(default_value)) + total_group_op_info = defaultdict(lambda: copy.deepcopy(total_op_info)) for op_name, op_dict in step_dict.items(): + group_name = op_name.split("@")[-1] if self.check_add_op(op_name): for link_key, link_dict in op_dict.items(): - self.combine_link(total_op_info[link_key], link_dict) - for _, link_dict in total_op_info.items(): - link_dict[Constant.BANDWIDTH_GB_S] = \ - self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), - link_dict.get(Constant.TRANSIT_TIME_MS, 0)) - step_dict[Constant.TOTAL_OP_INFO] = total_op_info + self.combine_link(total_group_op_info[group_name][link_key], link_dict) + for group_name, total_op_info in total_group_op_info.items(): + for _, link_dict in total_op_info.items(): + link_dict[Constant.BANDWIDTH_GB_S] = \ + self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), + link_dict.get(Constant.TRANSIT_TIME_MS, 0)) + step_dict[f"{Constant.TOTAL_OP_INFO}@{group_name}"] = total_op_info def get_parallel_group_info(self): parallel_group_info = {} diff --git a/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py b/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py index e8ca793f525b0279053bc9848f99f21016ea6295..6c8fac98f826a8a20556fbf9e7437cbda5b43bca 100644 --- a/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/msprof_analyze/cluster_analyse/analysis/communication_analysis.py @@ -80,17 +80,20 @@ class CommunicationAnalysis(BaseAnalysis): Constant.COMMUNICATION_BANDWIDTH_INFO: {} } total_rank_dict = defaultdict(lambda: copy.deepcopy(default_value)) - for _, rank_dict in comm_ops.items(): + total_group_rank_dict = defaultdict(lambda: copy.deepcopy(total_rank_dict)) + for op_name, rank_dict in comm_ops.items(): + group_name = op_name.split("@")[-1] for rank_id, communication_op_info in rank_dict.items(): for com_info, com_info_dict in communication_op_info.items(): if com_info == Constant.COMMUNICATION_TIME_INFO: - self.combine_time_info(com_info_dict, total_rank_dict[rank_id][com_info]) + self.combine_time_info(com_info_dict, total_group_rank_dict[group_name][rank_id][com_info]) if com_info == Constant.COMMUNICATION_BANDWIDTH_INFO: - self.combine_bandwidth_info(com_info_dict, total_rank_dict[rank_id][com_info]) - for rank_id in total_rank_dict: - self.compute_time_ratio(total_rank_dict[rank_id][Constant.COMMUNICATION_TIME_INFO]) - self.compute_bandwidth_ratio(total_rank_dict[rank_id][Constant.COMMUNICATION_BANDWIDTH_INFO]) - comm_ops[Constant.TOTAL_OP_INFO] = total_rank_dict + self.combine_bandwidth_info(com_info_dict, total_group_rank_dict[group_name][rank_id][com_info]) + for group_name, total_rank_dict in total_group_rank_dict.items(): + for rank_id in total_rank_dict: + self.compute_time_ratio(total_rank_dict[rank_id][Constant.COMMUNICATION_TIME_INFO]) + self.compute_bandwidth_ratio(total_rank_dict[rank_id][Constant.COMMUNICATION_BANDWIDTH_INFO]) + comm_ops[f"{Constant.TOTAL_OP_INFO}@{group_name}"] = total_rank_dict def combine_time_info(self, com_info_dict: dict, total_time_info_dict: dict): ratio_list = [Constant.WAIT_TIME_RATIO, Constant.SYNCHRONIZATION_TIME_RATIO] diff --git a/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py b/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py index 31cf1ad9ff184d840ec16b9b636fbc3555728980..63861028ed0334ab0cbec76f33207e4377c15c91 100644 --- a/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py +++ b/profiler/msprof_analyze/cluster_analyse/communication_group/base_communication_group.py @@ -118,6 +118,7 @@ class BaseCommunicationGroup: def collect_comm_data(self): comm_data_dict = { + Constant.P2P_GROUP: self.p2p_group_dict, Constant.COLLECTIVE_GROUP: self.collective_group_dict, Constant.COMMUNICATION_OPS: self.communication_ops, Constant.MATRIX_OPS: self.matrix_ops, diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py index 8b91626fe50dfa9ebfa62321545c92bb4dde39d0..b77d1a7f245d87d929af6c481c43d77b63841e7f 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_matrix_sum/communication_matrix_sum.py @@ -153,8 +153,8 @@ class CommMatrixSum(BaseRecipeAnalysis): grouped_df.at[index, 'bandwidth'] = row["transit_size"] / row["transit_time"] if row["transit_time"] else 0 filtered_df = grouped_df[grouped_df["is_mapped"]].drop(columns="is_mapped") total_op_info = filtered_df[filtered_df['hccl_op_name'].str.contains('total', na=False)].groupby( - [self.RANK_SET, 'step', "src_rank", "dst_rank"]).agg( - {"group_name": "first", 'transport_type': 'first', 'op_name': 'first', "transit_size": "sum", + [TableConstant.GROUP_NAME, 'step', "src_rank", "dst_rank"]).agg( + {'transport_type': 'first', 'op_name': 'first', "transit_size": "sum", "transit_time": "sum"} ) total_op_info = total_op_info.reset_index() diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py b/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py index fd6182a923a28766230148cdb12280031e1f2890..d2320ba0823aec6300d1e603a6490e9d67de2c01 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/communication_time_sum/communication_time_sum.py @@ -120,8 +120,7 @@ class CommunicationTimeSum(BaseRecipeAnalysis): "synchronization_time", "idle_time"等时间数据,新增汇总行插入communication_time """ merged_df = pd.merge(communication_time, rank_set_df, on=TableConstant.GROUP_NAME, how='left') - summed_df = merged_df.groupby([TableConstant.STEP, TableConstant.RANK_ID, TableConstant.RANK_SET]).agg({ - TableConstant.GROUP_NAME: "first", + summed_df = merged_df.groupby([TableConstant.STEP, TableConstant.RANK_ID, TableConstant.GROUP_NAME]).agg({ TableConstant.ELAPSED_TIME: "sum", TableConstant.TRANSIT_TIME: "sum", TableConstant.WAIT_TIME: "sum", @@ -170,23 +169,22 @@ class CommunicationTimeSum(BaseRecipeAnalysis): sum_transit_size = 'sum_transit_size' sum_transit_time = 'sum_transit_time' sum_transit = merged_df.groupby( - [TableConstant.RANK_SET, TableConstant.STEP, TableConstant.RANK_ID, TableConstant.TRANSPORT_TYPE]).apply( + [TableConstant.GROUP_NAME, TableConstant.STEP, TableConstant.RANK_ID, TableConstant.TRANSPORT_TYPE]).apply( self._get_sum_distinct_op).reset_index().rename(columns={ TableConstant.TRANSIT_SIZE: sum_transit_size, TableConstant.TRANSIT_TIME: sum_transit_time }) joined_df = pd.merge(merged_df, sum_transit, - on=[TableConstant.RANK_SET, TableConstant.STEP, TableConstant.RANK_ID, + on=[TableConstant.GROUP_NAME, TableConstant.STEP, TableConstant.RANK_ID, TableConstant.TRANSPORT_TYPE]) # 按'rank_set', 'step', 'rank_id', 'transport_type', 'package_size'进行聚合 agg_result = joined_df.groupby( - [TableConstant.RANK_SET, TableConstant.STEP, TableConstant.RANK_ID, TableConstant.TRANSPORT_TYPE, + [TableConstant.GROUP_NAME, TableConstant.STEP, TableConstant.RANK_ID, TableConstant.TRANSPORT_TYPE, TableConstant.PACKAGE_SIZE] ).agg({ TableConstant.COUNT: 'sum', TableConstant.TOTAL_DURATION: 'sum', TableConstant.HCCL_OP_NAME: 'first', - TableConstant.GROUP_NAME: 'first', sum_transit_size: 'first', sum_transit_time: 'first' }).reset_index() diff --git a/profiler/msprof_analyze/prof_common/constant.py b/profiler/msprof_analyze/prof_common/constant.py index 129b03f22e36275dda2121abf6daf2743ec87cf7..f63f4669c4094992de424ac95a9dec37a8254a67 100644 --- a/profiler/msprof_analyze/prof_common/constant.py +++ b/profiler/msprof_analyze/prof_common/constant.py @@ -92,6 +92,7 @@ class Constant(object): # params DATA_MAP = "data_map" + P2P_GROUP = "p2p_group" COLLECTIVE_GROUP = "collective_group" COMMUNICATION_OPS = "communication_ops" MATRIX_OPS = "matrix_ops" diff --git a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py index ce1c877838a6fdaf1149c2b22fcd6c64409470a1..8f3e1f8954c9291323b534716a0311660a9a55b8 100644 --- a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py +++ b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db.py @@ -110,10 +110,10 @@ class TestClusterAnalysePytorchDb(TestCase): communication_matrix_json = FileManager.read_json_file(self.COMMUNICATION_MATRIX_PATH) self.assertEqual(select_count(self.db_path, query_count), len(communication_matrix_json.get('(4, 5, 6, 7)') - .get('step').get('Total Op Info')), + .get('step').get('Total Op Info@15244899533746605158')), "Cluster communication matrix db vs text count wrong.") text_cluster_communication_matrix = (communication_matrix_json.get('(4, 5, 6, 7)').get('step') - .get('Total Op Info').get('7-4')) + .get('Total Op Info@15244899533746605158').get('7-4')) self.assertEqual(text_cluster_communication_matrix.get('Transport Type'), db_cluster_communication_analyzer_matrix.transport_type, "Cluster communication matrix db vs text 'Transport Type' property wrong.") @@ -132,19 +132,23 @@ class TestClusterAnalysePytorchDb(TestCase): Test case to compare the cluster bandWidth from text file and database. """ query = ("SELECT * FROM ClusterCommAnalyzerBandwidth WHERE hccl_op_name = 'Total Op Info' and rank_id = 7 " - "and step = 'step' and band_type = 'HCCS' and package_size = '3.372891' and rank_set = '(4, 5, 6, 7)'") + "and group_name = '15244899533746605158' and step = 'step' and band_type = 'HCCS' and " + "package_size = '3.372891' and rank_set = '(4, 5, 6, 7)'") db_cluster_communication_analyzer_band_width = select_by_query(self.db_path, query, ClusterCommunicationAnalyzerBandwidthDb) query_count = ("SELECT count(*) FROM ClusterCommAnalyzerBandwidth WHERE hccl_op_name = 'Total Op Info' and " - "rank_set = '(4, 5, 6, 7)' and rank_id = 7 and band_type = 'HCCS'") + "and group_name = '15244899533746605158' rank_set = '(4, 5, 6, 7)' " + "and rank_id = 7 and band_type = 'HCCS'") communication_json = FileManager.read_json_file(self.COMMUNICATION_PATH) self.assertEqual(select_count(self.db_path, query_count), len(communication_json.get('(4, 5, 6, 7)') - .get('step').get('Total Op Info').get('7').get('Communication Bandwidth Info') + .get('step').get('Total Op Info@15244899533746605158') + .get('7').get('Communication Bandwidth Info') .get('HCCS').get('Size Distribution')), "Cluster communication bandWidth db vs text count wrong.") text_cluster_communication_band_width = (communication_json.get('(4, 5, 6, 7)').get('step') - .get('Total Op Info').get('7').get('Communication Bandwidth Info') + .get('Total Op Info@15244899533746605158') + .get('7').get('Communication Bandwidth Info') .get('HCCS')) self.assertEqual(round(text_cluster_communication_band_width.get('Transit Time(ms)')), round(db_cluster_communication_analyzer_band_width.transit_time), @@ -168,18 +172,19 @@ class TestClusterAnalysePytorchDb(TestCase): Test case to compare the cluster time from text file and database. """ query = ("SELECT * FROM ClusterCommAnalyzerTime WHERE hccl_op_name = 'Total Op Info' and rank_id = 0 " - "and step = 'step' and rank_set = '(0, 1, 2, 3)'") + "and group_name = '6902614901354803568' and step = 'step' and rank_set = '(0, 1, 2, 3)'") db_cluster_communication_analyzer_time = select_by_query(self.db_path, query, ClusterCommunicationAnalyzerTime) - query_count = ("SELECT count(*) FROM ClusterCommAnalyzerTime WHERE hccl_op_name = 'Total Op Info' and " - "rank_set = '(0, 1, 2, 3)'") + query_count = ("SELECT count(*) FROM ClusterCommAnalyzerTime WHERE hccl_op_name = 'Total Op Info' " + "and group_name = '6902614901354803568' and rank_set = '(0, 1, 2, 3)'") communication_json = FileManager.read_json_file(self.COMMUNICATION_PATH) self.assertEqual(select_count(self.db_path, query_count), len(communication_json.get('(0, 1, 2, 3)') - .get('step').get('Total Op Info')), + .get('step').get('Total Op Info@6902614901354803568')), "Cluster communication time db vs text count wrong.") text_cluster_communication_analyzer_time = (communication_json.get('(0, 1, 2, 3)').get('step') - .get('Total Op Info').get('0').get('Communication Time Info')) + .get('Total Op Info@6902614901354803568') + .get('0').get('Communication Time Info')) self.assertEqual(round(text_cluster_communication_analyzer_time.get('Elapse Time(ms)')), round(db_cluster_communication_analyzer_time.elapsed_time), "Cluster communication time db vs text 'Elapse Time(ms)' property wrong.") diff --git a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db_simplification.py b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db_simplification.py index a0aad4c01162783146df68913261431ee15106ed..3cf2b2d99bd6a788bf4a1d9077651923c530845a 100644 --- a/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db_simplification.py +++ b/profiler/msprof_analyze/test/st/cluster_analyse/test_cluster_analyse_pytorch_db_simplification.py @@ -71,23 +71,24 @@ class TestClusterAnalysePytorchDbSimplification(TestCase): self.assertEqual(data["rank_set"].tolist(), ["(0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15)"]) def test_comm_matrix_data(self): - query = "SELECT * FROM ClusterCommunicationMatrix WHERE hccl_op_name = 'Total Op Info' " + query = "SELECT * FROM ClusterCommunicationMatrix WHERE hccl_op_name like 'Total Op Info%'" data = pd.read_sql(query, self.conn) - self.assertEqual(len(data), 232) + self.assertEqual(len(data), 312) query = "SELECT transport_type, transit_size, transit_time, bandwidth FROM ClusterCommunicationMatrix WHERE " \ - "hccl_op_name='Total Op Info' and group_name='1046397798680881114' and src_rank=12 and dst_rank=4" + "hccl_op_name = 'Total Op Info@1046397798680881114' and group_name='1046397798680881114' " \ + "and src_rank=12 and dst_rank=4" data = pd.read_sql(query, self.conn) - self.assertEqual(data.iloc[0].tolist(), ['RDMA', 59341.69862400028, 17684.277734, 3.3556190146182354]) + self.assertEqual(data.iloc[0].tolist(), ['RDMA', 58681.19654400028, 17642.966488, 3.326039109347782]) def test_comm_time_data(self): - query = "select rank_id, count(0) cnt from ClusterCommunicationTime where hccl_op_name = " \ - "'Total Op Info' group by rank_id" + query = "select rank_id, count(0) cnt from ClusterCommunicationTime where hccl_op_name LIKE " \ + "'Total Op Info%' group by rank_id" data = pd.read_sql(query, self.conn) self.assertEqual(len(data), 16) - self.assertEqual(data["cnt"].tolist(), [4 for _ in range(16)]) + self.assertEqual(data["cnt"].tolist(), [6 for _ in range(16)]) def test_comm_bandwidth_data(self): - query = "select * from ClusterCommunicationBandwidth where hccl_op_name = 'Total Op Info' and " \ + query = "select * from ClusterCommunicationBandwidth where hccl_op_name like 'Total Op Info%' and " \ "group_name='12703750860003234865' order by count" data = pd.read_sql(query, self.conn) self.assertEqual(len(data), 2)