From b916553562203ecaceb5850ff906e2c7aeb9862d Mon Sep 17 00:00:00 2001 From: wuyulong11 Date: Tue, 6 Feb 2024 10:20:44 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90=E4=BF=AE=E6=94=B9=E4=BF=A1=E6=81=AF?= =?UTF-8?q?=E3=80=91=E3=80=90tbplugin=E3=80=91Memory=E8=A7=86=E5=9B=BE?= =?UTF-8?q?=E9=80=82=E9=85=8D=E6=96=B0=E7=9A=84=E9=87=87=E9=9B=86=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=20=E3=80=90=E4=BF=AE=E6=94=B9=E4=BA=BA=E3=80=91=20wuy?= =?UTF-8?q?ulong=2030031080?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../torch_tb_profiler/profiler/run_generator.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py index df983cd7ce..f2ab0452ec 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -442,15 +442,22 @@ class RunGenerator(object): else: # Convert time metric table['columns'].append({'name': column.replace('(us)', '(ms)'), 'type': 'number'}) + required_column_idxs = {key: -1 for key in display_columns} + (name_idx, size_idx, allocation_idx, release_idx, duration_idx), column_exist_count = \ + RunGenerator._check_csv_columns(datas[0], required_column_idxs) + if column_exist_count < len(required_column_idxs): + logger.error('Required column is missing in file "operator_memory.csv"') for idx, ls in enumerate(datas[1:]): device_type = ls[self.device_type_form_idx] # convert time metric 'us' to 'ms' # some operators may not have the following columns try: - nums = [ls[0] if ls[0] else '', abs(float(ls[1])), - round((float(ls[2]) - self.profile_data.profiler_start_ts) / 1000, 3) if ls[2] else None, - round((float(ls[3]) - self.profile_data.profiler_start_ts) / 1000, 3) if ls[3] else None, - round(float(ls[4]) / 1000, 3) if ls[4] else None] + nums = [ls[name_idx] if ls[name_idx] else '', abs(float(ls[size_idx])), + round((float(ls[allocation_idx]) - self.profile_data.profiler_start_ts) / 1000, 3) if ls[ + allocation_idx] else None, + round((float(ls[release_idx]) - self.profile_data.profiler_start_ts) / 1000, 3) if ls[ + release_idx] else None, + round(float(ls[duration_idx]) / 1000, 3) if ls[duration_idx] else None] display_datas[device_type].append(nums) except ValueError: logger.error(f'File "{path}" has wrong data format in row {idx + 2} and will skip it.') -- Gitee