From 0aad679be8b688f9a85ce3d3291f8c92e943e47f Mon Sep 17 00:00:00 2001 From: zhouxianqi <13165993773@163.com> Date: Thu, 24 Aug 2023 16:44:53 +0800 Subject: [PATCH] communication_adapt --- .../comparator/index_comparator.py | 4 +- .../compare_tools/comparator/op_comparator.py | 6 +-- profiler/compare_tools/performance_compare.py | 9 ++-- profiler/compare_tools/utils/file_reader.py | 4 +- .../compare_tools/utils/profiling_parser.py | 48 ++++++++++++++----- .../compare_tools/utils/trace_event_data.py | 42 ++++++++++++++++ 6 files changed, 92 insertions(+), 21 deletions(-) create mode 100644 profiler/compare_tools/utils/trace_event_data.py diff --git a/profiler/compare_tools/comparator/index_comparator.py b/profiler/compare_tools/comparator/index_comparator.py index d122e3ea3c..c33261a5ce 100644 --- a/profiler/compare_tools/comparator/index_comparator.py +++ b/profiler/compare_tools/comparator/index_comparator.py @@ -14,7 +14,7 @@ class IndexComparator: def compare(self) -> list: base_data, comparison_data = [], [] if not self._base_profiling.communication_data: - print(f"[warning] Can't find any communication op in the file: {self._base_profiling.json_path}") + print(f"[WARNING] Can't find any communication op in the file: {self._base_profiling.json_path}") for data in self._base_profiling.communication_data: name_list = data.get("name", "").split("_") if len(name_list) >= 2: @@ -29,7 +29,7 @@ class IndexComparator: comparison_data = [] else: if not self._comparison_profiling.communication_data: - print(f"[warning] Can't find any communication op in the file: {self._comparison_profiling.json_path}") + print(f"[WARNING] Can't find any communication op in the file: {self._comparison_profiling.json_path}") for data in self._comparison_profiling.communication_data: name_list = data.get("name", "").split("_") if len(name_list) >= 2: diff --git a/profiler/compare_tools/comparator/op_comparator.py b/profiler/compare_tools/comparator/op_comparator.py index 4b552ac258..89bfc1a692 100644 --- a/profiler/compare_tools/comparator/op_comparator.py +++ b/profiler/compare_tools/comparator/op_comparator.py @@ -86,18 +86,18 @@ class OpComparator: def _get_top_layer_ops(self, profiling_instance: any) -> any: torch_op_data = profiling_instance.torch_op_data if not torch_op_data: - print(f"[warning] Can't find any torch op in the file: {profiling_instance.json_path}") + print(f"[WARNING] Can't find any torch op in the file: {profiling_instance.json_path}") root_node = TreeBuilder.build_tree(torch_op_data) kernel_dict, memory_list = {}, [] if not self._args.disable_operator_compare: kernel_dict = profiling_instance.kernel_dict if not kernel_dict: - print(f"[warning] Can't find any flow event in the file: {profiling_instance.json_path}") + print(f"[WARNING] Can't find any flow event in the file: {profiling_instance.json_path}") if not self._args.disable_memory_compare: memory_list = profiling_instance.memory_list if not memory_list: - print(f"[warning] Can't find any memory event in the file: {profiling_instance.file_path}") + print(f"[WARNING] Can't find any memory event in the file: {profiling_instance.file_path}") TreeBuilder.update_tree_node(root_node, kernel_dict, memory_list) level1_child_nodes = root_node.child_nodes diff --git a/profiler/compare_tools/performance_compare.py b/profiler/compare_tools/performance_compare.py index 3a7f391309..885f9b44b7 100644 --- a/profiler/compare_tools/performance_compare.py +++ b/profiler/compare_tools/performance_compare.py @@ -52,17 +52,20 @@ def main(): try: performance_compare(args) except Exception: - print("profiling analyze failed.") + print("[WARNING] Profiling failed to analyze.") + + print("[INFO] Start to compare performance data, please wait.") dir_path = args.output_path if args.output_path else "./" file_name = "performance_comparison_result_{}.xlsx".format( time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) - result_file_path = os.path.join(dir_path, file_name) + result_file_path = os.path.realpath(os.path.join(dir_path, file_name)) ComparisonGenerator(args).create_excel(result_file_path) + print(f"[INFO] The comparison result file has been generated: {result_file_path}") if __name__ == "__main__": start_time = datetime.datetime.now() main() end_time = datetime.datetime.now() - print(f'The comparison task has been completed in a total time of {end_time - start_time}') + print(f'[INFO] The comparison task has been completed in a total time of {end_time - start_time}') diff --git a/profiler/compare_tools/utils/file_reader.py b/profiler/compare_tools/utils/file_reader.py index b536fce0f7..4658e0e776 100644 --- a/profiler/compare_tools/utils/file_reader.py +++ b/profiler/compare_tools/utils/file_reader.py @@ -16,7 +16,7 @@ class FileReader: if file_size <= 0: return [] if file_size > Constant.MAX_FILE_SIZE: - print(f"The file size exceeds the preset value {Constant.MAX_FILE_SIZE / 1024 / 1024}MB, " + print(f"[WARNING] The file size exceeds the preset value {Constant.MAX_FILE_SIZE / 1024 / 1024}MB, " f"please check the file: {file_path}") return [] try: @@ -35,7 +35,7 @@ class FileReader: if file_size <= 0: return [] if file_size > Constant.MAX_FILE_SIZE: - print(f"[WARN] The file size exceeds the preset value {Constant.MAX_FILE_SIZE / 1024 / 1024}MB, " + print(f"[WARNING] The file size exceeds the preset value {Constant.MAX_FILE_SIZE / 1024 / 1024}MB, " f"please check the file: {file_path}") return [] result_data = [] diff --git a/profiler/compare_tools/utils/profiling_parser.py b/profiler/compare_tools/utils/profiling_parser.py index 231f91f2b7..8a94cb695d 100644 --- a/profiler/compare_tools/utils/profiling_parser.py +++ b/profiler/compare_tools/utils/profiling_parser.py @@ -4,6 +4,7 @@ from math import ceil from utils.compare_event import KernelEvent from utils.constant import Constant from utils.file_reader import FileReader +from utils.trace_event_data import TraceEventData class ProfilingParser(metaclass=ABCMeta): @@ -270,27 +271,52 @@ class NPUProfilingParser(ProfilingParser): def get_communication_data(self): self._communication_data, self._communication_task_data = [], {} - pid, tid = None, None json_data = FileReader.read_trace_file(self._json_path) + pid = None for data in json_data: - if data.get("ph", "") == "M" and data.get("name", "") == "thread_name" \ - and data.get("args", {}).get("name", "") == "Communication OP": - pid = data.get("pid", "") - tid = data.get("tid", "") - if not pid or not tid: + trace_event = TraceEventData(data) + if not trace_event.is_process_meta(): + continue + if trace_event.is_hccl_process(): + pid = trace_event.pid + break + if pid is None: + return + tid_list = [] + for data in json_data: + trace_event = TraceEventData(data) + if not trace_event.is_thread_meta(): + continue + if trace_event.pid != pid: + continue + if trace_event.is_communication_op_thread(): + tid_list.append(trace_event.tid) + + if not tid_list: return + for data in json_data: - if data.get("ph", "") == "X" and data.get("pid", "") == pid and data.get("tid", "") == tid: + trace_event = TraceEventData(data) + if not trace_event.is_x_mode(): + continue + if trace_event.pid != pid: + continue + if trace_event.tid in tid_list: self._communication_data.append(data) if not self._communication_data: return for data in json_data: - if data.get("ph", "") != "X" or data.get("pid", "") != pid or data.get("tid", "") == tid: + trace_event = TraceEventData(data) + if not trace_event.is_x_mode(): + continue + if trace_event.pid != pid: + continue + if trace_event.tid in tid_list: continue - ts = data.get("ts", 0) + ts = trace_event.start_time for communication_op in self._communication_data: - if ts < communication_op.get("ts", 0) or ts - communication_op.get("ts", 0) > communication_op.get( - "dur", 0): + comm_op_event = TraceEventData(communication_op) + if ts < comm_op_event.start_time or ts > comm_op_event.end_time: continue name_list = communication_op.get("name", "").split("_") if len(name_list) >= 2: diff --git a/profiler/compare_tools/utils/trace_event_data.py b/profiler/compare_tools/utils/trace_event_data.py new file mode 100644 index 0000000000..71030ab90a --- /dev/null +++ b/profiler/compare_tools/utils/trace_event_data.py @@ -0,0 +1,42 @@ +class TraceEventData: + + def __init__(self, event: dict): + self._event = event + + @property + def pid(self) -> int: + return self._event.get("pid", "") + + @property + def tid(self) -> int: + return self._event.get("tid", "") + + @property + def process_name(self) -> int: + return self._event.get("args", {}).get("name", "") + + @property + def start_time(self) -> float: + return self._event.get("ts", 0) + + @property + def end_time(self) -> float: + return self._event.get("ts", 0) + self._event.get("dur", 0) + + def is_m_mode(self) -> bool: + return self._event.get("ph", "") == "M" + + def is_x_mode(self) -> bool: + return self._event.get("ph", "") == "X" + + def is_process_meta(self) -> bool: + return self.is_m_mode() and self._event.get("name", "") == "process_name" + + def is_thread_meta(self) -> bool: + return self.is_m_mode() and self._event.get("name", "") == "thread_name" + + def is_communication_op_thread(self) -> bool: + return self._event.get("args", {}).get("name", "").find("Communication") != -1 + + def is_hccl_process(self) -> bool: + return self.process_name == "HCCL" -- Gitee