From b9075eb09c6dcd1c4793f6670894c9385b43078d Mon Sep 17 00:00:00 2001 From: z30043230 Date: Fri, 15 Aug 2025 14:50:36 +0800 Subject: [PATCH] =?UTF-8?q?p2p=E7=AE=97=E5=AD=90=E9=85=8D=E5=AF=B9?= =?UTF-8?q?=E9=83=A8=E5=88=86=E9=80=BB=E8=BE=91=E4=BF=AE=E6=94=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../recipes/p2p_pairing/p2p_pairing.py | 31 +++++++------------ 1 file changed, 11 insertions(+), 20 deletions(-) diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/p2p_pairing.py b/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/p2p_pairing.py index 998aa291f..4103d9203 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/p2p_pairing.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/p2p_pairing/p2p_pairing.py @@ -70,14 +70,11 @@ class P2PPairing(BaseRecipeAnalysis): if ret is None: logger.error("Failed to connect to the database. Please check the database configurations") return - if self.COL_NAME_P2P_CONNECTION_ID in ret: - logger.error(f"`{self.COL_NAME_P2P_CONNECTION_ID}` already exists in the {self.TARGET_TABLE_NAME}. " - f"Exiting to prevent result overwrite.") - return - DBManager.execute_sql( - conn, - f"ALTER TABLE {self.TARGET_TABLE_NAME} ADD COLUMN {self.COL_NAME_P2P_CONNECTION_ID} TEXT" - ) + if self.COL_NAME_P2P_CONNECTION_ID not in ret: + DBManager.execute_sql( + conn, + f"ALTER TABLE {self.TARGET_TABLE_NAME} ADD COLUMN {self.COL_NAME_P2P_CONNECTION_ID} TEXT" + ) DBManager.execute_sql( conn, f"UPDATE {self.TARGET_TABLE_NAME} SET {self.COL_NAME_P2P_CONNECTION_ID} = NULL" @@ -126,10 +123,10 @@ class P2PPairing(BaseRecipeAnalysis): """ df = df[df[P2PPairingExport.TASK_TYPE].isin(self.VALID_DST_RANK_TASK_TYPE)] - def check_dst_rank_unique(group): - return group[P2PPairingExport.DST_RANK].nunique() == 1 + def check_src_dst_rank_unique(group): + return group[P2PPairingExport.DST_RANK].nunique() == 1 and group[P2PPairingExport.SRC_RANK].nunique() == 1 - unique_dst_rank: pd.DataFrame = (df.groupby(P2PPairingExport.OP_NAME).apply(check_dst_rank_unique)) + unique_src_dst_rank: pd.DataFrame = (df.groupby(P2PPairingExport.OP_NAME).apply(check_src_dst_rank_unique)) def get_dst_rank_value(group): if group[P2PPairingExport.DST_RANK].nunique() == 1: @@ -140,23 +137,17 @@ class P2PPairing(BaseRecipeAnalysis): apply(get_dst_rank_value)) df = df.copy() - df[self.COL_NAME_IS_UNIQUE_VALUE] = df[P2PPairingExport.OP_NAME].map(unique_dst_rank) + df[self.COL_NAME_IS_UNIQUE_VALUE] = df[P2PPairingExport.OP_NAME].map(unique_src_dst_rank) df[self.COL_NAME_OP_DST_RANK] = df[P2PPairingExport.OP_NAME].map(dst_rank_value) df[self.COL_NAME_OP_DST_RANK] = df[self.COL_NAME_OP_DST_RANK].fillna(Constant.INVALID_RANK_NUM) df[self.COL_NAME_OP_DST_RANK] = df[self.COL_NAME_OP_DST_RANK].astype(df[P2PPairingExport.DST_RANK].dtype) - check_dst_rank_unique_false: pd.DataFrame = df[~df[self.COL_NAME_IS_UNIQUE_VALUE]] - if not check_dst_rank_unique_false.empty: + check_src_dst_rank_unique_false: pd.DataFrame = df[~df[self.COL_NAME_IS_UNIQUE_VALUE]] + if not check_src_dst_rank_unique_false.empty: logger.warning(f"There are communication op entries with multiple destination ranks! " f"Please check the corresponding profiler database file.") df = df[df[self.COL_NAME_IS_UNIQUE_VALUE]] - - src_rank_unique_values: int = df[P2PPairingExport.SRC_RANK].nunique() - if src_rank_unique_values != 1: - logger.error(f"There are communication op entries with multiple source ranks! " - f"Please check the corresponding profiler database file.") - return None return df.reset_index() def filter_data_by_group_name(self, df: pd.DataFrame): -- Gitee