From c50dfeedf832b359738ed2a98cb5d413b7c7cff8 Mon Sep 17 00:00:00 2001 From: wuyulong11 Date: Fri, 9 Jun 2023 17:00:49 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E3=80=90=E4=BF=AE=E6=94=B9=E8=AF=B4?= =?UTF-8?q?=E6=98=8E=E3=80=91=20Memory=20View=E7=95=8C=E9=9D=A2=E4=BF=AE?= =?UTF-8?q?=E6=94=B9=20=E3=80=90=E4=BF=AE=E6=94=B9=E4=BA=BA=E3=80=91=20wuy?= =?UTF-8?q?ulong=2030031080?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../tb_plugin/torch_tb_profiler/plugin.py | 20 ++++++++++++---- .../profiler/run_generator.py | 23 +++++++++++-------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/tb_plugins/profiling/tb_plugin/torch_tb_profiler/plugin.py b/tb_plugins/profiling/tb_plugin/torch_tb_profiler/plugin.py index 592c395c9bc..0150fcdcf1e 100644 --- a/tb_plugins/profiling/tb_plugin/torch_tb_profiler/plugin.py +++ b/tb_plugins/profiling/tb_plugin/torch_tb_profiler/plugin.py @@ -206,7 +206,7 @@ class TorchProfilerPlugin(base_plugin.TBPlugin): group_by = request.args.get('group_by') input_shape = request.args.get('input_shape') if group_by == 'OperationAndInputShape': - return self.respond_as_json(profile.operation_stack_by_name_input[str(op_name)+'###'+str(input_shape)]) + return self.respond_as_json(profile.operation_stack_by_name_input[str(op_name) + '###' + str(input_shape)]) else: return self.respond_as_json(profile.operation_stack_by_name[str(op_name)]) @@ -293,7 +293,7 @@ class TorchProfilerPlugin(base_plugin.TBPlugin): end_ts = int(end_ts) return self.respond_as_json( - profile.get_memory_stats(start_ts=start_ts, end_ts=end_ts, memory_metric=memory_metric), True) + profile.get_memory_stats(start_ts=start_ts, end_ts=end_ts, memory_metric=memory_metric), True) @wrappers.Request.application def memory_curve_route(self, request: werkzeug.Request): @@ -315,10 +315,20 @@ class TorchProfilerPlugin(base_plugin.TBPlugin): memory_metric = request.args.get('memory_metric', 'KB') if profile.device_target == 'Ascend': operator_memory_events = profile.memory_events['operator']['rows'] - start_ts = int(start_ts) if start_ts is not None else 0 - end_ts = int(end_ts) if end_ts is not None else float('inf') + if start_ts is not None: + start_ts = int(start_ts) + if end_ts is not None: + end_ts = int(end_ts) for key in operator_memory_events: - operator_memory_events[key] = [i for i in operator_memory_events[key] if start_ts <= i[2] <= end_ts] + if start_ts is not None and end_ts is not None: + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and start_ts <= i[2] <= end_ts] + elif start_ts is not None: + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and start_ts <= i[2]] + elif end_ts is not None: + operator_memory_events[key] = [i for i in operator_memory_events[key] if + i[2] and end_ts >= i[2]] return self.respond_as_json(profile.memory_events, True) else: if start_ts is not None: diff --git a/tb_plugins/profiling/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/tb_plugins/profiling/tb_plugin/torch_tb_profiler/profiler/run_generator.py index 7fd44d6e6e8..846c6db4294 100644 --- a/tb_plugins/profiling/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/tb_plugins/profiling/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -119,8 +119,12 @@ class RunGenerator(object): if len(datas) <= 1: return operator_by_name, operator_by_name_and_input_shapes for ls in datas[1:]: - temp: list = [ls[0], RunGenerator._trans_shape(str(ls[1])), ls[2], float(ls[3]), float(ls[4]), - float(ls[5]), float(ls[6]), float(ls[7]), float(ls[8])] + try: + temp: list = [ls[0], RunGenerator._trans_shape(str(ls[1])), ls[2], float(ls[3]), float(ls[4]), + float(ls[5]), float(ls[6]), float(ls[7]), float(ls[8])] + except (ValueError, IndexError): + logger.error('Data in file "operator_details.csv" has wrong format.') + return operator_by_name, operator_by_name_and_input_shapes operator_by_name[ls[0]].append(temp) key = "{}###{}".format(str(ls[0]), RunGenerator._trans_shape(str(ls[1]))) operator_by_name_and_input_shapes[key].append(temp) @@ -279,7 +283,7 @@ class RunGenerator(object): return temp def _get_memory_event(self, peak_memory_events: dict): - display_columns = ('Operator', 'Size(KB)', 'Allocation Time(us)', 'Release Time(us)', 'Duration(us)') + display_columns = ('Name', 'Size(KB)', 'Allocation Time(us)', 'Release Time(us)', 'Duration(us)') path = self.profile_data.memory_operator_path display_datas = defaultdict(list) devices_type = [] @@ -296,7 +300,7 @@ class RunGenerator(object): if column == 'Device Type': self.device_type_form_idx = idx if column in display_columns: - if column == 'Operator': + if column == 'Name': table['columns'].append({'name': column, 'type': 'string'}) elif column == 'Size(KB)': table['columns'].append({'name': column, 'type': 'number'}) @@ -306,12 +310,11 @@ class RunGenerator(object): for ls in datas[1:]: device_type = ls[self.device_type_form_idx] # convert time metric 'us' to 'ms' - nums = [ls[0], float(ls[1]), round((float(ls[2]) - self.profile_data.start_ts) / 1000, 3)] - # some operators may not have column[3] or column[4] - if ls[3]: - nums.append(round((float(ls[3]) - self.profile_data.start_ts) / 1000, 3)) - if ls[4]: - nums.append(round(float(ls[4]) / 1000, 2)) + # some operators may not have the following columns + nums = [ls[0] if ls[0] else '', abs(float(ls[1])), + round((float(ls[2]) - self.profile_data.start_ts) / 1000, 2) if ls[2] else None, + round((float(ls[3]) - self.profile_data.start_ts) / 1000, 2) if ls[3] else None, + round(float(ls[4]) / 1000, 2) if ls[4] else None] display_datas[device_type].append(nums) table['rows'] = display_datas for name in display_datas: -- Gitee From 5cb9679765bdf3c4cd8a670dfc9992994d10abcf Mon Sep 17 00:00:00 2001 From: wuyulong11 Date: Fri, 9 Jun 2023 18:08:08 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E3=80=90=E4=BF=AE=E6=94=B9=E8=AF=B4?= =?UTF-8?q?=E6=98=8E=E3=80=91=20Memory=20View=20Name=E6=90=9C=E7=B4=A2?= =?UTF-8?q?=E6=A1=86=E5=8A=9F=E8=83=BD=E4=BF=AE=E6=94=B9=20=E3=80=90?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=BA=E3=80=91=20wuyulong=2030031080?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../profiling/tb_plugin/fe/src/components/MemoryView.tsx | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tb_plugins/profiling/tb_plugin/fe/src/components/MemoryView.tsx b/tb_plugins/profiling/tb_plugin/fe/src/components/MemoryView.tsx index e3c7599f981..060a18d1679 100644 --- a/tb_plugins/profiling/tb_plugin/fe/src/components/MemoryView.tsx +++ b/tb_plugins/profiling/tb_plugin/fe/src/components/MemoryView.tsx @@ -176,6 +176,7 @@ export const MemoryView: React.FC = React.memo((props) => { const getName = React.useCallback((row: any) => row[searchIndex], [ searchIndex ]) + const getNameAscend = (row: any) => row[0] const [searchedTableDataRows] = useSearchDirectly( searchOperatorName, getName, @@ -183,7 +184,7 @@ export const MemoryView: React.FC = React.memo((props) => { ) const [searchedEventsTableDataRows] = useSearchDirectly( searchEventOperatorName, - getName, + deviceTarget === 'Ascend' ? getNameAscend : getName, filterByEventSize( memoryEventsData?.rows[device], filterEventSize[device] ?? [0, Infinity] -- Gitee