From 9107b1352c99ec9021a02d4415e01441b911bf08 Mon Sep 17 00:00:00 2001 From: wuyulong11 Date: Wed, 6 Dec 2023 19:47:51 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BF=AE=E6=94=B9=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E3=80=91=E3=80=90tbplugin=E3=80=91=E3=80=90ISSUE=20#I8MC2O?= =?UTF-8?q?=E3=80=91=E4=BF=AE=E5=A4=8DNPU=20Profiling=E6=95=B0=E6=8D=AEDif?= =?UTF-8?q?f=E6=AF=94=E5=AF=B9=E5=8A=9F=E8=83=BD=20=E3=80=90=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=BA=BA=E3=80=91=20wuyulong=2030031080?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tb_plugin/torch_tb_profiler/__init__.py | 2 +- .../torch_tb_profiler/profiler/data.py | 10 ++++++++++ .../torch_tb_profiler/profiler/op_tree.py | 17 ++++++++++------- .../torch_tb_profiler/profiler/run_generator.py | 2 +- 4 files changed, 22 insertions(+), 9 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py index 9222754d5dd..d5ed2b3d4c5 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py @@ -4,4 +4,4 @@ # Entry point for Pytorch TensorBoard plugin package. -__version__ = '0.4.0.3' +__version__ = '0.4.0.4' diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py index ba423019a0d..ee1ce5b62de 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py @@ -27,6 +27,7 @@ import tempfile from json.decoder import JSONDecodeError from typing import Dict, List, Optional +from .op_tree import OpTreeBuilder from .. import io, utils from ..utils import href from . import trace @@ -163,6 +164,15 @@ class RunProfileData(object): break profile = RunProfileData(worker, span, trace_json) + with utils.timing('EventParser.parse'): + parser = EventParser() + with utils.timing('EventParser: parse nodes'): + tid2list, tid2zero_rt_list, staled_device_nodes, _ = parser.parse_nodes(profile.events) + + with utils.timing('EventParser: build operator tree'): + builder = OpTreeBuilder() + profile.tid2tree = builder.build_tree(tid2list, tid2zero_rt_list, staled_device_nodes, + fwd_bwd_map=profile.forward_backward_events, is_ascend=True) profile.trace_file_path = trace_path profile.has_trace = has_trace if math.isinf(profile.profiler_start_ts): diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py index 5639c666aad..55e264617d8 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py @@ -1,6 +1,7 @@ # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # ------------------------------------------------------------------------- +import math import sys from collections import defaultdict from typing import Dict, Iterable, List, Optional, Tuple @@ -25,10 +26,11 @@ class OpTreeBuilder: tid2list: Dict[int, List[OperatorNode]], tid2zero_rt_list: Dict[int, List[RuntimeNode]], staled_device_nodes: List[DeviceNode], - fwd_bwd_map: Dict[int, int]): + fwd_bwd_map: Dict[int, int], + is_ascend=False): """Construct the BackwardNode and replace the original backward nodes """ - self.tid2tree = self._build_tree(tid2list, tid2zero_rt_list, staled_device_nodes) + self.tid2tree = self._build_tree(tid2list, tid2zero_rt_list, staled_device_nodes, is_ascend) # if could not find any forward/backward association, skip the processing if not fwd_bwd_map: @@ -55,7 +57,7 @@ class OpTreeBuilder: return self.tid2tree - def _build_tree(self, tid2list: Dict[int, List[OperatorNode]], tid2zero_rt_list, staled_device_nodes): + def _build_tree(self, tid2list: Dict[int, List[OperatorNode]], tid2zero_rt_list, staled_device_nodes, is_ascend): tid2tree = {} for tid, op_list in tid2list.items(): @@ -66,9 +68,9 @@ class OpTreeBuilder: if main_tid: # only append the staled device nodes into main thread self.main_tid = op_list[0].tid - root_node = self._build_tree_internal(op_list, zero_rt_list, tid, staled_device_nodes) + root_node = self._build_tree_internal(op_list, zero_rt_list, tid, staled_device_nodes, is_ascend) else: - root_node = self._build_tree_internal(op_list, zero_rt_list, tid, []) + root_node = self._build_tree_internal(op_list, zero_rt_list, tid, [], is_ascend) tid2tree[int(tid)] = root_node return tid2tree @@ -95,7 +97,7 @@ class OpTreeBuilder: return None - def _build_tree_internal(self, host_node_list, zero_rt_list, tid, staled_device_nodes): + def _build_tree_internal(self, host_node_list, zero_rt_list, tid, staled_device_nodes, is_ascend): """host_node_list: list of OperatorNode and ProfilerStepNode. zero_rt_list: list of RuntimeNode with external_id=0.""" @@ -125,7 +127,8 @@ class OpTreeBuilder: while True: # break loop when the node is inserted. tail_node = node_stack[-1] if node.start_time < tail_node.end_time: - if node.end_time <= tail_node.end_time: + if node.end_time <= tail_node.end_time or ( + is_ascend and math.isclose(node.end_time, tail_node.end_time, rel_tol=1)): tail_node.children.append(node) # node.parent_node = weakref.ref(tail_node) node_stack.append(node) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py index 4184e3830a6..0186139a59d 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -93,7 +93,6 @@ class RunGenerator(object): self.profile_data.gpu_metrics_parser.get_gpu_metrics_data_tooltip( gpu_infos, self.profile_data.tc_ratio) - profile_run.tid2tree = self.profile_data.tid2tree profile_run.pl_tid2tree = self.profile_data.pl_tid2tree profile_run.module_stats = aggegate_module_view(self.profile_data.tid2tree, self.profile_data.events) @@ -131,6 +130,7 @@ class RunGenerator(object): profile_run.step_to_overlap = self._npu_get_overlap() profile_run.step_to_wait, profile_run.comm_op = self._npu_get_wait_table() + profile_run.tid2tree = self.profile_data.tid2tree if self.profile_data.has_trace: profile_run.views.append(consts.TRACE_VIEW) profile_run.trace_file_path = self.profile_data.trace_file_path -- Gitee