From bf9be5bb954a7296bca4568e7b6926de3a70143a Mon Sep 17 00:00:00 2001 From: cabbage Date: Tue, 15 Aug 2023 11:39:17 +0800 Subject: [PATCH] =?UTF-8?q?debug=E4=BF=AE=E5=A4=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../torch_tb_profiler/profiler/run_generator.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py index ed9b099ccb..4bdfef709c 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -151,6 +151,8 @@ class RunGenerator(object): for step in data[1:]: key = step[0] + if key == '': + key = 'all' overlap = [float(step[int(title_name[0])]), float(step[int(title_name[1])]), float(step[int(title_name[2])]), float(step[int(title_name[3])])] if key in overlap_by_steps: @@ -165,7 +167,7 @@ class RunGenerator(object): length = len(title) if length < 5: return - key = ["compute time", "overlapped time", "communication time not overlapped", "free time"] + key = ["computing", "overlapped", "communication(not overlapped)", "free"] get_key = list() for j in key: for i in range(length): @@ -197,14 +199,17 @@ class RunGenerator(object): table_ops: Dict[str, List[float]] = OrderedDict() if len(communication_json) <= 0: return wait_by_step, table_ops - for data in communication_json: - step = data.get("step_id") + for step in communication_json: + step_id = re.sub(r'step', '', step) + if step_id == '': + step_id = 'all' + data = communication_json.get(step) collection_ops = data.get("collective") p2p_ops = data.get("p2p") coll_total_trans, coll_total_synchronize = RunGenerator._get_wait_table_by_ops(collection_ops, table_ops) p2p_total_trans, p2p_total_synchronize = RunGenerator._get_wait_table_by_ops(p2p_ops, table_ops) - wait_by_step[step] = { + wait_by_step[step_id] = { "trans": coll_total_trans + p2p_total_trans, "Synchronize": coll_total_synchronize + p2p_total_synchronize } -- Gitee