diff --git a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py index ed9b099ccba6725d3bbb8aab0685ecb2a535aad9..4bdfef709cd220cf6e487f370309992e58689756 100644 --- a/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py +++ b/plugins/tensorboard-plugins/tb_plugin/torch_tb_profiler/profiler/run_generator.py @@ -151,6 +151,8 @@ class RunGenerator(object): for step in data[1:]: key = step[0] + if key == '': + key = 'all' overlap = [float(step[int(title_name[0])]), float(step[int(title_name[1])]), float(step[int(title_name[2])]), float(step[int(title_name[3])])] if key in overlap_by_steps: @@ -165,7 +167,7 @@ class RunGenerator(object): length = len(title) if length < 5: return - key = ["compute time", "overlapped time", "communication time not overlapped", "free time"] + key = ["computing", "overlapped", "communication(not overlapped)", "free"] get_key = list() for j in key: for i in range(length): @@ -197,14 +199,17 @@ class RunGenerator(object): table_ops: Dict[str, List[float]] = OrderedDict() if len(communication_json) <= 0: return wait_by_step, table_ops - for data in communication_json: - step = data.get("step_id") + for step in communication_json: + step_id = re.sub(r'step', '', step) + if step_id == '': + step_id = 'all' + data = communication_json.get(step) collection_ops = data.get("collective") p2p_ops = data.get("p2p") coll_total_trans, coll_total_synchronize = RunGenerator._get_wait_table_by_ops(collection_ops, table_ops) p2p_total_trans, p2p_total_synchronize = RunGenerator._get_wait_table_by_ops(p2p_ops, table_ops) - wait_by_step[step] = { + wait_by_step[step_id] = { "trans": coll_total_trans + p2p_total_trans, "Synchronize": coll_total_synchronize + p2p_total_synchronize }