diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py index 9222754d5dd7afa9a9fff10c5a117c2e450db50f..d5ed2b3d4c574a3d21a0ba50dc1f05d3ceb39365 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/__init__.py @@ -4,4 +4,4 @@ # Entry point for Pytorch TensorBoard plugin package. -__version__ = '0.4.0.3' +__version__ = '0.4.0.4' diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py index ba423019a0d83547afe8f9382a80c1ee595fbf2c..ee1ce5b62deebb5d63dca8fc54f43f0cd2e1a6da 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/data.py @@ -27,6 +27,7 @@ import tempfile from json.decoder import JSONDecodeError from typing import Dict, List, Optional +from .op_tree import OpTreeBuilder from .. import io, utils from ..utils import href from . import trace @@ -163,6 +164,15 @@ class RunProfileData(object): break profile = RunProfileData(worker, span, trace_json) + with utils.timing('EventParser.parse'): + parser = EventParser() + with utils.timing('EventParser: parse nodes'): + tid2list, tid2zero_rt_list, staled_device_nodes, _ = parser.parse_nodes(profile.events) + + with utils.timing('EventParser: build operator tree'): + builder = OpTreeBuilder() + profile.tid2tree = builder.build_tree(tid2list, tid2zero_rt_list, staled_device_nodes, + fwd_bwd_map=profile.forward_backward_events, is_ascend=True) profile.trace_file_path = trace_path profile.has_trace = has_trace if math.isinf(profile.profiler_start_ts): diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py index 5639c666aadc32accc4d3548c19bcd4fa4ad4294..55e264617d835fb5bf94819b329fdbd2ee1c53f6 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/op_tree.py @@ -1,6 +1,7 @@ # ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # ------------------------------------------------------------------------- +import math import sys from collections import defaultdict from typing import Dict, Iterable, List, Optional, Tuple @@ -25,10 +26,11 @@ class OpTreeBuilder: tid2list: Dict[int, List[OperatorNode]], tid2zero_rt_list: Dict[int, List[RuntimeNode]], staled_device_nodes: List[DeviceNode], - fwd_bwd_map: Dict[int, int]): + fwd_bwd_map: Dict[int, int], + is_ascend=False): """Construct the BackwardNode and replace the original backward nodes """ - self.tid2tree = self._build_tree(tid2list, tid2zero_rt_list, staled_device_nodes) + self.tid2tree = self._build_tree(tid2list, tid2zero_rt_list, staled_device_nodes, is_ascend) # if could not find any forward/backward association, skip the processing if not fwd_bwd_map: @@ -55,7 +57,7 @@ class OpTreeBuilder: return self.tid2tree - def _build_tree(self, tid2list: Dict[int, List[OperatorNode]], tid2zero_rt_list, staled_device_nodes): + def _build_tree(self, tid2list: Dict[int, List[OperatorNode]], tid2zero_rt_list, staled_device_nodes, is_ascend): tid2tree = {} for tid, op_list in tid2list.items(): @@ -66,9 +68,9 @@ class OpTreeBuilder: if main_tid: # only append the staled device nodes into main thread self.main_tid = op_list[0].tid - root_node = self._build_tree_internal(op_list, zero_rt_list, tid, staled_device_nodes) + root_node = self._build_tree_internal(op_list, zero_rt_list, tid, staled_device_nodes, is_ascend) else: - root_node = self._build_tree_internal(op_list, zero_rt_list, tid, []) + root_node = self._build_tree_internal(op_list, zero_rt_list, tid, [], is_ascend) tid2tree[int(tid)] = root_node return tid2tree @@ -95,7 +97,7 @@ class OpTreeBuilder: return None - def _build_tree_internal(self, host_node_list, zero_rt_list, tid, staled_device_nodes): + def _build_tree_internal(self, host_node_list, zero_rt_list, tid, staled_device_nodes, is_ascend): """host_node_list: list of OperatorNode and ProfilerStepNode. zero_rt_list: list of RuntimeNode with external_id=0.""" @@ -125,7 +127,8 @@ class OpTreeBuilder: while True: # break loop when the node is inserted. tail_node = node_stack[-1] if node.start_time < tail_node.end_time: - if node.end_time <= tail_node.end_time: + if node.end_time <= tail_node.end_time or ( + is_ascend and math.isclose(node.end_time, tail_node.end_time, rel_tol=1)): tail_node.children.append(node) # node.parent_node = weakref.ref(tail_node) node_stack.append(node) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py index 4184e3830a6a00c735db22c279f8bd3a55e10747..0186139a59d43e54ba462e74063428c74033ea1c 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -93,7 +93,6 @@ class RunGenerator(object): self.profile_data.gpu_metrics_parser.get_gpu_metrics_data_tooltip( gpu_infos, self.profile_data.tc_ratio) - profile_run.tid2tree = self.profile_data.tid2tree profile_run.pl_tid2tree = self.profile_data.pl_tid2tree profile_run.module_stats = aggegate_module_view(self.profile_data.tid2tree, self.profile_data.events) @@ -131,6 +130,7 @@ class RunGenerator(object): profile_run.step_to_overlap = self._npu_get_overlap() profile_run.step_to_wait, profile_run.comm_op = self._npu_get_wait_table() + profile_run.tid2tree = self.profile_data.tid2tree if self.profile_data.has_trace: profile_run.views.append(consts.TRACE_VIEW) profile_run.trace_file_path = self.profile_data.trace_file_path