From c18a3fc2fbc82f3b15b9a183d1f572c247ad75c1 Mon Sep 17 00:00:00 2001 From: wuyulong11 Date: Thu, 9 Nov 2023 21:30:20 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BF=AE=E6=94=B9=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E3=80=91=20=E3=80=90tbplugin=E3=80=91=E3=80=90issue=20#I8F38Y?= =?UTF-8?q?=E3=80=91=E7=AE=97=E5=AD=90=E7=B1=BB=E5=9E=8B=E8=BD=AC=E4=B8=BA?= =?UTF-8?q?=E5=B0=8F=E5=86=99=E5=88=A4=E6=96=AD=20=E3=80=90=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=E4=BA=BA=E3=80=91=20wuyulong=2030031080?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../torch_tb_profiler/profiler/trace.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py index 32c5e1ad3e..3657fb11fb 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/trace.py @@ -33,17 +33,17 @@ class EventTypes(object): EventTypeMap = { - 'Trace': EventTypes.TRACE, + 'trace': EventTypes.TRACE, 'cpu_op': EventTypes.OPERATOR, - 'Operator': EventTypes.OPERATOR, - 'Runtime': EventTypes.RUNTIME, - 'Kernel': EventTypes.KERNEL, - 'Memcpy': EventTypes.MEMCPY, + 'operator': EventTypes.OPERATOR, + 'runtime': EventTypes.RUNTIME, + 'kernel': EventTypes.KERNEL, + 'memcpy': EventTypes.MEMCPY, 'gpu_memcpy': EventTypes.MEMCPY, - 'Memset': EventTypes.MEMSET, + 'memset': EventTypes.MEMSET, 'gpu_memset': EventTypes.MEMSET, - 'Python': EventTypes.PYTHON, - 'Memory': EventTypes.MEMORY, + 'python': EventTypes.PYTHON, + 'memory': EventTypes.MEMORY, 'python_function': EventTypes.PYTHON_FUNCTION } @@ -178,7 +178,7 @@ def create_event(event, is_pytorch_lightning) -> Optional[BaseEvent]: def create_trace_event(event, is_pytorch_lightning) -> Optional[BaseEvent]: category = event.get('cat') - event_type = EventTypeMap.get(category) + event_type = EventTypeMap.get(category.lower()) if category else None if event_type == EventTypes.OPERATOR: name = event.get('name') if name and name.startswith('ProfilerStep#'): -- Gitee