From e04c2f778195849328b94087ece61fb73b619520 Mon Sep 17 00:00:00 2001 From: w00800385 Date: Thu, 22 Feb 2024 15:06:38 +0800 Subject: [PATCH 1/7] db manager --- profiler/cluster_analyse/common_func/DBManager.py | 0 profiler/cluster_analyse/common_func/empty_class.py | 0 .../prof_bean/communication_bandwidth_info_bean.py | 0 profiler/cluster_analyse/prof_bean/communication_matrix_bean.py | 0 .../cluster_analyse/prof_bean/communication_time_info_bean.py | 0 5 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 profiler/cluster_analyse/common_func/DBManager.py create mode 100644 profiler/cluster_analyse/common_func/empty_class.py create mode 100644 profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py create mode 100644 profiler/cluster_analyse/prof_bean/communication_matrix_bean.py create mode 100644 profiler/cluster_analyse/prof_bean/communication_time_info_bean.py diff --git a/profiler/cluster_analyse/common_func/DBManager.py b/profiler/cluster_analyse/common_func/DBManager.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/cluster_analyse/common_func/empty_class.py b/profiler/cluster_analyse/common_func/empty_class.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py b/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py b/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py b/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py new file mode 100644 index 0000000000..e69de29bb2 -- Gitee From fc053fc710e05ffaaa6e6e6882555b91cfcba3b8 Mon Sep 17 00:00:00 2001 From: w00800385 Date: Thu, 22 Feb 2024 15:06:38 +0800 Subject: [PATCH 2/7] db manager --- .../cluster_analyse/common_func/DBManager.py | 137 ++++++++++++++++++ .../cluster_analyse/common_func/constant.py | 10 ++ .../common_func/empty_class.py | 46 ++++++ .../communication_bandwidth_info_bean.py | 30 ++++ .../prof_bean/communication_matrix_bean.py | 28 ++++ .../prof_bean/communication_time_info_bean.py | 28 ++++ 6 files changed, 279 insertions(+) create mode 100644 profiler/cluster_analyse/common_func/DBManager.py create mode 100644 profiler/cluster_analyse/common_func/empty_class.py create mode 100644 profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py create mode 100644 profiler/cluster_analyse/prof_bean/communication_matrix_bean.py create mode 100644 profiler/cluster_analyse/prof_bean/communication_time_info_bean.py diff --git a/profiler/cluster_analyse/common_func/DBManager.py b/profiler/cluster_analyse/common_func/DBManager.py new file mode 100644 index 0000000000..b7801eff75 --- /dev/null +++ b/profiler/cluster_analyse/common_func/DBManager.py @@ -0,0 +1,137 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sqlite3 +from collections import namedtuple +from dataclasses import fields + +from common_func.constant import Constant +from common_func.empty_class import EmptyClass + + +class DBManager: + """ + class to manage DB operation + """ + FETCH_SIZE = 10000 + INSERT_SIZE = 10000 + MAX_ROW_COUNT = 100000000 + + @staticmethod + def create_connect_db(db_path: str) -> tuple: + """ + create and connect database + """ + if check_db_path_valid(db_path): + try: + conn = sqlite3.connect(db_path) + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return EmptyClass("empty conn"), EmptyClass("empty curs") + try: + if isinstance(conn, sqlite3.Connection): + curs = conn.cursor() + os.chmod(db_path, Constant.FILE_AUTHORITY) + return conn, curs + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return EmptyClass("empty conn"), EmptyClass("empty curs") + return EmptyClass("empty conn"), EmptyClass("empty curs") + + @staticmethod + def destroy_db_connect(conn: any, curs: any) -> None: + """ + destroy db connection + """ + try: + if isinstance(curs, sqlite3.Cursor): + curs.close() + except sqlite3.Error as err: + print(f"[ERROR] {err}") + try: + if isinstance(conn, sqlite3.Connection): + conn.close() + except sqlite3.Error as err: + print(f"[ERROR] {err}") + + @staticmethod + def execute_sql(conn: any, sql: str, params: any = None) -> bool: + """ + execute sql + """ + try: + if isinstance(conn, sqlite3.Connection): + if params: + conn.cursor().execute(sql, params) + else: + conn.cursor().execute(sql) + conn.commit() + return True + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return False + print("[ERROR] conn is invalid param") + return False + + @classmethod + def fetch_all_data(cls: any, curs: any, sql: str, param: tuple = None, dto_class: any = None) -> list: + """ + fetch 10000 num of data from db each time to get all data + """ + if not isinstance(curs, sqlite3.Cursor): + return [] + data = [] + try: + if param: + curs.execute(sql, param) + else: + curs.execute(sql) + except sqlite3.Error as err: + print(f"[ERROR] {err}") + curs.row_factory = None + return [] + try: + while True: + res = curs.fetchmany(cls.FETCH_SIZE) + if dto_class: + field_names = [item.name for item in fields(dto_class)] + dto = namedtuple(dto_class.__name__, field_names) + data.extend([dto(*item) for item in res]) + else: + data.extend(res) + if len(data) > cls.MAX_ROW_COUNT: + print("[WARRING] The records count in the table exceeds the limit!") + if len(res) < cls.FETCH_SIZE: + break + return data + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return [] + finally: + curs.row_factory = None + + +def check_db_path_valid(path: str, max_size: int = Constant.MAX_READ_DB_FILE_BYTES) -> bool: + if not os.path.exists(path): + print(f'[ERROR] The db file path: {path} does not exist. Please check the path') + return False + if os.path.islink(path): + print(f'[ERROR] The db file path: {path} is link. Please check the path') + return False + if os.path.getsize(path) > max_size: + print(f'[ERROR] The db file: {path} is too large to read. Please check the file') + return False + return True diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index 5ca830edef..1af29a97b2 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -30,6 +30,7 @@ class Constant(object): MAX_JSON_SIZE = 1024 * 1024 * 1024 * 10 MAX_CSV_SIZE = 1024 * 1024 * 1024 * 5 MAX_PATH_LENGTH = 4096 + MAX_READ_DB_FILE_BYTES = 1024 * 1024 * 1024 * 8 # communication P2P = "p2p" @@ -76,3 +77,12 @@ class Constant(object): # file suffix JSON_SUFFIX = ".json" CSV_SUFFIX = ".csv" + + # db name + DB_COMMUNICATION_ANALYZER = 'communication_analyzer.db' + DB_CLUSTER_COMMUNICATION_ANALYZER = 'cluster_communication_analyzer.db' + + # db tables + TABLE_COMM_ANALYZER_BANDWIDTH = 'CommAnalyzerBandwidth' + TABLE_COMM_ANALYZER_TIME = 'CommAnalyzerTime' + TABLE_COMM_ANALYZER_MATRIX = 'CommAnalyzerMatrix' diff --git a/profiler/cluster_analyse/common_func/empty_class.py b/profiler/cluster_analyse/common_func/empty_class.py new file mode 100644 index 0000000000..0141079a00 --- /dev/null +++ b/profiler/cluster_analyse/common_func/empty_class.py @@ -0,0 +1,46 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class EmptyClass: + """ + Empty class + """ + + def __init__(self: any, info: str = "") -> None: + self._info = info + + @classmethod + def __bool__(cls: any) -> bool: + return False + + @classmethod + def __str__(cls: any) -> str: + return "" + + @property + def info(self: any) -> str: + """ + get info + :return: _info + """ + return self._info + + @staticmethod + def is_empty() -> bool: + """ + return this is an empty class + """ + return True diff --git a/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py b/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py new file mode 100644 index 0000000000..cc4a5fecbf --- /dev/null +++ b/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py @@ -0,0 +1,30 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class CommunicationBandwidthInfo: + """ + The class represents a Communication Bandwidth Info object + """ + hccl_op_name: str + group_name: str + transport_type: str + transit_size: float + transit_time: float + bandwidth: float + large_packet_ratio: float + package_size: float + count: float + total_duration: float diff --git a/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py b/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py new file mode 100644 index 0000000000..95b174e859 --- /dev/null +++ b/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class CommunicationMatrix: + """ + The class represents a Communication Matrix object + """ + hccl_op_name: str + group_name: str + src_rank: str + dst_rank: str + transport_type: str + transit_size: float + transit_time: float + bandwidth: float diff --git a/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py b/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py new file mode 100644 index 0000000000..b1a8e81eae --- /dev/null +++ b/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class CommunicationTimeInfo: + """ + The class represents a Communication Time Info object + """ + hccl_op_name: str + group_name: str + start_timestamp: float + elapse_time: float + transit_time: float + wait_time: float + synchronization_time: float + idle_time: float -- Gitee From 88c7247e055c72c67b304584892e5f67447d137c Mon Sep 17 00:00:00 2001 From: cheng <3218750885@qq.com> Date: Sun, 25 Feb 2024 20:59:23 +0800 Subject: [PATCH 3/7] add db manage --- .../cluster_analyse/analysis/base_analysis.py | 36 ++++ .../cluster_analyse/common_func/DBManager.py | 59 ++++- .../cluster_analyse/common_func/constant.py | 4 +- .../common_func/tables_config.py | 35 +++ .../base_communication_group.py | 133 ++++++++++++ .../communication_db_group.py | 99 +++++++++ .../communication_group_generator.py | 202 +----------------- .../communication_json_group.py | 128 +++++++++++ .../communication_bandwidth_info_bean.py | 2 + .../prof_bean/communication_matrix_bean.py | 2 + .../prof_bean/communication_time_info_bean.py | 2 + 11 files changed, 501 insertions(+), 201 deletions(-) create mode 100644 profiler/cluster_analyse/analysis/base_analysis.py create mode 100644 profiler/cluster_analyse/common_func/tables_config.py create mode 100644 profiler/cluster_analyse/communication_group/base_communication_group.py create mode 100644 profiler/cluster_analyse/communication_group/communication_db_group.py create mode 100644 profiler/cluster_analyse/communication_group/communication_json_group.py diff --git a/profiler/cluster_analyse/analysis/base_analysis.py b/profiler/cluster_analyse/analysis/base_analysis.py new file mode 100644 index 0000000000..8f0f6eaea3 --- /dev/null +++ b/profiler/cluster_analyse/analysis/base_analysis.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from common_func.constant import Constant + + +class BaseAnalysis: + + def __init__(self, param: dict): + self.collection_path = param.get(Constant.COLLECTION_PATH) + self.data_map = param.get(Constant.DATA_MAP) + self.communication_group = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_GROUP) + + @staticmethod + def compute_ratio(dividend: float, divisor: float): + if abs(divisor) < Constant.EPS: + return 0 + else: + return round(dividend / divisor, 4) + + @abstractmethod + def run(self): + pass diff --git a/profiler/cluster_analyse/common_func/DBManager.py b/profiler/cluster_analyse/common_func/DBManager.py index b7801eff75..24a66b992b 100644 --- a/profiler/cluster_analyse/common_func/DBManager.py +++ b/profiler/cluster_analyse/common_func/DBManager.py @@ -20,6 +20,7 @@ from dataclasses import fields from common_func.constant import Constant from common_func.empty_class import EmptyClass +from common_func.tables_config import TablesConfig class DBManager: @@ -67,6 +68,57 @@ class DBManager: except sqlite3.Error as err: print(f"[ERROR] {err}") + @staticmethod + def judge_table_exists(curs: any, table_name: str) -> any: + """ + judge table exists + """ + if not isinstance(curs, sqlite3.Cursor): + return False + try: + curs.execute("select count(*) from sqlite_master where type='table' and name=?", table_name) + return curs.fetchone()[0] + except sqlite3.Error as err: + print("[ERROR] {}".format(err)) + return False + + @staticmethod + def sql_generate_table(table_map: str): + header_with_type_begin = "(" + header_with_type_end = ")" + header_with_type_list = [] + if table_map in TablesConfig.DATA: + items = TablesConfig.DATA[table_map] + for item in items: + header_with_type_list.append(item[0] + " " + item[1].split(",")[0]) + header_with_type_begin += ",".join(header_with_type_list) + header_with_type_begin += header_with_type_end + return header_with_type_begin + + @classmethod + def check_tables_in_db(cls, db_path: any, *tables: any) -> bool: + check_db_path_valid(db_path) + conn, curs = cls.create_connect_db(db_path) + if not (conn and curs): + return False + res = True + for table in tables: + if not cls.judge_table_exists(curs, table): + res = False + break + cls.destroy_db_connect(conn, curs) + return res + + @classmethod + def create_tables(cls, db_path: any, *tables: any) -> bool: + conn, curs = cls.create_connect_db(db_path) + for table_name in tables: + if not cls.judge_table_exists(curs, table_name): + table_map = "{0}Map".format(table_name) + header_with_type = cls.sql_generate_table(table_map) + sql = "CREATE TABLE IF NOT EXISTS " + table_name + header_with_type + cls.execute_sql(conn, sql) + @staticmethod def execute_sql(conn: any, sql: str, params: any = None) -> bool: """ @@ -124,14 +176,11 @@ class DBManager: curs.row_factory = None -def check_db_path_valid(path: str, max_size: int = Constant.MAX_READ_DB_FILE_BYTES) -> bool: - if not os.path.exists(path): - print(f'[ERROR] The db file path: {path} does not exist. Please check the path') - return False +def check_db_path_valid(path: str, is_create: bool = False, max_size: int = Constant.MAX_READ_DB_FILE_BYTES) -> bool: if os.path.islink(path): print(f'[ERROR] The db file path: {path} is link. Please check the path') return False - if os.path.getsize(path) > max_size: + if not is_create and not os.path.exists(path) and os.path.getsize(path) > max_size: print(f'[ERROR] The db file: {path} is too large to read. Please check the file') return False return True diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index c21814d85a..0599ad34f6 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -84,8 +84,8 @@ class Constant(object): DB = "db" # db name - DB_COMMUNICATION_ANALYZER = 'communication_analyzer.db' - DB_CLUSTER_COMMUNICATION_ANALYZER = 'cluster_communication_analyzer.db' + DB_COMMUNICATION_ANALYZER = 'analysis.db' + DB_CLUSTER_COMMUNICATION_ANALYZER = 'cluster_analysis.db' # db tables TABLE_COMM_ANALYZER_BANDWIDTH = 'CommAnalyzerBandwidth' diff --git a/profiler/cluster_analyse/common_func/tables_config.py b/profiler/cluster_analyse/common_func/tables_config.py new file mode 100644 index 0000000000..9a9108c6a2 --- /dev/null +++ b/profiler/cluster_analyse/common_func/tables_config.py @@ -0,0 +1,35 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class TablesConfig: + DATA = { + "CommAnalyzerTimeMap": [ + ("step", "TEXT, null"), + ("type", "TEXT, null"), + ("hccl_op_name", "TEXT, null"), + ("group_name", "TEXT, null"), + ("start_timestamp", "NUMERIC, null"), + ("elapsed_time", "NUMERIC, null"), + ("transit_time", "NUMERIC, null"), + ("wait_time", "NUMERIC, null"), + ("synchronization_time", "NUMERIC, null"), + ("idle_time", "NUMERIC, null") + ], + "CommunicationGroupMap": [ + ("type", "TEXT, null"), + ("rank_set", "TEXT, null") + ] + } diff --git a/profiler/cluster_analyse/communication_group/base_communication_group.py b/profiler/cluster_analyse/communication_group/base_communication_group.py new file mode 100644 index 0000000000..e29486d677 --- /dev/null +++ b/profiler/cluster_analyse/communication_group/base_communication_group.py @@ -0,0 +1,133 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict +from copy import deepcopy +from multiprocessing import Pool + +from common_func.constant import Constant + + +class BaseCommunicationGroup: + def __init__(self, collection_path: str, data_map: dict, data_type: str): + self.collection_path = collection_path + self.data_map = data_map + self.data_type = data_type + self.rank_comm_dir_dict = {} + self.p2p_link = [] + self.collective_group_dict = defaultdict(set) + self.p2p_comm_group = [] + self.communication_group = {} + + def load_communication_data(self): + comm_op_dirs = [] + for rank_id, profiling_dir_path in self.data_map.items(): + if self.data_type == Constant.TEXT: + comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_JSON) + matrix_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_MATRIX_JSON) + else: + comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) + matrix_dir = comm_dir + if comm_dir and matrix_dir: + comm_op_dirs.append((rank_id, comm_dir, matrix_dir)) + else: + print( + f"[WARNING] Rank {rank_id} does not have a valid communication.json or communication_matrix.json.") + with Pool() as p: + self.rank_comm_dir_dict = p.map(self.read_communication_func, comm_op_dirs) + + def set_p2p_groups(self): + self.p2p_link = sorted(self.p2p_link, key=lambda x: min(x)) + while self.p2p_link: + union_set = deepcopy(self.p2p_link[0]) + rm_list = [self.p2p_link[0]] + for idx, link_rank_set_x in enumerate(self.p2p_link[1:]): + if UnionFind.is_connected(link_rank_set_x, union_set): + union_set = union_set.union(link_rank_set_x) + rm_list.append(link_rank_set_x) + self.p2p_comm_group.append(union_set) + self.p2p_link = [element for element in self.p2p_link if element not in rm_list] + + def generate_collective_communication_group(self): + self.communication_group[Constant.COLLECTIVE] = \ + [list(group) for group_name, group in self.collective_group_dict.items()] + + def generate_p2p_communication_group(self): + stage_group = {} + for group_name, rank_set in self.collective_group_dict.items(): + if not self.whether_valid_comm_group(rank_set): + continue + unioned_set = set() + remove_key = [] + for first_rank, stage in stage_group.items(): + if UnionFind.is_connected(rank_set, stage): + unioned_set = UnionFind.union(rank_set, stage, unioned_set) + remove_key.append(first_rank) + if unioned_set: + for key in remove_key: + del stage_group[key] + stage_group[min(unioned_set)] = unioned_set + else: + stage_group[min(rank_set)] = rank_set + first_rank_sort_list = sorted([first_rank for first_rank in stage_group]) + self.communication_group[Constant.P2P] = \ + [list(stage_group.get(first_rank, {})) for first_rank in first_rank_sort_list] + + def whether_valid_comm_group(self, rank_set: set): + """ + while distinguish which communication group should be used to infer stage info, these group should be ignored: + 1. group can not include more than 1 rank in every single p2p group + """ + for p2p_rank_set in self.p2p_comm_group: + if len(rank_set.intersection(p2p_rank_set)) > 1: + return False + return True + + def read_communication_func(self, params: tuple): + pass + + def analyze_communication_data(self): + pass + + def dump_data(self): + pass + + def generate(self): + self.load_communication_data() + self.analyze_communication_data() + self.set_p2p_groups() + self.generate_collective_communication_group() + self.generate_p2p_communication_group() + return self.dump_data() + + +class UnionFind(object): + """Disjoint Set Union""" + + @classmethod + def union(cls, p: set, q: set, o: set): + """make p and q the same set""" + return p | q | o + + @classmethod + def is_connected(cls, p: set, q: set): + """ + check whether set p and set q are connected + """ + if p & q: + return True + else: + return False diff --git a/profiler/cluster_analyse/communication_group/communication_db_group.py b/profiler/cluster_analyse/communication_group/communication_db_group.py new file mode 100644 index 0000000000..f3754251b8 --- /dev/null +++ b/profiler/cluster_analyse/communication_group/communication_db_group.py @@ -0,0 +1,99 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +from common_func.DBManager import DBManager +from common_func.constant import Constant +from communication_group.base_communication_group import BaseCommunicationGroup +from prof_bean.communication_bandwidth_info_bean import CommunicationBandwidthInfo +from prof_bean.communication_matrix_bean import CommunicationMatrix +from prof_bean.communication_time_info_bean import CommunicationTimeInfo + + +class CommunicationDBGroup(BaseCommunicationGroup): + COMMUNICATION_GROUP_TABLE = "CommunicationGroup" + + def __init__(self, collection_path: str, data_map: dict, data_type: str): + super().__init__(collection_path, data_map, data_type) + + def read_communication_func(self, params: tuple): + if len(params) < 3: + return -1, {}, {} + rank_id = params[0] + db_path = params[1] + time_data = {} + bandwidth_data = {} + matrix_data = {} + if DBManager.check_tables_in_db(db_path, (Constant.TABLE_COMM_ANALYZER_TIME, + Constant.TABLE_COMM_ANALYZER_BANDWIDTH, + Constant.TABLE_COMM_ANALYZER_MATRIX)): + conn, cursor = DBManager.create_connect_db(db_path) + time_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_TIME) + bandwidth_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) + matrix_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_MATRIX) + time_data = DBManager.fetch_all_data(cursor, time_info_sql, dto_class=CommunicationTimeInfo) + bandwidth_data = DBManager.fetch_all_data(cursor, bandwidth_info_sql, dto_class=CommunicationBandwidthInfo) + matrix_data = DBManager.fetch_all_data(cursor, matrix_info_sql, dto_class=CommunicationMatrix) + DBManager.destroy_db_connect(conn, cursor) + return (rank_id, self.data_group_by_step(time_data), self.data_group_by_step(bandwidth_data), + self.data_group_by_step(matrix_data)) + + @staticmethod + def data_group_by_step(data: any) -> any: + res = {} + for item in data: + res.setdefault(item.step, []).append(item) + return res + + def dump_data(self): + output_path = os.path.join(self.collection_path, Constant.CLUSTER_ANALYSIS_OUTPUT) + result_db = os.path.join(output_path, Constant.DB_CLUSTER_COMMUNICATION_ANALYZER) + DBManager.create_tables(result_db, self.COMMUNICATION_GROUP_TABLE) + res = [] + conn, cursor = DBManager.create_connect_db(result_db) + for data_type, data_list in self.communication_group.items(): + for data in data_list: + data = [data_type, data] + res.append(data) + if res: + sql = "insert into {} values ({value})".format(self.COMMUNICATION_GROUP_TABLE, + value="?," * (len(res[0]) - 1) + "?") + DBManager.execute_sql(conn, sql, res) + DBManager.destroy_db_connect(conn, cursor) + + def analyze_communication_data(self): + for rank_id, time_data, bandwidth_data, matrix_data in self.rank_comm_dir_dict: + for step, data_list in time_data.items(): + self.add_p2p_and_rank(rank_id, step, matrix_data) + for data in data_list: + if data.type == Constant.COLLECTIVE: + self.collective_group_dict[data.group_name].add(rank_id) + setattr(data, "rank_id", rank_id) + for data in bandwidth_data[step]: + setattr(data, "rank_id", rank_id) + + def add_p2p_and_rank(self, rank_id: int, step: str, data_dict: dict): + data_list = data_dict[step] + for data in data_list: + if data.type != Constant.COLLECTIVE and data.type != Constant.P2P: + print(f"[WARNING] Unknown communication operators type!") + continue + setattr(data, "rank_id", rank_id) + if data.type == Constant.P2P: + if data.src_rank != data.dst_rank: + rank_set = {data.src_rank, data.dst_rank} + if rank_set not in self.p2p_link: + self.p2p_link.append(rank_set) diff --git a/profiler/cluster_analyse/communication_group/communication_group_generator.py b/profiler/cluster_analyse/communication_group/communication_group_generator.py index 176977da38..1cec39c11c 100644 --- a/profiler/cluster_analyse/communication_group/communication_group_generator.py +++ b/profiler/cluster_analyse/communication_group/communication_group_generator.py @@ -13,206 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os -from copy import deepcopy -from multiprocessing import Pool -from collections import defaultdict from common_func.constant import Constant -from common_func.file_manager import FileManager +from communication_group.communication_db_group import CommunicationDBGroup +from communication_group.communication_json_group import CommunicationJsonGroup class CommunicationGroupGenerator: COMMUNICATION_GROUP_JSON = "communication_group.json" + GROUP_MAP = { + Constant.DB: CommunicationDBGroup, + Constant.TEXT: CommunicationJsonGroup + } def __init__(self, collection_path: str, data_map: dict, data_type: str): - self.collection_path = collection_path - self.data_map = data_map - self.communication_group = {} - self.collective_group_dict = defaultdict(set) - self.p2p_group_dict = defaultdict(list) - self.rank_comm_dir_dict = {} - self.communication_ops = [] - self.p2p_comm_group = [] - self.p2p_link = [] - self.matrix_ops = [] - self.data_type = data_type + self.processor = self.GROUP_MAP[data_type](collection_path, data_map, data_type) def generate(self): - self.load_communication_json() - self.analyze_communication_ops() - self.set_p2p_groups() - self.generate_collective_communication_group() - self.generate_p2p_communication_group() - FileManager.create_json_file(self.collection_path, self.communication_group, self.COMMUNICATION_GROUP_JSON) - comm_data_dict = { - Constant.COLLECTIVE_GROUP: self.collective_group_dict, - Constant.COMMUNICATION_OPS: self.communication_ops, - Constant.MATRIX_OPS: self.matrix_ops, - Constant.COMMUNICATION_GROUP: self.communication_group - } - return comm_data_dict - - def analyze_communication_ops(self): - for rank_id, rank_id_comm_dict, rank_id_matrix_dict in self.rank_comm_dir_dict: - for step_id, step_id_dict in rank_id_comm_dict.items(): - if not isinstance(step_id_dict, dict): - print(f"[WARNING] rank{rank_id}'s communication.json has a wrong data struct.") - continue - self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) - self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) - for comm_op_type, comm_op_dict in step_id_dict.items(): - self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) - - @staticmethod - def read_comm_json_func(params: tuple): - if len(params) < 3: - return -1, {}, {} - rank_id = params[0] - comm_json_path = params[1] - matrix_json_path = params[2] - comm_data = {} - matrix_data = {} - if os.path.exists(comm_json_path): - comm_data = FileManager.read_json_file(comm_json_path) - if os.path.exists(matrix_json_path): - matrix_data = FileManager.read_json_file(matrix_json_path) - return rank_id, comm_data, matrix_data - - def load_communication_json(self): - comm_op_dirs = [] - for rank_id, profiling_dir_path in self.data_map.items(): - comm_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_JSON) - matrix_dir = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.COMM_MATRIX_JSON) - if comm_dir and matrix_dir: - comm_op_dirs.append((rank_id, comm_dir, matrix_dir)) - else: - print(f"[WARNING] Rank {rank_id} does not have a valid communication.json or communication_matrix.json.") - with Pool() as p: - self.rank_comm_dir_dict = p.map(self.read_comm_json_func, comm_op_dirs) - - def generate_collective_communication_group(self): - self.communication_group[Constant.COLLECTIVE] = \ - [list(group) for group_name, group in self.collective_group_dict.items()] - - def whether_valid_comm_group(self, rank_set: set): - """ - while distinguish which communication group should be used to infer stage info, these group should be ignored: - 1. group can not include more than 1 rank in every single p2p group - """ - for p2p_rank_set in self.p2p_comm_group: - if len(rank_set.intersection(p2p_rank_set)) > 1: - return False - return True - - def generate_p2p_communication_group(self): - stage_group = {} - for group_name, rank_set in self.collective_group_dict.items(): - if not self.whether_valid_comm_group(rank_set): - continue - unioned_set = set() - remove_key = [] - for first_rank, stage in stage_group.items(): - if UnionFind.is_connected(rank_set, stage): - unioned_set = UnionFind.union(rank_set, stage, unioned_set) - remove_key.append(first_rank) - if unioned_set: - for key in remove_key: - del stage_group[key] - stage_group[min(unioned_set)] = unioned_set - else: - stage_group[min(rank_set)] = rank_set - first_rank_sort_list = sorted([first_rank for first_rank in stage_group]) - self.communication_group[Constant.P2P] = \ - [list(stage_group.get(first_rank, {})) for first_rank in first_rank_sort_list] - - def set_p2p_groups(self): - self.p2p_link = sorted(self.p2p_link, key=lambda x: min(x)) - while self.p2p_link: - union_set = deepcopy(self.p2p_link[0]) - rm_list = [self.p2p_link[0]] - for idx, link_rank_set_x in enumerate(self.p2p_link[1:]): - if UnionFind.is_connected(link_rank_set_x, union_set): - union_set = union_set.union(link_rank_set_x) - rm_list.append(link_rank_set_x) - self.p2p_comm_group.append(union_set) - self.p2p_link = [element for element in self.p2p_link if element not in rm_list] - - def set_p2p_link(self, rank_id: int, step_id: str, rank_id_matrix_dict: dict): - ops = rank_id_matrix_dict.get(step_id, {}) - self.add_matrix_ops(rank_id, step_id, ops) - if not ops: - print(f"[WARNING] rank{rank_id} {step_id} do not have communication matrix ops data.") - return - p2p_ops = ops.get(Constant.P2P, {}) - for op_name, link_dict in p2p_ops.items(): - self.append_p2p_link(op_name, link_dict) - - def append_p2p_link(self, op_name, link_dict): - for link in link_dict: - if '-' not in link: - print(f"[WARNING] {op_name} has an invalid link key {link}!") - break - src_rank = int(link.split('-')[0]) - dst_rank = int(link.split('-')[1]) - if src_rank != dst_rank: - rank_set = set([src_rank, dst_rank]) - if rank_set in self.p2p_link: - continue - self.p2p_link.append(rank_set) - - def get_collective_ops_name(self, rank_id: int, comm_op_dict: dict): - for comm_op in comm_op_dict: - if comm_op.startswith('Total'): - continue - group_name = comm_op.split('@')[-1] - self.collective_group_dict[group_name].add(rank_id) - - def add_communication_ops(self, rank_id: str, step_id: str, comm_op_type: str, comm_op_dict: dict): - for comm_op in comm_op_dict: - if comm_op.startswith('Total'): - continue - group_name = comm_op.split('@')[-1] - self.communication_ops.append({ - Constant.RANK_ID: rank_id, - Constant.STEP_ID: step_id, - Constant.COMM_OP_TYPE: comm_op_type, - Constant.COMM_OP_NAME: comm_op, - Constant.GROUP_NAME: group_name, - Constant.COMM_OP_INFO: comm_op_dict.get(comm_op) - }) - - def add_matrix_ops(self, rank_id: int, step_id: str, step_id_dict: dict): - for comm_op_type, comm_dict in step_id_dict.items(): - if comm_op_type != Constant.COLLECTIVE and comm_op_type != Constant.P2P: - print(f"[WARNING] Unknown communication operators type!") - continue - for op_name, op_link_info in comm_dict.items(): - if op_name.startswith('Total'): - continue - group_name = op_name.split('@')[-1] - self.matrix_ops.append({ - Constant.RANK_ID: rank_id, - Constant.STEP_ID: step_id, - Constant.COMM_OP_TYPE: comm_op_type, - Constant.COMM_OP_NAME: op_name, - Constant.GROUP_NAME: group_name, - Constant.COMM_OP_INFO: op_link_info - }) - - -class UnionFind(object): - """Disjoint Set Union""" - @classmethod - def union(cls, p: set, q: set, o: set): - """make p and q the same set""" - return p | q | o - - @classmethod - def is_connected(cls, p: set, q: set): - """ - check whether set p and set q are connected - """ - if p & q: - return True - else: - return False + return self.processor.generate() diff --git a/profiler/cluster_analyse/communication_group/communication_json_group.py b/profiler/cluster_analyse/communication_group/communication_json_group.py new file mode 100644 index 0000000000..7a6d3df712 --- /dev/null +++ b/profiler/cluster_analyse/communication_group/communication_json_group.py @@ -0,0 +1,128 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from collections import defaultdict + +from common_func.constant import Constant +from common_func.file_manager import FileManager +from communication_group.base_communication_group import BaseCommunicationGroup, UnionFind + + +class CommunicationJsonGroup(BaseCommunicationGroup): + COMMUNICATION_GROUP_JSON = "communication_group.json" + + def __init__(self, collection_path: str, data_map: dict, data_type: str): + super().__init__(collection_path, data_map, data_type) + self.p2p_group_dict = defaultdict(list) + self.communication_ops = [] + self.matrix_ops = [] + + def dump_data(self): + FileManager.create_json_file(self.collection_path, self.communication_group, self.COMMUNICATION_GROUP_JSON) + comm_data_dict = { + Constant.COLLECTIVE_GROUP: self.collective_group_dict, + Constant.COMMUNICATION_OPS: self.communication_ops, + Constant.MATRIX_OPS: self.matrix_ops, + Constant.COMMUNICATION_GROUP: self.communication_group + } + return comm_data_dict + + def analyze_communication_data(self): + for rank_id, rank_id_comm_dict, rank_id_matrix_dict in self.rank_comm_dir_dict: + for step_id, step_id_dict in rank_id_comm_dict.items(): + if not isinstance(step_id_dict, dict): + print(f"[WARNING] rank{rank_id}'s communication.json has a wrong data struct.") + continue + self.set_p2p_link(rank_id, step_id, rank_id_matrix_dict) + self.get_collective_ops_name(rank_id, step_id_dict.get(Constant.COLLECTIVE)) + for comm_op_type, comm_op_dict in step_id_dict.items(): + self.add_communication_ops(rank_id, step_id, comm_op_type, comm_op_dict) + + def read_communication_func(self, params: tuple): + if len(params) < 3: + return -1, {}, {} + rank_id = params[0] + comm_json_path = params[1] + matrix_json_path = params[2] + comm_data = {} + matrix_data = {} + if os.path.exists(comm_json_path): + comm_data = FileManager.read_json_file(comm_json_path) + if os.path.exists(matrix_json_path): + matrix_data = FileManager.read_json_file(matrix_json_path) + return rank_id, comm_data, matrix_data + + def set_p2p_link(self, rank_id: int, step_id: str, rank_id_matrix_dict: dict): + ops = rank_id_matrix_dict.get(step_id, {}) + self.add_matrix_ops(rank_id, step_id, ops) + if not ops: + print(f"[WARNING] rank{rank_id} {step_id} do not have communication matrix ops data.") + return + p2p_ops = ops.get(Constant.P2P, {}) + for op_name, link_dict in p2p_ops.items(): + self.append_p2p_link(op_name, link_dict) + + def append_p2p_link(self, op_name, link_dict): + for link in link_dict: + if '-' not in link: + print(f"[WARNING] {op_name} has an invalid link key {link}!") + break + src_rank = int(link.split('-')[0]) + dst_rank = int(link.split('-')[1]) + if src_rank != dst_rank: + rank_set = set([src_rank, dst_rank]) + if rank_set in self.p2p_link: + continue + self.p2p_link.append(rank_set) + + def get_collective_ops_name(self, rank_id: int, comm_op_dict: dict): + for comm_op in comm_op_dict: + if comm_op.startswith('Total'): + continue + group_name = comm_op.split('@')[-1] + self.collective_group_dict[group_name].add(rank_id) + + def add_communication_ops(self, rank_id: str, step_id: str, comm_op_type: str, comm_op_dict: dict): + for comm_op in comm_op_dict: + if comm_op.startswith('Total'): + continue + group_name = comm_op.split('@')[-1] + self.communication_ops.append({ + Constant.RANK_ID: rank_id, + Constant.STEP_ID: step_id, + Constant.COMM_OP_TYPE: comm_op_type, + Constant.COMM_OP_NAME: comm_op, + Constant.GROUP_NAME: group_name, + Constant.COMM_OP_INFO: comm_op_dict.get(comm_op) + }) + + def add_matrix_ops(self, rank_id: int, step_id: str, step_id_dict: dict): + for comm_op_type, comm_dict in step_id_dict.items(): + if comm_op_type != Constant.COLLECTIVE and comm_op_type != Constant.P2P: + print(f"[WARNING] Unknown communication operators type!") + continue + for op_name, op_link_info in comm_dict.items(): + if op_name.startswith('Total'): + continue + group_name = op_name.split('@')[-1] + self.matrix_ops.append({ + Constant.RANK_ID: rank_id, + Constant.STEP_ID: step_id, + Constant.COMM_OP_TYPE: comm_op_type, + Constant.COMM_OP_NAME: op_name, + Constant.GROUP_NAME: group_name, + Constant.COMM_OP_INFO: op_link_info + }) diff --git a/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py b/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py index cc4a5fecbf..3c4b18d3af 100644 --- a/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py +++ b/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py @@ -18,6 +18,8 @@ class CommunicationBandwidthInfo: """ The class represents a Communication Bandwidth Info object """ + step: str + type: str hccl_op_name: str group_name: str transport_type: str diff --git a/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py b/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py index 95b174e859..97794cb69b 100644 --- a/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py +++ b/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py @@ -18,6 +18,8 @@ class CommunicationMatrix: """ The class represents a Communication Matrix object """ + step: str + type: str hccl_op_name: str group_name: str src_rank: str diff --git a/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py b/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py index b1a8e81eae..81f514acdd 100644 --- a/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py +++ b/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py @@ -18,6 +18,8 @@ class CommunicationTimeInfo: """ The class represents a Communication Time Info object """ + step: str + type: str hccl_op_name: str group_name: str start_timestamp: float -- Gitee From 74eaddd7c7c330e8d47aa0dd1cc92fc1ff793b67 Mon Sep 17 00:00:00 2001 From: w00800385 Date: Mon, 26 Feb 2024 14:21:58 +0800 Subject: [PATCH 4/7] add db processor --- .../analysis/analysis_facade.py | 4 +- .../cluster_analyse/analysis/base_analysis.py | 32 ++++- .../analysis/communication/__init__.py | 0 .../communication/communication_analysis.py | 32 +++++ .../communication_analysis_db.py | 24 ++++ .../communication_analysis_json.py} | 131 +----------------- .../analysis/communication_matrix/__init__.py | 0 .../comm_matrix_analysis.py | 23 +++ .../comm_matrix_analysis_db.py | 25 ++++ .../comm_matrix_analysis_json.py | 98 +++++++++++++ .../analysis/step_trace/__init__.py | 0 profiler/cluster_analyse/cluster_analysis.py | 3 +- .../cluster_analyse/common_func/DBManager.py | 16 +++ .../communication_db_group.py | 22 ++- .../communication_group_generator.py | 2 +- .../communication_json_group.py | 4 +- 16 files changed, 277 insertions(+), 139 deletions(-) create mode 100644 profiler/cluster_analyse/analysis/communication/__init__.py create mode 100644 profiler/cluster_analyse/analysis/communication/communication_analysis.py create mode 100644 profiler/cluster_analyse/analysis/communication/communication_analysis_db.py rename profiler/cluster_analyse/analysis/{communication_analysis.py => communication/communication_analysis_json.py} (45%) create mode 100644 profiler/cluster_analyse/analysis/communication_matrix/__init__.py create mode 100644 profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis.py create mode 100644 profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py create mode 100644 profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py create mode 100644 profiler/cluster_analyse/analysis/step_trace/__init__.py diff --git a/profiler/cluster_analyse/analysis/analysis_facade.py b/profiler/cluster_analyse/analysis/analysis_facade.py index 34228f97a2..d521e4c00b 100644 --- a/profiler/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/cluster_analyse/analysis/analysis_facade.py @@ -14,9 +14,9 @@ # limitations under the License. from multiprocessing import Process -from analysis.communication_analysis import CommunicationAnalysis +from analysis.communication.communication_analysis import CommunicationAnalysis +from analysis.communication_matrix.comm_matrix_analysis import CommMatrixAnalysis from analysis.step_trace_time_analysis import StepTraceTimeAnalysis -from analysis.communication_analysis import CommMatrixAnalysis class AnalysisFacade: diff --git a/profiler/cluster_analyse/analysis/base_analysis.py b/profiler/cluster_analyse/analysis/base_analysis.py index 8f0f6eaea3..9248ae328d 100644 --- a/profiler/cluster_analyse/analysis/base_analysis.py +++ b/profiler/cluster_analyse/analysis/base_analysis.py @@ -15,6 +15,7 @@ from abc import abstractmethod from common_func.constant import Constant +from common_func.file_manager import FileManager class BaseAnalysis: @@ -22,7 +23,9 @@ class BaseAnalysis: def __init__(self, param: dict): self.collection_path = param.get(Constant.COLLECTION_PATH) self.data_map = param.get(Constant.DATA_MAP) - self.communication_group = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COMMUNICATION_GROUP) + self.communication_ops = [] + self.collective_group_dict = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) + self.comm_ops_struct = {} @staticmethod def compute_ratio(dividend: float, divisor: float): @@ -34,3 +37,30 @@ class BaseAnalysis: @abstractmethod def run(self): pass + + def dump_data(self): + if not self.comm_ops_struct: + print("[WARNING] There is no final comm ops data generated") + return + output_comm_data = {} + for key in self.comm_ops_struct: + output_comm_data[str(key)] = self.comm_ops_struct.get(key) + FileManager.create_json_file(self.collection_path, output_comm_data, self.SAVED_JSON) + + def split_op_by_group(self): + for single_op in self.communication_ops: + if single_op.get(Constant.COMM_OP_TYPE) == Constant.P2P: + rank_tup = Constant.P2P + else: + rank_tup = tuple(self.collective_group_dict.get(single_op.get(Constant.GROUP_NAME), [])) + rank_id = single_op.get(Constant.RANK_ID, 'N/A') + step_id = single_op.get(Constant.STEP_ID, 'N/A') + op_name = single_op.get(Constant.COMM_OP_NAME, 'N/A') + op_info = single_op.get(Constant.COMM_OP_INFO) + self.comm_ops_struct.setdefault(rank_tup, {}).setdefault(step_id, {}).\ + setdefault(op_name, {}).setdefault(rank_id, op_info) + + def combine_ops_total_info(self): + for rank_tup, group_dict in self.comm_ops_struct.items(): + for step_id, communication_ops in group_dict.items(): + self.compute_total_info(communication_ops) diff --git a/profiler/cluster_analyse/analysis/communication/__init__.py b/profiler/cluster_analyse/analysis/communication/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis.py b/profiler/cluster_analyse/analysis/communication/communication_analysis.py new file mode 100644 index 0000000000..a8a61785b2 --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis.py @@ -0,0 +1,32 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from analysis.communication.communication_analysis_db import CommunicationAnalysisDB +from analysis.communication.communication_analysis_json import CommunicationAnalysisJson +from common_func.constant import Constant + + +class CommunicationAnalysis: + + GROUP_MAP = { + Constant.DB: CommunicationAnalysisDB, + Constant.TEXT: CommunicationAnalysisJson + } + + def __init__(self, param: dict): + self.generator = self.GROUP_MAP[param[Constant.DATA_TYPE]](param) + + def run(self): + self.generator.run() diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py new file mode 100644 index 0000000000..17e91f2122 --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from analysis.base_analysis import BaseAnalysis + + +class CommunicationAnalysisDB(BaseAnalysis): + def run(self): + pass + + def __init__(self, param: any): + super().__init__(param) diff --git a/profiler/cluster_analyse/analysis/communication_analysis.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_json.py similarity index 45% rename from profiler/cluster_analyse/analysis/communication_analysis.py rename to profiler/cluster_analyse/analysis/communication/communication_analysis_json.py index a3c51d46a9..e22e80950c 100644 --- a/profiler/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_json.py @@ -14,61 +14,12 @@ # limitations under the License. from collections import defaultdict -from abc import abstractmethod +from analysis.base_analysis import BaseAnalysis from common_func.constant import Constant -from common_func.file_manager import FileManager -class BaseCommAnalysis: - - def __init__(self, param: dict): - self.collection_path = param.get(Constant.COLLECTION_PATH) - self.data_map = param.get(Constant.DATA_MAP) - self.communication_ops = [] - self.collective_group_dict = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.COLLECTIVE_GROUP) - self.comm_ops_struct = {} - - @staticmethod - def compute_ratio(dividend: float, divisor: float): - if abs(divisor) < Constant.EPS: - return 0 - else: - return round(dividend / divisor, 4) - - @abstractmethod - def run(self): - pass - - def dump_data(self): - if not self.comm_ops_struct: - print("[WARNING] There is no final comm ops data generated") - return - output_comm_data = {} - for key in self.comm_ops_struct: - output_comm_data[str(key)] = self.comm_ops_struct.get(key) - FileManager.create_json_file(self.collection_path, output_comm_data, self.SAVED_JSON) - - def split_op_by_group(self): - for single_op in self.communication_ops: - if single_op.get(Constant.COMM_OP_TYPE) == Constant.P2P: - rank_tup = Constant.P2P - else: - rank_tup = tuple(self.collective_group_dict.get(single_op.get(Constant.GROUP_NAME), [])) - rank_id = single_op.get(Constant.RANK_ID, 'N/A') - step_id = single_op.get(Constant.STEP_ID, 'N/A') - op_name = single_op.get(Constant.COMM_OP_NAME, 'N/A') - op_info = single_op.get(Constant.COMM_OP_INFO) - self.comm_ops_struct.setdefault(rank_tup, {}).setdefault(step_id, {}).\ - setdefault(op_name, {}).setdefault(rank_id, op_info) - - def combine_ops_total_info(self): - for rank_tup, group_dict in self.comm_ops_struct.items(): - for step_id, communication_ops in group_dict.items(): - self.compute_total_info(communication_ops) - - -class CommunicationAnalysis(BaseCommAnalysis): +class CommunicationAnalysisJson(BaseAnalysis): SAVED_JSON = "cluster_communication.json" def __init__(self, param: dict): @@ -143,81 +94,3 @@ class CommunicationAnalysis(BaseCommAnalysis): self.compute_ratio(bandwidth_dict.get(Constant.TRANSIT_SIZE_MB, 0), bandwidth_dict.get(Constant.TRANSIT_TIME_MS, 0)) - -class CommMatrixAnalysis(BaseCommAnalysis): - SAVED_JSON = "cluster_communication_matrix.json" - - def __init__(self, param: dict): - super().__init__(param) - self.communication_ops = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.MATRIX_OPS) - - @staticmethod - def combine_link(link_info_dict: dict, single_link_dict: dict): - link_info_dict[Constant.TRANSPORT_TYPE] = single_link_dict.get(Constant.TRANSPORT_TYPE) - link_info_dict[Constant.TRANSIT_TIME_MS] += single_link_dict.get(Constant.TRANSIT_TIME_MS, 0) - link_info_dict[Constant.TRANSIT_SIZE_MB] += single_link_dict.get(Constant.TRANSIT_SIZE_MB, 0) - - def run(self): - self.split_op_by_group() - self.combine_ops_total_info() - self.dump_data() - - def compute_total_info(self, step_dict: dict): - self.merge_same_links(step_dict) - self.combine_link_info(step_dict) - - def merge_same_links(self, step_dict: dict): - def process_link_key(): - for link_key in rank_dict: - if '-' not in link_key: - print(f"[WARNING] {op_name} has an invalid link key {link_key}!") - break - src_rank = link_key.split('-')[0] - dst_rank = link_key.split('-')[1] - if src_rank == dst_rank: - if src_rank not in project_local_global_rank_map: - project_local_global_rank_map[src_rank] = rank_id - elif project_local_global_rank_map.get(src_rank) != rank_id: - print(f"[WARNING] In the same communication group, local ranks projecting to global ranks repeat!") - self.combine_link(link_info[link_key], rank_dict[link_key]) - - def convert_local_to_global_rank(): - tmp_link = {} - for link_key, link_dict in link_info.items(): - src_rank = link_key.split('-')[0] - dst_rank = link_key.split('-')[1] - src_rank = project_local_global_rank_map[src_rank] \ - if src_rank in project_local_global_rank_map else src_rank - dst_rank = project_local_global_rank_map[dst_rank] \ - if dst_rank in project_local_global_rank_map else dst_rank - link_dict[Constant.BANDWIDTH_GB_S] = \ - self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), - link_dict.get(Constant.TRANSIT_TIME_MS, 0)) - tmp_link[f"{src_rank}-{dst_rank}"] = link_dict - return tmp_link - - project_local_global_rank_map = dict() - for op_name, op_dict in step_dict.items(): - link_info = defaultdict(lambda: { - Constant.TRANSPORT_TYPE: '', - Constant.TRANSIT_TIME_MS: 0, - Constant.TRANSIT_SIZE_MB: 0 - }) - for rank_id, rank_dict in op_dict.items(): - process_link_key() - step_dict[op_name] = convert_local_to_global_rank() - - def combine_link_info(self, step_dict: dict): - total_op_info = defaultdict(lambda: { - Constant.TRANSPORT_TYPE: '', - Constant.TRANSIT_TIME_MS: 0, - Constant.TRANSIT_SIZE_MB: 0 - }) - for op_name, op_dict in step_dict.items(): - for link_key, link_dict in op_dict.items(): - self.combine_link(total_op_info[link_key], link_dict) - for link_key, link_dict in total_op_info.items(): - link_dict[Constant.BANDWIDTH_GB_S] = \ - self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), - link_dict.get(Constant.TRANSIT_TIME_MS, 0)) - step_dict[Constant.TOTAL_OP_INFO] = total_op_info diff --git a/profiler/cluster_analyse/analysis/communication_matrix/__init__.py b/profiler/cluster_analyse/analysis/communication_matrix/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis.py new file mode 100644 index 0000000000..1244feae8a --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis.py @@ -0,0 +1,23 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from common_func.constant import Constant + + +class CommMatrixAnalysis: + + GROUP_MAP = { + Constant.DB: CommunicationAnalysisDB, + Constant.TEXT: CommMatrixAnalysisJson + } \ No newline at end of file diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py new file mode 100644 index 0000000000..ef076202cc --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py @@ -0,0 +1,25 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from analysis.base_analysis import BaseAnalysis + + +class CommMatrixAnalysisDB(BaseAnalysis): + + def run(self): + pass + + def __init__(self, param: any): + super().__init__(param) diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py new file mode 100644 index 0000000000..b374775ec2 --- /dev/null +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py @@ -0,0 +1,98 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections import defaultdict + +from analysis.base_analysis import BaseAnalysis +from common_func.constant import Constant + + +class CommMatrixAnalysisJson(BaseAnalysis): + SAVED_JSON = "cluster_communication_matrix.json" + + def __init__(self, param: dict): + super().__init__(param) + self.communication_ops = param.get(Constant.COMM_DATA_DICT, {}).get(Constant.MATRIX_OPS) + + @staticmethod + def combine_link(link_info_dict: dict, single_link_dict: dict): + link_info_dict[Constant.TRANSPORT_TYPE] = single_link_dict.get(Constant.TRANSPORT_TYPE) + link_info_dict[Constant.TRANSIT_TIME_MS] += single_link_dict.get(Constant.TRANSIT_TIME_MS, 0) + link_info_dict[Constant.TRANSIT_SIZE_MB] += single_link_dict.get(Constant.TRANSIT_SIZE_MB, 0) + + def run(self): + self.split_op_by_group() + self.combine_ops_total_info() + self.dump_data() + + def compute_total_info(self, step_dict: dict): + self.merge_same_links(step_dict) + self.combine_link_info(step_dict) + + def merge_same_links(self, step_dict: dict): + def process_link_key(): + for link_key in rank_dict: + if '-' not in link_key: + print(f"[WARNING] {op_name} has an invalid link key {link_key}!") + break + src_rank = link_key.split('-')[0] + dst_rank = link_key.split('-')[1] + if src_rank == dst_rank: + if src_rank not in project_local_global_rank_map: + project_local_global_rank_map[src_rank] = rank_id + elif project_local_global_rank_map.get(src_rank) != rank_id: + print(f"[WARNING] In the same communication group, local ranks projecting to global ranks " + f"repeat!") + self.combine_link(link_info[link_key], rank_dict[link_key]) + + def convert_local_to_global_rank(): + tmp_link = {} + for link_key, link_dict in link_info.items(): + src_rank = link_key.split('-')[0] + dst_rank = link_key.split('-')[1] + src_rank = project_local_global_rank_map[src_rank] \ + if src_rank in project_local_global_rank_map else src_rank + dst_rank = project_local_global_rank_map[dst_rank] \ + if dst_rank in project_local_global_rank_map else dst_rank + link_dict[Constant.BANDWIDTH_GB_S] = \ + self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), + link_dict.get(Constant.TRANSIT_TIME_MS, 0)) + tmp_link[f"{src_rank}-{dst_rank}"] = link_dict + return tmp_link + + project_local_global_rank_map = dict() + for op_name, op_dict in step_dict.items(): + link_info = defaultdict(lambda: { + Constant.TRANSPORT_TYPE: '', + Constant.TRANSIT_TIME_MS: 0, + Constant.TRANSIT_SIZE_MB: 0 + }) + for rank_id, rank_dict in op_dict.items(): + process_link_key() + step_dict[op_name] = convert_local_to_global_rank() + + def combine_link_info(self, step_dict: dict): + total_op_info = defaultdict(lambda: { + Constant.TRANSPORT_TYPE: '', + Constant.TRANSIT_TIME_MS: 0, + Constant.TRANSIT_SIZE_MB: 0 + }) + for op_name, op_dict in step_dict.items(): + for link_key, link_dict in op_dict.items(): + self.combine_link(total_op_info[link_key], link_dict) + for link_key, link_dict in total_op_info.items(): + link_dict[Constant.BANDWIDTH_GB_S] = \ + self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), + link_dict.get(Constant.TRANSIT_TIME_MS, 0)) + step_dict[Constant.TOTAL_OP_INFO] = total_op_info diff --git a/profiler/cluster_analyse/analysis/step_trace/__init__.py b/profiler/cluster_analyse/analysis/step_trace/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index 604cc4ff2a..db932a9ed9 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -90,7 +90,8 @@ class Interface: params = { Constant.COLLECTION_PATH: self.collection_path, Constant.DATA_MAP: data_map, - Constant.COMM_DATA_DICT: comm_data_dict + Constant.COMM_DATA_DICT: comm_data_dict, + Constant.DATA_TYPE: data_type } AnalysisFacade(params).cluster_analyze() diff --git a/profiler/cluster_analyse/common_func/DBManager.py b/profiler/cluster_analyse/common_func/DBManager.py index 24a66b992b..f8343f4f44 100644 --- a/profiler/cluster_analyse/common_func/DBManager.py +++ b/profiler/cluster_analyse/common_func/DBManager.py @@ -138,6 +138,22 @@ class DBManager: print("[ERROR] conn is invalid param") return False + @staticmethod + def executemany_sql(conn: any, sql: str, params: any) -> bool: + """ + execute many sql once + """ + try: + if isinstance(conn, sqlite3.Connection): + conn.cursor().executemany(sql, params) + conn.commit() + return True + except sqlite3.Error as err: + print(f"[ERROR] {err}") + return False + print("[ERROR] conn is invalid param") + return False + @classmethod def fetch_all_data(cls: any, curs: any, sql: str, param: tuple = None, dto_class: any = None) -> list: """ diff --git a/profiler/cluster_analyse/communication_group/communication_db_group.py b/profiler/cluster_analyse/communication_group/communication_db_group.py index f3754251b8..93fbabdea5 100644 --- a/profiler/cluster_analyse/communication_group/communication_db_group.py +++ b/profiler/cluster_analyse/communication_group/communication_db_group.py @@ -28,6 +28,9 @@ class CommunicationDBGroup(BaseCommunicationGroup): def __init__(self, collection_path: str, data_map: dict, data_type: str): super().__init__(collection_path, data_map, data_type) + self.communication_bandwidth_info = [] + self.communication_time_info = [] + self.matrix_info = [] def read_communication_func(self, params: tuple): if len(params) < 3: @@ -66,13 +69,22 @@ class CommunicationDBGroup(BaseCommunicationGroup): conn, cursor = DBManager.create_connect_db(result_db) for data_type, data_list in self.communication_group.items(): for data in data_list: - data = [data_type, data] + rank_set = "(" + ",".join(str(i) for i in data) + ")" + data = [data_type, rank_set] res.append(data) if res: sql = "insert into {} values ({value})".format(self.COMMUNICATION_GROUP_TABLE, value="?," * (len(res[0]) - 1) + "?") - DBManager.execute_sql(conn, sql, res) + DBManager.executemany_sql(conn, sql, res) DBManager.destroy_db_connect(conn, cursor) + comm_data_dict = { + Constant.COLLECTIVE_GROUP: self.collective_group_dict, + Constant.COMMUNICATION_TIME_INFO: self.communication_time_info, + Constant.COMMUNICATION_BANDWIDTH_INFO: self.communication_bandwidth_info, + Constant.MATRIX_OPS: self.matrix_info, + Constant.COMMUNICATION_GROUP: self.communication_group + } + return comm_data_dict def analyze_communication_data(self): for rank_id, time_data, bandwidth_data, matrix_data in self.rank_comm_dir_dict: @@ -82,8 +94,12 @@ class CommunicationDBGroup(BaseCommunicationGroup): if data.type == Constant.COLLECTIVE: self.collective_group_dict[data.group_name].add(rank_id) setattr(data, "rank_id", rank_id) + setattr(data, "step", step) + self.communication_time_info.append(data) for data in bandwidth_data[step]: setattr(data, "rank_id", rank_id) + setattr(data, "step", step) + self.communication_bandwidth_info.append(data) def add_p2p_and_rank(self, rank_id: int, step: str, data_dict: dict): data_list = data_dict[step] @@ -92,6 +108,8 @@ class CommunicationDBGroup(BaseCommunicationGroup): print(f"[WARNING] Unknown communication operators type!") continue setattr(data, "rank_id", rank_id) + setattr(data, "step", step) + self.matrix_info.append(data) if data.type == Constant.P2P: if data.src_rank != data.dst_rank: rank_set = {data.src_rank, data.dst_rank} diff --git a/profiler/cluster_analyse/communication_group/communication_group_generator.py b/profiler/cluster_analyse/communication_group/communication_group_generator.py index 1cec39c11c..eac377f470 100644 --- a/profiler/cluster_analyse/communication_group/communication_group_generator.py +++ b/profiler/cluster_analyse/communication_group/communication_group_generator.py @@ -19,7 +19,7 @@ from communication_group.communication_json_group import CommunicationJsonGroup class CommunicationGroupGenerator: - COMMUNICATION_GROUP_JSON = "communication_group.json" + GROUP_MAP = { Constant.DB: CommunicationDBGroup, Constant.TEXT: CommunicationJsonGroup diff --git a/profiler/cluster_analyse/communication_group/communication_json_group.py b/profiler/cluster_analyse/communication_group/communication_json_group.py index 7a6d3df712..1b725add26 100644 --- a/profiler/cluster_analyse/communication_group/communication_json_group.py +++ b/profiler/cluster_analyse/communication_group/communication_json_group.py @@ -14,11 +14,10 @@ # limitations under the License. import os -from collections import defaultdict from common_func.constant import Constant from common_func.file_manager import FileManager -from communication_group.base_communication_group import BaseCommunicationGroup, UnionFind +from communication_group.base_communication_group import BaseCommunicationGroup class CommunicationJsonGroup(BaseCommunicationGroup): @@ -26,7 +25,6 @@ class CommunicationJsonGroup(BaseCommunicationGroup): def __init__(self, collection_path: str, data_map: dict, data_type: str): super().__init__(collection_path, data_map, data_type) - self.p2p_group_dict = defaultdict(list) self.communication_ops = [] self.matrix_ops = [] -- Gitee From 54a5b6101aefea525ca102e74f4cff94aa5a6187 Mon Sep 17 00:00:00 2001 From: w00800385 Date: Wed, 28 Feb 2024 11:01:26 +0800 Subject: [PATCH 5/7] add table config --- .../cluster_analyse/common_func/tables_config.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/profiler/cluster_analyse/common_func/tables_config.py b/profiler/cluster_analyse/common_func/tables_config.py index 851b224d56..8331dbaf90 100644 --- a/profiler/cluster_analyse/common_func/tables_config.py +++ b/profiler/cluster_analyse/common_func/tables_config.py @@ -62,5 +62,18 @@ class TablesConfig: ("bandwidth", "NUMERIC, null"), ("transport_type", "TEXT, null"), ("op_name", "TEXT, null") + ], + "ClusterStepTraceTimeMap": [ + ("step", "TEXT, null"), + ("type", "TEXT, null"), + ("index", "TEXT, null"), + ("computing", "NUMERIC, null"), + ("communication_not_overlapped", "NUMERIC, null"), + ("overlapped", "NUMERIC, null"), + ("communication", "NUMERIC, null"), + ("free", "NUMERIC, null"), + ("stage", "NUMERIC, null"), + ("bubble", "NUMERIC, null"), + ("communication_not_overlapped_and_exclude_receive", "NUMERIC, null") ] } -- Gitee From 65a7569c80cc21ec0651b32ca9c83949e9a8b012 Mon Sep 17 00:00:00 2001 From: w00800385 Date: Wed, 28 Feb 2024 20:24:35 +0800 Subject: [PATCH 6/7] bugfix --- .../analysis/analysis_facade.py | 6 +- ...json_analysis.py => base_analysis_json.py} | 15 +- ...analysis.py => comm_analysis_generator.py} | 6 +- .../communication_analysis_db.py | 177 ++++++++++-------- .../communication_analysis_json.py | 4 +- .../comm_matrix_analysis_db.py | 120 ++++++------ .../comm_matrix_analysis_json.py | 16 +- ...x_analysis.py => comm_matrix_generator.py} | 6 +- .../analysis/step_trace/__init__.py | 0 .../analysis/step_trace_time_analysis.py | 16 +- profiler/cluster_analyse/cluster_analysis.py | 1 + .../cluster_analyse/common_func/constant.py | 16 +- .../{DBManager.py => db_manager.py} | 100 +++------- .../common_func/empty_class.py | 10 - .../common_func/file_manager.py | 16 +- .../common_func/table_constant.py | 27 +++ .../communication_db_group.py | 50 +++-- .../communication_bandwidth_info_bean.py | 34 ---- .../prof_bean/communication_matrix_bean.py | 33 ---- .../prof_bean/communication_time_info_bean.py | 34 ---- 20 files changed, 285 insertions(+), 402 deletions(-) rename profiler/cluster_analyse/analysis/{base_json_analysis.py => base_analysis_json.py} (86%) rename profiler/cluster_analyse/analysis/communication/{communication_analysis.py => comm_analysis_generator.py} (90%) rename profiler/cluster_analyse/analysis/communication_matrix/{comm_matrix_analysis.py => comm_matrix_generator.py} (90%) delete mode 100644 profiler/cluster_analyse/analysis/step_trace/__init__.py rename profiler/cluster_analyse/common_func/{DBManager.py => db_manager.py} (65%) create mode 100644 profiler/cluster_analyse/common_func/table_constant.py delete mode 100644 profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py delete mode 100644 profiler/cluster_analyse/prof_bean/communication_matrix_bean.py delete mode 100644 profiler/cluster_analyse/prof_bean/communication_time_info_bean.py diff --git a/profiler/cluster_analyse/analysis/analysis_facade.py b/profiler/cluster_analyse/analysis/analysis_facade.py index 60acc4ccd9..0b870bbaaf 100644 --- a/profiler/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/cluster_analyse/analysis/analysis_facade.py @@ -14,13 +14,13 @@ # limitations under the License. from multiprocessing import Process -from analysis.communication.communication_analysis import CommunicationAnalysis -from analysis.communication_matrix.comm_matrix_analysis import CommMatrixAnalysis +from analysis.communication.comm_analysis_generator import CommunicationAnalysisGenerator +from analysis.communication_matrix.comm_matrix_generator import CommMatrixAnalysisGenerator from analysis.step_trace_time_analysis import StepTraceTimeAnalysis class AnalysisFacade: - analysis_module = {CommunicationAnalysis, StepTraceTimeAnalysis, CommMatrixAnalysis} + analysis_module = {CommunicationAnalysisGenerator, StepTraceTimeAnalysis, CommMatrixAnalysisGenerator} def __init__(self, params: dict): self.params = params diff --git a/profiler/cluster_analyse/analysis/base_json_analysis.py b/profiler/cluster_analyse/analysis/base_analysis_json.py similarity index 86% rename from profiler/cluster_analyse/analysis/base_json_analysis.py rename to profiler/cluster_analyse/analysis/base_analysis_json.py index 64f26d965b..be8e42d3d5 100644 --- a/profiler/cluster_analyse/analysis/base_json_analysis.py +++ b/profiler/cluster_analyse/analysis/base_analysis_json.py @@ -18,7 +18,7 @@ from common_func.constant import Constant from common_func.file_manager import FileManager -class BaseJsonAnalysis: +class BaseAnalysisJson: def __init__(self, param: dict): self.collection_path = param.get(Constant.COLLECTION_PATH) @@ -34,6 +34,19 @@ class BaseJsonAnalysis: else: return round(dividend / divisor, 4) + @staticmethod + def check_add_op(op_name: str): + """ + 兼容2个版本,判断是否需要将此算子信息相加 + """ + stat_list = ["middle", "top", "bottom", "total"] + total = "total" + for stat_name in stat_list: + if stat_name in op_name: + if stat_name != total: + return False + return True + @abstractmethod def run(self): pass diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis.py b/profiler/cluster_analyse/analysis/communication/comm_analysis_generator.py similarity index 90% rename from profiler/cluster_analyse/analysis/communication/communication_analysis.py rename to profiler/cluster_analyse/analysis/communication/comm_analysis_generator.py index 70d42d108f..2e727d9f2b 100644 --- a/profiler/cluster_analyse/analysis/communication/communication_analysis.py +++ b/profiler/cluster_analyse/analysis/communication/comm_analysis_generator.py @@ -14,15 +14,15 @@ # limitations under the License. from analysis.communication.communication_analysis_db import CommunicationAnalysisDB -from analysis.communication.communication_analysis_json import CommunicationJsonAnalysisJson +from analysis.communication.communication_analysis_json import CommunicationJsonAnalysis from common_func.constant import Constant -class CommunicationAnalysis: +class CommunicationAnalysisGenerator: GROUP_MAP = { Constant.DB: CommunicationAnalysisDB, - Constant.TEXT: CommunicationJsonAnalysisJson + Constant.TEXT: CommunicationJsonAnalysis } def __init__(self, params: dict): diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py index df393a88d0..92a4d8a09d 100644 --- a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py @@ -14,11 +14,10 @@ # limitations under the License. import os -from analysis.base_json_analysis import BaseJsonAnalysis -from common_func.DBManager import CustomizedNamedtupleFactory, DBManager +from analysis.base_analysis_json import BaseAnalysisJson +from common_func.db_manager import DBManager from common_func.constant import Constant -from prof_bean.communication_bandwidth_info_bean import CommunicationBandwidthInfo -from prof_bean.communication_time_info_bean import CommunicationTimeInfo +from common_func.table_constant import TableConstant class CommunicationAnalysisDB: @@ -39,8 +38,11 @@ class CommunicationAnalysisDB: self.res_comm_bandwidth = [] def run(self): - self.split_and_add_rank_set() - self.combine_total_info() + if not self.communication_time_info and not self.communication_bandwidth_info: + return + self.split_and_add_rank_set(self.communication_time_info, self.comm_time_struct) + self.split_and_add_rank_set(self.communication_bandwidth_info, self.comm_bandwidth_struct) + self.compute_total_info() self.dump_data() def dump_data(self): @@ -50,119 +52,128 @@ class CommunicationAnalysisDB: res_time, res_bandwidth = [], [] conn, cursor = DBManager.create_connect_db(result_db) for data in self.res_comm_time: - res_time.append(list(data[:3] + data[4:])) - for data in self.res_comm_bandwidth: - res_bandwidth.append(list(data[:3] + data[4:])) + res_time.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.RANK_ID], + data[TableConstant.HCCL_OP_NAME], data[TableConstant.GROUP_NAME], + data[TableConstant.START_TIMESTAMP], data[TableConstant.ELAPSED_TIME], + data[TableConstant.TRANSIT_TIME], data[TableConstant.WAIT_TIME], + data[TableConstant.SYNCHRONIZATION_TIME], data[TableConstant.IDLE_TIME], + data[TableConstant.SYNCHRONIZATION_TIME_RATIO], data[TableConstant.WAIT_TIME_RATIO]]) if res_time: sql = "insert into {} values ({value})".format(self.COMMUNICATION_TIME_TABLE, value="?," * (len(res_time[0]) - 1) + "?") DBManager.executemany_sql(conn, sql, res_time) + for data in self.res_comm_bandwidth: + res_bandwidth.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.RANK_ID], + data[TableConstant.HCCL_OP_NAME], data[TableConstant.GROUP_NAME], + data[TableConstant.BAND_TYPE], data[TableConstant.TRANSIT_SIZE], + data[TableConstant.TRANSIT_TIME], data[TableConstant.BANDWIDTH], + data[TableConstant.LARGE_PACKET_RATIO], data[TableConstant.PACKAGE_SIZE], + data[TableConstant.COUNT], data[TableConstant.TOTAL_DURATION]]) if res_bandwidth: sql = "insert into {} values ({value})".format(self.COMMUNICATION_BANDWIDTH_TABLE, - value="?," * (len(res_time[0]) - 1) + "?") - DBManager.executemany_sql(conn, sql, res_time) + value="?," * (len(res_bandwidth[0]) - 1) + "?") + DBManager.executemany_sql(conn, sql, res_bandwidth) DBManager.destroy_db_connect(conn, cursor) - def split_and_add_rank_set(self): - for data in self.communication_time_info: - if data.type == Constant.P2P: - rank_tuple = Constant.P2P - else: - rank_tuple = self.collective_group_dict.get(data.group_name) - self.comm_time_struct.setdefault(rank_tuple, {}).setdefault(data.step, []).append(data) - for data in self.communication_bandwidth_info: - if data.type == Constant.P2P: + def split_and_add_rank_set(self, data_list, res_dict): + for data in data_list: + if data[TableConstant.TYPE] == Constant.P2P: rank_tuple = Constant.P2P else: - rank_tuple = self.collective_group_dict.get(data.group_name) - self.comm_bandwidth_struct.setdefault(rank_tuple, {}).setdefault(data.step, []).append(data) + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + res_dict.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], []).append(data) - def combine_total_info(self): - for rank_tuple, op_dict in self.comm_time_struct: + def compute_total_info(self): + for rank_tuple, op_dict in self.comm_time_struct.items(): if rank_tuple != Constant.P2P: for step, data_list in op_dict.items(): - self.compute_rank_set_total_time_info(data_list, rank_tuple, step) + self.compute_rank_set_total_time_info(data_list, rank_tuple) else: rank_set = set() for step, data_list in op_dict.items(): - rank_set.add(data.rank_id for data in data_list) + rank_set.add(data[TableConstant.RANK_ID] for data in data_list) for step, data_list in op_dict.items(): - self.compute_rank_set_total_time_info(data_list, rank_set, step, True) - for rank_tuple, op_dict in self.comm_bandwidth_struct: + self.compute_rank_set_total_time_info(data_list, rank_set, True) + for rank_tuple, op_dict in self.comm_bandwidth_struct.items(): for step, data_list in op_dict.items(): if rank_tuple != Constant.P2P: - self.compute_rank_set_total_bandwidth_info(data_list, rank_tuple, step) + self.compute_rank_set_total_bandwidth_info(data_list, rank_tuple) else: - self.compute_rank_set_total_bandwidth_info(data_list, rank_tuple, step, True) + self.compute_rank_set_total_bandwidth_info(data_list, rank_tuple, True) - def compute_rank_set_total_bandwidth_info(self, data_list, rank_tuple, step, is_p2p=False): + def compute_rank_set_total_bandwidth_info(self, data_list, rank_tuple, is_p2p=False): if not data_list: return data_dict = {} - total_list = [] rank_tuple = "(" + ",".join(str(i) for i in rank_tuple) + ")" if not is_p2p else Constant.P2P for data in data_list: - data = data.replace(rank_set=rank_tuple) - rank_band_type = self.RANK_BAND_TYPE.format(data.rank_id, data.band_type) + data[TableConstant.RANK_SET] = rank_tuple + rank_band_type = self.RANK_BAND_TYPE.format(data[TableConstant.RANK_ID], data[TableConstant.BAND_TYPE]) data_dict.setdefault(rank_band_type, []).append(data) - self.res_comm_bandwidth.extend(data_list) + self.res_comm_bandwidth.append(data) for rank_band_type, bandwidth_list in data_dict.items(): package_set = set() for data in bandwidth_list: - package_set.add(data.package_size) + package_set.add(data[TableConstant.PACKAGE_SIZE]) for package in package_set: - transit_time, transit_size, count, total_duration = 0.0, 0.0, 0.0, 0.0 - total_comm_bandwidth_info = \ - CustomizedNamedtupleFactory.generate_named_tuple_from_dto(CommunicationBandwidthInfo) - total_comm_bandwidth_info = total_comm_bandwidth_info.replace(rankset=rank_tuple) + total_comm_bandwidth_info = dict() for data in bandwidth_list: - transit_time += data.transit_time - transit_size += data.transit_size - if data.package_size == package: - count += data.count - total_duration += data.total_duration - total_comm_bandwidth_info = total_comm_bandwidth_info.replace(package_size=package, - step=step, - transit_time=transit_time, - transit_size=transit_size, - count=count, - hccl_op_name=Constant.TOTAL_OP_INFO, - total_duration=total_duration, - band_type=bandwidth_list[0].band_type, - rank_id=bandwidth_list[0].rank_id) - bandwidth = BaseJsonAnalysis.compute_ratio(total_comm_bandwidth_info.transit_size, - total_comm_bandwidth_info.transit_time) - total_comm_bandwidth_info = total_comm_bandwidth_info.replace(bandwidth=bandwidth) - total_list.append(total_comm_bandwidth_info) - self.res_comm_bandwidth.extend(total_list) + self.compute_bandwidth(total_comm_bandwidth_info, data, package) + bandwidth = BaseAnalysisJson.compute_ratio(TableConstant.TRANSIT_SIZE, TableConstant.TRANSIT_TIME) + total_comm_bandwidth_info[TableConstant.BANDWIDTH] = bandwidth + total_comm_bandwidth_info[TableConstant.PACKAGE_SIZE] = package + total_comm_bandwidth_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_comm_bandwidth_info[TableConstant.GROUP_NAME] = "" + total_comm_bandwidth_info[TableConstant.LARGE_PACKET_RATIO] = 0.0 + self.res_comm_bandwidth.append(total_comm_bandwidth_info) + + def compute_bandwidth(self, res_dict, data_dict, package): + for key in data_dict.keys(): + if key in [TableConstant.TRANSIT_TIME, TableConstant.TRANSIT_SIZE]: + if key not in res_dict.keys(): + res_dict[key] = 0.0 + res_dict[key] += data_dict[key] + elif key in [TableConstant.COUNT, TableConstant.TOTAL_DURATION]: + if data_dict[TableConstant.PACKAGE_SIZE] == package: + if key not in res_dict.keys(): + res_dict[key] = 0.0 + res_dict[key] += data_dict[key] + else: + res_dict[key] = 0.0 + else: + res_dict[key] = data_dict[key] + + def compute_time(self, res_dict, data_dict, dict_key): + if dict_key.endswith(self.TIME_EXTENSION): + if dict_key not in res_dict.keys(): + res_dict[dict_key] = 0.0 + res_dict[dict_key] += data_dict[dict_key] + else: + res_dict[dict_key] = data_dict[dict_key] - def compute_rank_set_total_time_info(self, data_list: list, rank_tuple: any, step: str, is_p2p: bool = False): + def compute_rank_set_total_time_info(self, data_list: list, rank_tuple: any, is_p2p: bool = False): if not data_list: return - total_list = [] rank_set = "(" + ",".join(str(i) for i in rank_tuple) + ")" if not is_p2p else Constant.P2P for rank_id in rank_tuple: - total_comm_time_info = CustomizedNamedtupleFactory.generate_named_tuple_from_dto(CommunicationTimeInfo) - total_comm_time_info = total_comm_time_info.replace(rank_set=rank_set) + total_comm_time_info = dict() for data in data_list: - if data.rank_id == rank_id: - data = data.replace(rank_set=rank_set) - for index, item in enumerate(data._fields): - if item.endwith(self.TIME_EXTENSION): - total_comm_time_info = (total_comm_time_info[:index] + (total_comm_time_info[index] + - data[index]) - + total_comm_time_info[index:]) - synchronization_time_ratio = BaseJsonAnalysis.compute_ratio(total_comm_time_info.synchronization_time, - total_comm_time_info.synchronization_time + - total_comm_time_info.transit_time) - wait_time_ratio = BaseJsonAnalysis.compute_ratio(total_comm_time_info.wait_time, - total_comm_time_info.wait_time + - total_comm_time_info.transit_time) - total_comm_time_info = total_comm_time_info.replace(hccl_op_name=Constant.TOTAL_OP_INFO, - step=step, - rank_id=rank_id, - wait_time_ratio=wait_time_ratio, - synchronization_time_ratio=synchronization_time_ratio) - total_list.append(total_comm_time_info) + if data[TableConstant.RANK_ID] == rank_id: + data[TableConstant.RANK_SET] = rank_set + data[TableConstant.SYNCHRONIZATION_TIME_RATIO] = 0.0 + data[TableConstant.WAIT_TIME_RATIO] = 0.0 + for key, value in data.items(): + self.compute_time(total_comm_time_info, data, key) + syn_ratio = BaseAnalysisJson.compute_ratio(total_comm_time_info[TableConstant.SYNCHRONIZATION_TIME], + total_comm_time_info[TableConstant.SYNCHRONIZATION_TIME] + + total_comm_time_info[TableConstant.TRANSIT_TIME]) + wait_time_ratio = BaseAnalysisJson.compute_ratio(total_comm_time_info[TableConstant.WAIT_TIME], + total_comm_time_info[TableConstant.WAIT_TIME] + + total_comm_time_info[TableConstant.TRANSIT_TIME]) + total_comm_time_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_comm_time_info[TableConstant.GROUP_NAME] = "" + total_comm_time_info[TableConstant.START_TIMESTAMP] = 0.0 + total_comm_time_info[TableConstant.WAIT_TIME_RATIO] = wait_time_ratio + total_comm_time_info[TableConstant.SYNCHRONIZATION_TIME_RATIO] = syn_ratio + self.res_comm_time.append(total_comm_time_info) self.res_comm_time.extend(data_list) - self.res_comm_time.extend(total_list) diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis_json.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_json.py index e5739e1835..9b86eada4b 100644 --- a/profiler/cluster_analyse/analysis/communication/communication_analysis_json.py +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_json.py @@ -15,11 +15,11 @@ from collections import defaultdict -from analysis.base_json_analysis import BaseJsonAnalysis +from analysis.base_analysis_json import BaseAnalysisJson from common_func.constant import Constant -class CommunicationJsonAnalysisJson(BaseJsonAnalysis): +class CommunicationJsonAnalysis(BaseAnalysisJson): SAVED_JSON = "cluster_communication.json" def __init__(self, param: dict): diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py index 2fd502e940..0903bb639d 100644 --- a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py @@ -14,16 +14,14 @@ # limitations under the License. import os -from analysis.base_json_analysis import BaseJsonAnalysis -from common_func.DBManager import CustomizedNamedtupleFactory, DBManager +from analysis.base_analysis_json import BaseAnalysisJson +from common_func.db_manager import DBManager from common_func.constant import Constant -from prof_bean.communication_matrix_bean import CommunicationMatrix +from common_func.table_constant import TableConstant class CommMatrixAnalysisDB: COMMUNICATION_MATRIX_TABLE = "ClusterCommAnalyzerMatrix" - STAT_LIST = ["middle", "top", "bottom", "total"] - TOTAL = "total" def __init__(self, params: any): self.collection_path = params.get(Constant.COLLECTION_PATH) @@ -33,6 +31,8 @@ class CommMatrixAnalysisDB: self.res_comm_matrix = [] def run(self): + if not self.matrix_info: + return self.set_rank_tuple() self.combine_total_matrix_info() self.dump_data() @@ -44,23 +44,16 @@ class CommMatrixAnalysisDB: conn, cursor = DBManager.create_connect_db(result_db) res = [] for data in self.res_comm_matrix: - res.append(list(data[:2] + data[4:])) + res.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.HCCL_OP_NAME], + data[TableConstant.GROUP_NAME], data[TableConstant.SRC_RANK], data[TableConstant.DST_RANK], + data[TableConstant.TRANSIT_SIZE], data[TableConstant.TRANSIT_TIME], + data[TableConstant.BANDWIDTH], data[TableConstant.TRANSPORT_TYPE], data[TableConstant.OPNAME]]) if res: sql = "insert into {} values ({value})".format(self.COMMUNICATION_MATRIX_TABLE, value="?," * (len(res[0]) - 1) + "?") DBManager.executemany_sql(conn, sql, res) DBManager.destroy_db_connect(conn, cursor) - def check_add_op(self: any, op_name: str): - """ - 兼容2个版本,判断是否需要将此算子信息相加 - """ - for stat_name in self.STAT_LIST: - if stat_name in op_name: - if stat_name != self.TOTAL: - return False - return True - def combine_total_matrix_info(self): for rank_tuple, group_dict in self.comm_matrix_struct.items(): if rank_tuple != Constant.P2P: @@ -71,60 +64,60 @@ class CommMatrixAnalysisDB: def combine_total_info(self, step_dict: dict): link_key_set = set() - total_dict = {} for op_name, matrix_dict in step_dict.items(): - if self.check_add_op(op_name): - self.res_comm_matrix.extend(*matrix_dict.values()) - link_key_set.add(*matrix_dict.keys()) + if BaseAnalysisJson.check_add_op(op_name): + self.res_comm_matrix.extend(matrix_dict.values()) + for key in matrix_dict.keys(): + link_key_set.add(key) for link_key in link_key_set: - total_matrix_info = CustomizedNamedtupleFactory.generate_named_tuple_from_dto(CommunicationMatrix) + total_matrix_info = dict() + total_matrix_info[TableConstant.TRANSIT_SIZE] = 0.0 + total_matrix_info[TableConstant.TRANSIT_TIME] = 0.0 for op_name, matrix_dict in step_dict.items(): - total_matrix_info = total_matrix_info.replace(transport_type=matrix_dict[link_key].transport_type, - op_name=matrix_dict[link_key].op_name, - hccl_op_name=Constant.TOTAL_OP_INFO, - rank_set=matrix_dict[link_key].rank_set, - step=matrix_dict[link_key].step, - src_rank=matrix_dict[link_key].src_rank, - dst_rank=matrix_dict[link_key].dst_rank, - transit_size=(total_matrix_info.transit_size + matrix_dict[link_key].transit_size), - transit_time=(total_matrix_info.transit_time + matrix_dict[link_key].transit_time)) - total_dict[link_key] = total_matrix_info - for key, data in total_dict: - bandwidth = BaseJsonAnalysis.compute_ratio(data.transit_size, data.transit_time) - total_dict[key] = data.replace(bandwidth=bandwidth) - self.res_comm_matrix.extend(total_dict.values()) + total_matrix_info[TableConstant.RANK_SET] = matrix_dict[link_key][TableConstant.RANK_SET] + self.combine_link_info(total_matrix_info, matrix_dict[link_key]) + bandwidth = BaseAnalysisJson.compute_ratio(total_matrix_info[TableConstant.TRANSIT_SIZE], + total_matrix_info[TableConstant.TRANSIT_TIME]) + total_matrix_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO + total_matrix_info[TableConstant.GROUP_NAME] = "" + total_matrix_info[TableConstant.BANDWIDTH] = bandwidth + self.res_comm_matrix.append(total_matrix_info) - def combine_link_info(self, link_info, data): - return link_info.replace(transport_type=data.transport_type, - op_name=data.op_name, - hccl_op_name=data.hccl_op_name, - group_name=data.group_name, - step=data.step, - transit_size=(link_info.transit_size + data.transit_size), - transit_time=(link_info.transit_time + data.transit_time)) + def combine_link_info(self, link_info, data: dict): + for col in data.keys(): + if col in [TableConstant.TRANSIT_TIME, TableConstant.TRANSIT_SIZE]: + link_info[col] += data[col] + else: + link_info[col] = data[col] def merge_same_info(self, step_dict: dict, rank_tuple): def process_matrix(): for data in op_list: - if data.src_rank == data.dst_rank: - if data.src_rank not in local_global_rank_map: - local_global_rank_map[data.src_rank] = data.rank_id - elif local_global_rank_map[data.src_rank] != data.rank_id: + if data[TableConstant.SRC_RANK] == data[TableConstant.DST_RANK]: + if data[TableConstant.SRC_RANK] not in local_global_rank_map: + local_global_rank_map[data[TableConstant.SRC_RANK]] = data[TableConstant.RANK_ID] + elif local_global_rank_map[data[TableConstant.SRC_RANK]] != data[TableConstant.RANK_ID]: print(f"[WARNING] In the same communication group, local ranks projecting to global ranks " f"repeat!") - if matrix_info.src_rank == data.src_rank and matrix_info.dst_rank == data.dst_rank: - new_matrix_list[link_key] = self.combine_link_info(matrix_info, data) + if (link_key.split('-')[0] == data[TableConstant.SRC_RANK] and + link_key.split('-')[1] == data[TableConstant.DST_RANK]): + self.combine_link_info(matrix_info, data) + new_matrix_list[link_key] = matrix_info def convert_local_to_global_rank(): - res_dict = {} - for key, new_matrix in new_matrix_list: - src_rank = new_matrix.src_rank - dst_rank = new_matrix.dst_rank + res_dict = dict() + for key, new_matrix in new_matrix_list.items(): + src_rank = new_matrix[TableConstant.SRC_RANK] + dst_rank = new_matrix[TableConstant.DST_RANK] src_rank = local_global_rank_map[src_rank] if src_rank in local_global_rank_map else src_rank dst_rank = local_global_rank_map[dst_rank] if dst_rank in local_global_rank_map else dst_rank - bandwidth = BaseJsonAnalysis.compute_ratio(new_matrix.transit_size, new_matrix.transit_time) + bandwidth = BaseAnalysisJson.compute_ratio(new_matrix[TableConstant.TRANSIT_SIZE], + new_matrix[TableConstant.TRANSIT_TIME]) key = f"{src_rank}-{dst_rank}" - res_dict[key] = new_matrix.replace(src_rank=src_rank, dst_rank=dst_rank, bandwidth=bandwidth) + new_matrix[TableConstant.SRC_RANK] = src_rank + new_matrix[TableConstant.DST_RANK] = dst_rank + new_matrix[TableConstant.BANDWIDTH] = bandwidth + res_dict[key] = new_matrix return res_dict local_global_rank_map = dict() @@ -132,20 +125,21 @@ class CommMatrixAnalysisDB: new_matrix_list = {} link_key_set = set() for op_data in op_list: - link_key_set.add(op_data.src_rank + "-" + op_data.dst_rank) + link_key_set.add(op_data[TableConstant.SRC_RANK] + "-" + op_data[TableConstant.DST_RANK]) for link_key in link_key_set: - matrix_info = CustomizedNamedtupleFactory.generate_named_tuple_from_dto(CommunicationMatrix) - matrix_info = matrix_info.replace(rank_set=rank_tuple, - src_rank=link_key.split('-')[0], dst_rank=link_key.split('-')[1]) + matrix_info = dict() + matrix_info[TableConstant.RANK_SET] = rank_tuple + matrix_info[TableConstant.TRANSIT_SIZE] = 0.0 + matrix_info[TableConstant.TRANSIT_TIME] = 0.0 process_matrix() step_dict[op_name] = convert_local_to_global_rank() def set_rank_tuple(self): for data in self.matrix_info: - op_name = data.hccl_op_name + "@" + data.group_name - if data.type == Constant.P2P: + op_name = data[TableConstant.HCCL_OP_NAME] + "@" + data[TableConstant.GROUP_NAME] + if data[TableConstant.STEP] == Constant.P2P: rank_tuple = Constant.P2P else: - rank_tuple = self.collective_group_dict.get(data.group_name) - self.comm_matrix_struct.setdefault(rank_tuple, {}).setdefault(data.step, {}). \ + rank_tuple = tuple(self.collective_group_dict.get(data[TableConstant.GROUP_NAME])) + self.comm_matrix_struct.setdefault(rank_tuple, {}).setdefault(data[TableConstant.STEP], {}). \ setdefault(op_name, []).append(data) diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py index b91f6a028f..3851161575 100644 --- a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_json.py @@ -15,14 +15,12 @@ from collections import defaultdict -from analysis.base_json_analysis import BaseJsonAnalysis +from analysis.base_analysis_json import BaseAnalysisJson from common_func.constant import Constant -class CommMatrixJsonAnalysisJson(BaseJsonAnalysis): +class CommMatrixAnalysisJson(BaseAnalysisJson): SAVED_JSON = "cluster_communication_matrix.json" - STAT_LIST = ["middle", "top", "bottom", "total"] - TOTAL = "total" def __init__(self, param: dict): super().__init__(param) @@ -105,13 +103,3 @@ class CommMatrixJsonAnalysisJson(BaseJsonAnalysis): self.compute_ratio(link_dict.get(Constant.TRANSIT_SIZE_MB, 0), link_dict.get(Constant.TRANSIT_TIME_MS, 0)) step_dict[Constant.TOTAL_OP_INFO] = total_op_info - - def check_add_op(self: any, op_name: str): - """ - 兼容2个版本,判断是否需要将此算子信息相加 - """ - for stat_name in self.STAT_LIST: - if stat_name in op_name: - if stat_name != self.TOTAL: - return False - return True diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_generator.py similarity index 90% rename from profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis.py rename to profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_generator.py index 69507807db..d943338aa7 100644 --- a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis.py +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_generator.py @@ -14,15 +14,15 @@ # limitations under the License. from analysis.communication_matrix.comm_matrix_analysis_db import CommMatrixAnalysisDB -from analysis.communication_matrix.comm_matrix_analysis_json import CommMatrixJsonAnalysisJson +from analysis.communication_matrix.comm_matrix_analysis_json import CommMatrixAnalysisJson from common_func.constant import Constant -class CommMatrixAnalysis: +class CommMatrixAnalysisGenerator: GROUP_MAP = { Constant.DB: CommMatrixAnalysisDB, - Constant.TEXT: CommMatrixJsonAnalysisJson + Constant.TEXT: CommMatrixAnalysisJson } def __init__(self, params: dict): diff --git a/profiler/cluster_analyse/analysis/step_trace/__init__.py b/profiler/cluster_analyse/analysis/step_trace/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py index 1d1eb87c7b..20a71df3c5 100644 --- a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py +++ b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py @@ -14,9 +14,8 @@ # limitations under the License. import os -from collections import defaultdict -from common_func.DBManager import DBManager +from common_func.db_manager import DBManager from common_func.constant import Constant from common_func.file_manager import FileManager from prof_bean.step_trace_time_bean import StepTraceTimeBean @@ -67,8 +66,6 @@ class StepTraceTimeAnalysis: DBManager.executemany_sql(conn, sql, self.step_data_list) DBManager.destroy_db_connect(conn, cursor) - - def load_step_trace_time_data(self): for rank_id, profiling_dir_path in self.data_map.items(): if self.data_type == Constant.TEXT: @@ -76,12 +73,13 @@ class StepTraceTimeAnalysis: if step_time_file: self.step_time_dict[rank_id] = FileManager.read_csv_file(step_time_file, StepTraceTimeBean) else: - step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, Constant.DB_COMMUNICATION_ANALYZER) + step_time_file = os.path.join(profiling_dir_path, Constant.SINGLE_OUTPUT, + Constant.DB_COMMUNICATION_ANALYZER) if step_time_file and DBManager.check_tables_in_db(step_time_file, Constant.TABLE_STEP_TRACE): conn, cursor = DBManager.create_connect_db(step_time_file) sql = "select * from {0}".format(Constant.TABLE_STEP_TRACE) - data = DBManager.fetch_all_data(cursor, sql) - self.step_time_dict[rank_id] = StepTraceTimeBean(data) + data = DBManager.fetch_all_data(cursor, sql, is_dict=False) + self.step_time_dict[rank_id] = data DBManager.destroy_db_connect(conn, cursor) if not self.step_time_dict.get(rank_id): print(f"[WARNING] Rank {rank_id} does not have a valid step_trace_time.json.") @@ -92,7 +90,7 @@ class StepTraceTimeAnalysis: if self.data_type == Constant.TEXT: self.step_data_list.append([data_bean.step, Constant.RANK, rank_id] + data_bean.row) else: - self.step_data_list.append([data_bean[0], Constant.RANK, rank_id] + data_bean[1:]) + self.step_data_list.append([data_bean[0], Constant.RANK, rank_id] + list(data_bean[1:])) stage_list = self.communication_group.get(Constant.P2P) if not stage_list: return @@ -110,7 +108,7 @@ class StepTraceTimeAnalysis: if self.data_type == Constant.TEXT: self.step_data_list.append([key[0], Constant.STAGE, key[1]] + self.get_max_data_row(data_group_list)) else: - index = "(" + ",".join(str(i) for i in key[1]) + index = "(" + ",".join(str(i) for i in key[1]) + ")" self.step_data_list.append([key[0], Constant.STAGE, index] + self.get_max_data_row(data_group_list)) def get_headers(self): diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index 39cbae963f..68eae526fb 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -45,6 +45,7 @@ class Interface: type_db_count = 0 type_text_count = 0 for _, folder_path in data_map.items(): + folder_path = os.path.join(folder_path, Constant.SINGLE_OUTPUT) db_files = glob.glob(os.path.join(folder_path, self.DB_RESULT_INFO)) all_files = glob.glob(os.path.join(folder_path, self.ALL_RESULT_INFO)) if all_files and db_files and len(all_files) != len(db_files): diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index 7b0468fc28..71caee40db 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -71,8 +71,8 @@ class Constant(object): ANALYSIS_MODE = "analysis_mode" # step time - RANK = 'rank' - STAGE = 'stage' + RANK = "rank" + STAGE = "stage" # epsilon EPS = 1e-15 @@ -86,11 +86,11 @@ class Constant(object): DB = "db" # db name - DB_COMMUNICATION_ANALYZER = 'analysis.db' - DB_CLUSTER_COMMUNICATION_ANALYZER = 'cluster_analysis.db' + DB_COMMUNICATION_ANALYZER = "analysis.db" + DB_CLUSTER_COMMUNICATION_ANALYZER = "cluster_analysis.db" # db tables - TABLE_COMM_ANALYZER_BANDWIDTH = 'CommAnalyzerBandwidth' - TABLE_COMM_ANALYZER_TIME = 'CommAnalyzerTime' - TABLE_COMM_ANALYZER_MATRIX = 'CommAnalyzerMatrix' - TABLE_STEP_TRACE = 'StepTraceTime' + TABLE_COMM_ANALYZER_BANDWIDTH = "CommAnalyzerBandwidth" + TABLE_COMM_ANALYZER_TIME = "CommAnalyzerTime" + TABLE_COMM_ANALYZER_MATRIX = "CommAnalyzerMatrix" + TABLE_STEP_TRACE = "StepTraceTime" diff --git a/profiler/cluster_analyse/common_func/DBManager.py b/profiler/cluster_analyse/common_func/db_manager.py similarity index 65% rename from profiler/cluster_analyse/common_func/DBManager.py rename to profiler/cluster_analyse/common_func/db_manager.py index 7fd8dafd8d..e7e50a572e 100644 --- a/profiler/cluster_analyse/common_func/DBManager.py +++ b/profiler/cluster_analyse/common_func/db_manager.py @@ -15,11 +15,10 @@ import os import sqlite3 -from collections import namedtuple, OrderedDict -from dataclasses import fields, MISSING from common_func.constant import Constant from common_func.empty_class import EmptyClass +from common_func.file_manager import check_db_path_valid from common_func.tables_config import TablesConfig @@ -76,7 +75,7 @@ class DBManager: if not isinstance(curs, sqlite3.Cursor): return False try: - curs.execute("select count(*) from sqlite_master where type='table' and name=?", table_name) + curs.execute("select count(*) from sqlite_master where type='table' and name=?", (table_name,)) return curs.fetchone()[0] except sqlite3.Error as err: print("[ERROR] {}".format(err)) @@ -90,6 +89,10 @@ class DBManager: if table_map in TablesConfig.DATA: items = TablesConfig.DATA[table_map] for item in items: + if item[0] == "index": + header_with_type_list.append('"' + item[0] + '" ' + item[1].split(",")[0]) + else: + header_with_type_list.append(item[0] + ' ' + item[1].split(",")[0]) header_with_type_list.append(item[0] + " " + item[1].split(",")[0]) header_with_type_begin += ",".join(header_with_type_list) header_with_type_begin += header_with_type_end @@ -97,17 +100,17 @@ class DBManager: @classmethod def check_tables_in_db(cls, db_path: any, *tables: any) -> bool: - check_db_path_valid(db_path) - conn, curs = cls.create_connect_db(db_path) - if not (conn and curs): - return False - res = True - for table in tables: - if not cls.judge_table_exists(curs, table): - res = False - break - cls.destroy_db_connect(conn, curs) - return res + if check_db_path_valid(db_path, True): + conn, curs = cls.create_connect_db(db_path) + if not (conn and curs): + return False + res = True + for table in tables: + if not cls.judge_table_exists(curs, table): + res = False + break + cls.destroy_db_connect(conn, curs) + return res @classmethod def create_tables(cls, db_path: any, *tables: any) -> bool: @@ -155,7 +158,7 @@ class DBManager: return False @classmethod - def fetch_all_data(cls: any, curs: any, sql: str, param: tuple = None, dto_class: any = None) -> list: + def fetch_all_data(cls: any, curs: any, sql: str, param: tuple = None, is_dict: bool = True) -> list: """ fetch 10000 num of data from db each time to get all data """ @@ -172,12 +175,11 @@ class DBManager: curs.row_factory = None return [] try: - if dto_class: - tuple_dto = CustomizedNamedtupleFactory.generate_named_tuple_from_db(dto_class, res.description) + description = res.description while True: res = curs.fetchmany(cls.FETCH_SIZE) - if dto_class: - data += [tuple_dto(*i) for i in res] + if is_dict: + data += CustomizedDictFactory.generate_dict_from_db(res, description) else: data += res if len(data) > cls.MAX_ROW_COUNT: @@ -192,56 +194,12 @@ class DBManager: curs.row_factory = None -class CustomizedNamedtupleFactory: - - @staticmethod - def generate_named_tuple_from_db(dto_class: type, description: any) -> any: - description_set = {i[0] for i in description} - extend_columns = OrderedDict() - for item in fields(dto_class): - if item.name not in description_set: - if item.default == MISSING: - extend_columns[item.name] = None - else: - extend_columns[item.name] = item.default - field_names = [i[0] for i in description] - field_names.extend(extend_columns.keys()) - defaults = [None] * len(description) - defaults.extend(extend_columns.values()) - base_tuple = namedtuple(dto_class.__name__, field_names, defaults=defaults) - extra_properties = {} - for name in dir(dto_class): - if isinstance(getattr(dto_class, name), property): - extra_properties[name] = getattr(dto_class, name) - return CustomizedNamedtupleFactory.enhance_namedtuple(base_tuple, extra_properties) - +class CustomizedDictFactory: @staticmethod - def enhance_namedtuple(tuple_type: type, function_dict: dict): - class_namespace = dict(tuple_type.__dict__) - class_namespace.update(function_dict) - class_namespace["replace"] = class_namespace.get("_replace") - return type(tuple_type.__name__, (tuple,), class_namespace) - - @staticmethod - def generate_named_tuple_from_dto(dto_class: type) -> any: - extend_columns = OrderedDict() - for item in fields(dto_class): - if item.default == MISSING: - extend_columns[item.name] = None - else: - extend_columns[item.name] = item.default - fields_name = list(extend_columns.keys()) - defaults = [None] * len(fields_name) - defaults.extend(extend_columns.values()) - base_tuple = namedtuple(dto_class.__name__, fields_name, defaults=defaults) - return CustomizedNamedtupleFactory.enhance_namedtuple(base_tuple, fields_name) - - -def check_db_path_valid(path: str, is_create: bool = False, max_size: int = Constant.MAX_READ_DB_FILE_BYTES) -> bool: - if os.path.islink(path): - print(f'[ERROR] The db file path: {path} is link. Please check the path') - return False - if not is_create and not os.path.exists(path) and os.path.getsize(path) > max_size: - print(f'[ERROR] The db file: {path} is too large to read. Please check the file') - return False - return True + def generate_dict_from_db(data_result: any, description: any) -> any: + description_set = [i[0] for i in description] + res = [] + for data in data_result: + data_dict = dict(zip(description_set, data)) + res.append(data_dict) + return res diff --git a/profiler/cluster_analyse/common_func/empty_class.py b/profiler/cluster_analyse/common_func/empty_class.py index 0141079a00..9d41eccbb5 100644 --- a/profiler/cluster_analyse/common_func/empty_class.py +++ b/profiler/cluster_analyse/common_func/empty_class.py @@ -15,9 +15,6 @@ class EmptyClass: - """ - Empty class - """ def __init__(self: any, info: str = "") -> None: self._info = info @@ -32,15 +29,8 @@ class EmptyClass: @property def info(self: any) -> str: - """ - get info - :return: _info - """ return self._info @staticmethod def is_empty() -> bool: - """ - return this is an empty class - """ return True diff --git a/profiler/cluster_analyse/common_func/file_manager.py b/profiler/cluster_analyse/common_func/file_manager.py index 3853c806f9..eefa80ceae 100644 --- a/profiler/cluster_analyse/common_func/file_manager.py +++ b/profiler/cluster_analyse/common_func/file_manager.py @@ -71,8 +71,8 @@ class FileManager: PathManager.check_path_writeable(output_path) try: with os.fdopen( - os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), - 'w', newline="" + os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), + 'w', newline="" ) as file: writer = csv.writer(file) if headers: @@ -91,7 +91,7 @@ class FileManager: PathManager.check_path_writeable(output_path) try: with os.fdopen( - os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), 'w' + os.open(output_file, os.O_WRONLY | os.O_CREAT, cls.DATA_FILE_AUTHORITY), 'w' ) as file: file.write(json.dumps(data)) except Exception as e: @@ -115,3 +115,13 @@ class FileManager: file_size = os.path.getsize(file_path) if file_size > limit_size: raise RuntimeError(f"The file({base_name}) size exceeds the preset max value.") + + +def check_db_path_valid(path: str, is_create: bool = False, max_size: int = Constant.MAX_READ_DB_FILE_BYTES) -> bool: + if os.path.islink(path): + print(f'[ERROR] The db file path: {path} is link. Please check the path') + return False + if not is_create and os.path.exists(path) and os.path.getsize(path) > max_size: + print(f'[ERROR] The db file: {path} is too large to read. Please check the file') + return False + return True diff --git a/profiler/cluster_analyse/common_func/table_constant.py b/profiler/cluster_analyse/common_func/table_constant.py new file mode 100644 index 0000000000..de6d47e97e --- /dev/null +++ b/profiler/cluster_analyse/common_func/table_constant.py @@ -0,0 +1,27 @@ +class TableConstant: + + RANK_SET = "rank_set" + STEP = "step" + RANK_ID = "rank_id" + TYPE = "type" + HCCL_OP_NAME = "hccl_op_name" + GROUP_NAME = "group_name" + START_TIMESTAMP = "start_timestamp" + ELAPSED_TIME = "elapse_time" + TRANSIT_TIME = "transit_time" + WAIT_TIME = "wait_time" + SYNCHRONIZATION_TIME = "synchronization_time" + IDLE_TIME = "idle_time" + SYNCHRONIZATION_TIME_RATIO = "synchronization_time_ratio" + WAIT_TIME_RATIO = "wait_time_ratio" + BAND_TYPE = "band_type" + TRANSIT_SIZE = "transit_size" + BANDWIDTH = "bandwidth" + LARGE_PACKET_RATIO = "large_packet_ratio" + PACKAGE_SIZE = "package_size" + COUNT = "count" + TOTAL_DURATION = "total_duration" + SRC_RANK = "src_rank" + DST_RANK = "dst_rank" + TRANSPORT_TYPE = "transport_type" + OPNAME = "op_name" diff --git a/profiler/cluster_analyse/communication_group/communication_db_group.py b/profiler/cluster_analyse/communication_group/communication_db_group.py index 941efc9df8..db1e8bd4f5 100644 --- a/profiler/cluster_analyse/communication_group/communication_db_group.py +++ b/profiler/cluster_analyse/communication_group/communication_db_group.py @@ -15,12 +15,10 @@ import os -from common_func.DBManager import DBManager +from common_func.db_manager import DBManager from common_func.constant import Constant +from common_func.table_constant import TableConstant from communication_group.base_communication_group import BaseCommunicationGroup -from prof_bean.communication_bandwidth_info_bean import CommunicationBandwidthInfo -from prof_bean.communication_matrix_bean import CommunicationMatrix -from prof_bean.communication_time_info_bean import CommunicationTimeInfo class CommunicationDBGroup(BaseCommunicationGroup): @@ -40,19 +38,18 @@ class CommunicationDBGroup(BaseCommunicationGroup): time_data = {} bandwidth_data = {} matrix_data = {} - if DBManager.check_tables_in_db(db_path, (Constant.TABLE_COMM_ANALYZER_TIME, - Constant.TABLE_COMM_ANALYZER_BANDWIDTH, - Constant.TABLE_COMM_ANALYZER_MATRIX)): + if DBManager.check_tables_in_db(db_path, Constant.TABLE_COMM_ANALYZER_TIME, + Constant.TABLE_COMM_ANALYZER_BANDWIDTH, + Constant.TABLE_COMM_ANALYZER_MATRIX): conn, cursor = DBManager.create_connect_db(db_path) time_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_TIME) bandwidth_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_BANDWIDTH) matrix_info_sql = "select * from {0}".format(Constant.TABLE_COMM_ANALYZER_MATRIX) if self.analysis_mode in ["all", "communication_time"]: - time_data = DBManager.fetch_all_data(cursor, time_info_sql, dto_class=CommunicationTimeInfo) - bandwidth_data = DBManager.fetch_all_data(cursor, bandwidth_info_sql, - dto_class=CommunicationBandwidthInfo) + time_data = DBManager.fetch_all_data(cursor, time_info_sql) + bandwidth_data = DBManager.fetch_all_data(cursor, bandwidth_info_sql) if self.analysis_mode in ["all", "communication_matrix"]: - matrix_data = DBManager.fetch_all_data(cursor, matrix_info_sql, dto_class=CommunicationMatrix) + matrix_data = DBManager.fetch_all_data(cursor, matrix_info_sql) DBManager.destroy_db_connect(conn, cursor) return (rank_id, self.data_group_by_step(time_data), self.data_group_by_step(bandwidth_data), self.data_group_by_step(matrix_data)) @@ -61,7 +58,7 @@ class CommunicationDBGroup(BaseCommunicationGroup): def data_group_by_step(data: any) -> any: res = {} for item in data: - res.setdefault(item.step, []).append(item) + res.setdefault(item[TableConstant.STEP], []).append(item) return res def dump_data(self): @@ -93,20 +90,19 @@ class CommunicationDBGroup(BaseCommunicationGroup): for rank_id, time_data, bandwidth_data, matrix_data in self.rank_comm_dir_dict: for step, data_list in time_data.items(): for data in data_list: - if data.type == Constant.COLLECTIVE: - self.collective_group_dict[data.group_name].add(rank_id) - data = data.replace(rank_id=rank_id) - self.communication_time_info.append(data) + self.compute_collective_group(data, rank_id, self.communication_time_info) for data in bandwidth_data[step]: - if data.type == Constant.COLLECTIVE: - self.collective_group_dict[data.group_name].add(rank_id) - data = data.replace(rank_id=rank_id) - self.communication_bandwidth_info.append(data) + self.compute_collective_group(data, rank_id, self.communication_bandwidth_info) for step, data_list in matrix_data.items(): self.add_p2p_and_rank(rank_id, step, matrix_data) for data in data_list: - if data.type == Constant.COLLECTIVE: - self.collective_group_dict[data.group_name].add(rank_id) + self.compute_collective_group(data, rank_id, self.matrix_info) + + def compute_collective_group(self, data, rank_id, res_list): + if data[TableConstant.TYPE] == Constant.COLLECTIVE: + self.collective_group_dict[data[TableConstant.GROUP_NAME]].add(rank_id) + data[TableConstant.RANK_ID] = rank_id + res_list.append(data) def add_p2p_and_rank(self, rank_id: int, step: str, data_dict: dict): data_list = data_dict[step] @@ -114,13 +110,11 @@ class CommunicationDBGroup(BaseCommunicationGroup): print(f"[WARNING] rank {rank_id} {step} don't have communication matrix ops data") return for data in data_list: - if data.type != Constant.COLLECTIVE and data.type != Constant.P2P: + if data[TableConstant.TYPE] != Constant.COLLECTIVE and data[TableConstant.TYPE] != Constant.P2P: print(f"[WARNING] Unknown communication operators type!") continue - data = data.replace(rank_id=rank_id) - self.matrix_info.append(data) - if data.type == Constant.P2P: - if data.src_rank != data.dst_rank: - rank_set = {data.src_rank, data.dst_rank} + if data[TableConstant.TYPE] == Constant.P2P: + if data[TableConstant.SRC_RANK] != data[TableConstant.DST_RANK]: + rank_set = {data[TableConstant.SRC_RANK], data[TableConstant.DST_RANK]} if rank_set not in self.p2p_link: self.p2p_link.append(rank_set) diff --git a/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py b/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py deleted file mode 100644 index 877d7d2901..0000000000 --- a/profiler/cluster_analyse/prof_bean/communication_bandwidth_info_bean.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class CommunicationBandwidthInfo: - """ - The class represents a Communication Bandwidth Info object - """ - rank_set: str = "" - step: str = "" - rank_id: int = 0 - type: str = "" - hccl_op_name: str = "" - group_name: str = "" - band_type: str = "" - transit_size: float = 0.0 - transit_time: float = 0.0 - bandwidth: float = 0.0 - large_packet_ratio: float = 0.0 - package_size: float = 0.0 - count: float = 0.0 - total_duration: float = 0.0 diff --git a/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py b/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py deleted file mode 100644 index d3b9bdddb6..0000000000 --- a/profiler/cluster_analyse/prof_bean/communication_matrix_bean.py +++ /dev/null @@ -1,33 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class CommunicationMatrix: - """ - The class represents a Communication Matrix object - """ - rank_set: str = "" - step: str = "" - rank_id: int = 0 - type: str = "" - hccl_op_name: str = "" - group_name: str = "" - src_rank: str = "" - dst_rank: str = "" - transit_size: float = 0.0 - transit_time: float = 0.0 - bandwidth: float = 0.0 - transport_type: str = "" - op_name: str = "" diff --git a/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py b/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py deleted file mode 100644 index c51856569d..0000000000 --- a/profiler/cluster_analyse/prof_bean/communication_time_info_bean.py +++ /dev/null @@ -1,34 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -class CommunicationTimeInfo: - """ - The class represents a Communication Time Info object - """ - rank_set: str = "" - step: str = "" - rank_id: int = 0 - type: str = "" - hccl_op_name: str = "" - group_name: str = "" - start_timestamp: float = 0.0 - elapse_time: float = 0.0 - transit_time: float = 0.0 - wait_time: float = 0.0 - synchronization_time: float = 0.0 - idle_time: float = 0.0 - synchronization_time_ratio: float = 0.0 - wait_time_ratio: float = 0.0 -- Gitee From 4f12e6f1e436dd23133cb1b806ec8f184d4b56e0 Mon Sep 17 00:00:00 2001 From: w00800385 Date: Tue, 5 Mar 2024 20:44:01 +0800 Subject: [PATCH 7/7] bugfix --- .../analysis/communication/communication_analysis_db.py | 8 +++++--- .../communication_matrix/comm_matrix_analysis_db.py | 8 +++++--- profiler/cluster_analyse/common_func/db_manager.py | 1 - 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py index 92a4d8a09d..017ae20512 100644 --- a/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication/communication_analysis_db.py @@ -65,7 +65,7 @@ class CommunicationAnalysisDB: for data in self.res_comm_bandwidth: res_bandwidth.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.RANK_ID], data[TableConstant.HCCL_OP_NAME], data[TableConstant.GROUP_NAME], - data[TableConstant.BAND_TYPE], data[TableConstant.TRANSIT_SIZE], + data[TableConstant.TRANSPORT_TYPE], data[TableConstant.TRANSIT_SIZE], data[TableConstant.TRANSIT_TIME], data[TableConstant.BANDWIDTH], data[TableConstant.LARGE_PACKET_RATIO], data[TableConstant.PACKAGE_SIZE], data[TableConstant.COUNT], data[TableConstant.TOTAL_DURATION]]) @@ -108,7 +108,8 @@ class CommunicationAnalysisDB: rank_tuple = "(" + ",".join(str(i) for i in rank_tuple) + ")" if not is_p2p else Constant.P2P for data in data_list: data[TableConstant.RANK_SET] = rank_tuple - rank_band_type = self.RANK_BAND_TYPE.format(data[TableConstant.RANK_ID], data[TableConstant.BAND_TYPE]) + rank_band_type = self.RANK_BAND_TYPE.format(data[TableConstant.RANK_ID], + data[TableConstant.TRANSPORT_TYPE]) data_dict.setdefault(rank_band_type, []).append(data) self.res_comm_bandwidth.append(data) for rank_band_type, bandwidth_list in data_dict.items(): @@ -119,7 +120,8 @@ class CommunicationAnalysisDB: total_comm_bandwidth_info = dict() for data in bandwidth_list: self.compute_bandwidth(total_comm_bandwidth_info, data, package) - bandwidth = BaseAnalysisJson.compute_ratio(TableConstant.TRANSIT_SIZE, TableConstant.TRANSIT_TIME) + bandwidth = BaseAnalysisJson.compute_ratio(total_comm_bandwidth_info[TableConstant.TRANSIT_SIZE], + total_comm_bandwidth_info[TableConstant.TRANSIT_TIME]) total_comm_bandwidth_info[TableConstant.BANDWIDTH] = bandwidth total_comm_bandwidth_info[TableConstant.PACKAGE_SIZE] = package total_comm_bandwidth_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO diff --git a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py index 0903bb639d..afc1d40d8f 100644 --- a/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py +++ b/profiler/cluster_analyse/analysis/communication_matrix/comm_matrix_analysis_db.py @@ -44,10 +44,11 @@ class CommMatrixAnalysisDB: conn, cursor = DBManager.create_connect_db(result_db) res = [] for data in self.res_comm_matrix: + op_name = data.get(TableConstant.OPNAME) if data.get(TableConstant.OPNAME) is not None else "" res.append([data[TableConstant.RANK_SET], data[TableConstant.STEP], data[TableConstant.HCCL_OP_NAME], data[TableConstant.GROUP_NAME], data[TableConstant.SRC_RANK], data[TableConstant.DST_RANK], data[TableConstant.TRANSIT_SIZE], data[TableConstant.TRANSIT_TIME], - data[TableConstant.BANDWIDTH], data[TableConstant.TRANSPORT_TYPE], data[TableConstant.OPNAME]]) + data[TableConstant.BANDWIDTH], data[TableConstant.TRANSPORT_TYPE], op_name]) if res: sql = "insert into {} values ({value})".format(self.COMMUNICATION_MATRIX_TABLE, value="?," * (len(res[0]) - 1) + "?") @@ -74,8 +75,9 @@ class CommMatrixAnalysisDB: total_matrix_info[TableConstant.TRANSIT_SIZE] = 0.0 total_matrix_info[TableConstant.TRANSIT_TIME] = 0.0 for op_name, matrix_dict in step_dict.items(): - total_matrix_info[TableConstant.RANK_SET] = matrix_dict[link_key][TableConstant.RANK_SET] - self.combine_link_info(total_matrix_info, matrix_dict[link_key]) + if link_key in matrix_dict.keys(): + total_matrix_info[TableConstant.RANK_SET] = matrix_dict[link_key][TableConstant.RANK_SET] + self.combine_link_info(total_matrix_info, matrix_dict[link_key]) bandwidth = BaseAnalysisJson.compute_ratio(total_matrix_info[TableConstant.TRANSIT_SIZE], total_matrix_info[TableConstant.TRANSIT_TIME]) total_matrix_info[TableConstant.HCCL_OP_NAME] = Constant.TOTAL_OP_INFO diff --git a/profiler/cluster_analyse/common_func/db_manager.py b/profiler/cluster_analyse/common_func/db_manager.py index e7e50a572e..5ec6e10a9e 100644 --- a/profiler/cluster_analyse/common_func/db_manager.py +++ b/profiler/cluster_analyse/common_func/db_manager.py @@ -93,7 +93,6 @@ class DBManager: header_with_type_list.append('"' + item[0] + '" ' + item[1].split(",")[0]) else: header_with_type_list.append(item[0] + ' ' + item[1].split(",")[0]) - header_with_type_list.append(item[0] + " " + item[1].split(",")[0]) header_with_type_begin += ",".join(header_with_type_list) header_with_type_begin += header_with_type_end return header_with_type_begin -- Gitee