From 80c487204df2a82de91519c8d3d8068da1ca856e Mon Sep 17 00:00:00 2001 From: wuyulong11 Date: Mon, 27 Nov 2023 09:53:24 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BF=AE=E6=94=B9=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E3=80=91=E3=80=90tbplugin=E3=80=91=E3=80=90=E9=97=AE=E9=A2=98?= =?UTF-8?q?=E5=8D=95=E3=80=91=E6=96=B0=E5=A2=9E=E5=AF=B9pytorch=20gpu=20pr?= =?UTF-8?q?ofiling=E9=87=87=E9=9B=86user=5Fannotation=E7=B1=BB=E5=9E=8B?= =?UTF-8?q?=E7=AE=97=E5=AD=90=E9=80=82=E9=85=8D=20=E3=80=90=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=BA=BA=E3=80=91=20wuyulong=2030031080?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../profiler/communication.py | 4 +++ .../profiler/event_parser.py | 32 ++++++++++--------- .../torch_tb_profiler/profiler/node.py | 7 +++- .../profiler/run_generator.py | 4 +-- .../torch_tb_profiler/profiler/trace.py | 19 ++++++++--- 5 files changed, 44 insertions(+), 22 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/communication.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/communication.py index 0894ea3966..00f8dc9813 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/communication.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/communication.py @@ -68,6 +68,10 @@ def analyze_communication_nodes(comm_node_list: List[CommunicationNode])\ bytes_one_value = 4 elif comm_node.input_type[i] == 'c10::Half': bytes_one_value = 2 + elif comm_node.input_type[i] == 'c10:BFloat16': + bytes_one_value = 2 + elif comm_node.input_type[i] == 'unsigned char': + bytes_one_value = 1 else: logger.warning('Found an unknown tensor type: {}'.format(comm_node.input_type[i])) bytes_one_value = 0 diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/event_parser.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/event_parser.py index 30c1878cfe..061db7a4e0 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/event_parser.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/event_parser.py @@ -12,12 +12,10 @@ from .node import (CommunicationNode, DeviceNode, ModuleNode, OperatorNode, PLMo ProfilerStepNode, RuntimeNode, create_operator_node) from .op_tree import OpTreeBuilder from .range_utils import merge_ranges -from .trace import BaseEvent, DurationEvent, EventTypes, KernelEvent +from .trace import BaseEvent, DurationEvent, EventTypes, KernelEvent, NcclOpNameSet, GlooOpNameSet logger = utils.get_logger() -NcclOpNameSet = ['nccl:broadcast', 'nccl:reduce', 'nccl:all_reduce', 'nccl:all_gather', 'nccl:reduce_scatter'] -GlooOpNameSet = ['gloo:broadcast', 'gloo:reduce', 'gloo:all_reduce', 'gloo:all_gather', 'gloo:reduce_scatter'] CommLibTypes = IntEnum('CommLibTypes', ['Nccl', 'Gloo'], start=0) @@ -173,7 +171,8 @@ class NodeParserMixin: EventTypes.OPERATOR, EventTypes.PL_MODULE, EventTypes.PROFILER_STEP, - EventTypes.MODULE]: + EventTypes.MODULE, + EventTypes.USER_ANNOTATION]: if event.type == EventTypes.PROFILER_STEP: op_node = ProfilerStepNode.create(event) elif event.type == EventTypes.MODULE: @@ -188,16 +187,17 @@ class NodeParserMixin: self.comm_lib.add(CommLibTypes.Nccl) if event.name in GlooOpNameSet: self.comm_lib.add(CommLibTypes.Gloo) - ts = event.ts - dur = event.duration - comm_node.kernel_ranges.append((ts, ts + dur)) - comm_node.total_time = dur + ts = event.ts + dur = event.duration + comm_node.kernel_ranges.append((ts, ts + dur)) + comm_node.total_time = dur self.communication_data[op_node.external_id] = comm_node if event.name == 'DataParallel.forward': self.use_dp = True if event.name == 'DistributedDataParallel.forward': self.use_ddp = True - tid2list[int(tid)].append(op_node) + if op_node: + tid2list[int(tid)].append(op_node) elif event.type == EventTypes.PL_PROFILE: op_node = PLProfileNode.create(event) pl_tid2list[int(tid)].append(op_node) @@ -260,6 +260,10 @@ class StepParser: return bool(self.role_ranges[ProfileRole.Memcpy] or self.role_ranges[ProfileRole.Memset]) def _parse_step(self, event: DurationEvent, comm_nodes: Dict[int, CommunicationNode]): + def check_name(name: str): + return (name.startswith('enumerate(DataLoader)#') and name.endswith('.__next__')) or name.startswith( + 'enumerate(DataPipe)#') + ts = event.ts dur = event.duration evt_type = event.type @@ -274,15 +278,13 @@ class StepParser: self.role_ranges[ProfileRole.Memset].append((ts, ts + dur)) elif evt_type == EventTypes.RUNTIME: self.role_ranges[ProfileRole.Runtime].append((ts, ts + dur)) - elif evt_type == EventTypes.OPERATOR: - if ((event.name.startswith('enumerate(DataLoader)#') and event.name.endswith('.__next__')) - or event.name.startswith('enumerate(DataPipe)#')): - self.role_ranges[ProfileRole.DataLoader].append((ts, ts + dur)) + elif evt_type in [EventTypes.OPERATOR, EventTypes.USER_ANNOTATION] and check_name(event.name): + self.role_ranges[ProfileRole.DataLoader].append((ts, ts + dur)) elif event.type == EventTypes.PROFILER_STEP: self.steps.append((ts, ts + dur)) self.steps_names.append(str(event.step)) - elif evt_type in [EventTypes.PYTHON, EventTypes.OPERATOR]: - if event.name in GlooOpNameSet: + elif evt_type in [EventTypes.PYTHON, EventTypes.OPERATOR, EventTypes.USER_ANNOTATION]: + if event.name in GlooOpNameSet or event.name in NcclOpNameSet: self.role_ranges[ProfileRole.Communication].append((ts, ts + dur)) else: self.role_ranges[ProfileRole.CpuOp].append((ts, ts + dur)) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/node.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/node.py index 824b809497..80860e5366 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/node.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/node.py @@ -8,7 +8,7 @@ from typing import List, Optional, Tuple from .. import utils from .tensor_core import TC_Allowlist, TC_OP_Allowlist from .trace import (DurationEvent, EventTypes, KernelEvent, ModuleEvent, - OperatorEvent, PLProfileEvent) + OperatorEvent, PLProfileEvent, NcclOpNameSet, GlooOpNameSet) logger = utils.get_logger() @@ -296,6 +296,11 @@ def create_operator_node(event: OperatorEvent): return DataLoaderNode.create(event) elif event.name.startswith('Optimizer.step'): return OptimizerNode.create(event) + elif event.type == EventTypes.USER_ANNOTATION: + if event.name in GlooOpNameSet or event.name in NcclOpNameSet: + return OperatorNode.create(event) + else: + return None else: return OperatorNode.create(event) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py index 0f00b40a38..f13948aa37 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -1166,7 +1166,7 @@ class DistributedRunGenerator(object): else: DistributedRunGenerator._get_npu_overlap_data(data, steps_to_overlap) - steps_to_overlap['all'][data.worker] = [x / step_number for x in steps_to_overlap['all'][data.worker]] + steps_to_overlap['all'][data.worker] = [int(x / step_number) for x in steps_to_overlap['all'][data.worker]] for k, v in steps_to_overlap.items(): steps_to_overlap[k] = OrderedDict(sorted(v.items())) result['data'] = steps_to_overlap @@ -1226,7 +1226,7 @@ class DistributedRunGenerator(object): ] steps_to_wait['all'][data.worker] = [ sum(x) for x in zip(steps_to_wait['all'][data.worker], steps_to_wait[step][data.worker])] - steps_to_wait['all'][data.worker] = [x / step_number for x in steps_to_wait['all'][data.worker]] + steps_to_wait['all'][data.worker] = [int(x / step_number) for x in steps_to_wait['all'][data.worker]] def _generate_wait_graph(self): result = dict() diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py index 3657fb11fb..8ce3dc63b5 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py @@ -10,6 +10,9 @@ __all__ = ['EventTypes', 'create_event'] logger = utils.get_logger() +NcclOpNameSet = ['nccl:broadcast', 'nccl:reduce', 'nccl:all_reduce', 'nccl:all_gather', 'nccl:reduce_scatter'] +GlooOpNameSet = ['gloo:broadcast', 'gloo:reduce', 'gloo:all_reduce', 'gloo:all_gather', 'gloo:reduce_scatter'] + class DeviceType(IntEnum): CPU = 0 @@ -30,6 +33,7 @@ class EventTypes(object): MODULE = 'Module' PL_PROFILE = 'pl_profile' PL_MODULE = 'pl_module' + USER_ANNOTATION = 'user_annotation' EventTypeMap = { @@ -44,7 +48,9 @@ EventTypeMap = { 'gpu_memset': EventTypes.MEMSET, 'python': EventTypes.PYTHON, 'memory': EventTypes.MEMORY, - 'python_function': EventTypes.PYTHON_FUNCTION + 'python_function': EventTypes.PYTHON_FUNCTION, + 'user_annotation': EventTypes.USER_ANNOTATION, + 'gpu_user_annotation': EventTypes.USER_ANNOTATION } @@ -179,7 +185,13 @@ def create_event(event, is_pytorch_lightning) -> Optional[BaseEvent]: def create_trace_event(event, is_pytorch_lightning) -> Optional[BaseEvent]: category = event.get('cat') event_type = EventTypeMap.get(category.lower()) if category else None - if event_type == EventTypes.OPERATOR: + if event_type == EventTypes.USER_ANNOTATION: + name = event.get('name') + if name and name.startswith('ProfilerStep#'): + return ProfilerStepEvent(event) + if name in GlooOpNameSet or name in NcclOpNameSet: + return OperatorEvent(event_type, event) + elif event_type == EventTypes.OPERATOR: name = event.get('name') if name and name.startswith('ProfilerStep#'): return ProfilerStepEvent(event) @@ -203,8 +215,7 @@ def create_trace_event(event, is_pytorch_lightning) -> Optional[BaseEvent]: return PythonFunctionEvent(event_type, event) elif event_type is not None: return DurationEvent(event_type, event) - else: - return None + return None def create_association_events(events) -> Dict[int, int]: -- Gitee