diff --git a/debug/accuracy_tools/msprobe/visualization/db_utils.py b/debug/accuracy_tools/msprobe/visualization/db_utils.py index 6d9e65d8fcf6fce77a52ef525b87871c44217b01..3988ebecc88863b7e8d65724719cf132971f116e 100644 --- a/debug/accuracy_tools/msprobe/visualization/db_utils.py +++ b/debug/accuracy_tools/msprobe/visualization/db_utils.py @@ -66,7 +66,12 @@ config_columns = { } indexes = { - "index1": ["step", "rank", "data_source", "up_node", "node_name"] + "index1": ["step", "rank", "data_source", "up_node", "node_order"], + "index2": ["step", "rank", "data_source", "node_name"], + "index3": ["step", "rank", "data_source", "node_order"], + "index4": ["step", "rank", "node_order"], + "index5": ["step", "rank", "micro_step_id", "node_order"], + "index6": ["step", "rank", "modified", "matched_node_link"] } SAFE_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$') @@ -201,7 +206,8 @@ def node_to_db(graph, db_name): json.dumps(node.stack_info), json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '', json.dumps(node.matched_distributed), 0, - json.dumps(format_node_data(node.input_data)), json.dumps(format_node_data(node.output_data)), + json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)), + json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)), graph.data_source, graph.data_path, graph.step, graph.rank)) to_db(db_name, create_table_sql, insert_sql, data) diff --git a/debug/accuracy_tools/msprobe/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/visualization/graph/graph.py index dac99514628cfe9acbf44587de526be123c3cc18..4bec2a65d481f9eddc38a4cf7978a2599285fd63 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/graph.py @@ -194,6 +194,15 @@ class Graph: graph_other: 可选参数,另一个graph Returns: 分批的数量 """ + + @recursion_depth_decorator( + 'msprobe.visualization.graph.graph.Graph.paging_by_micro_step.propagate_micro_step_id', max_depth=500) + def propagate_micro_step_id(node): + if node.upnode is not None and node.micro_step_id is None: + node.micro_step_id = node.upnode.micro_step_id + for sub_node in node.subnodes: + propagate_micro_step_id(sub_node) + batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes) for batch_number, nodes in batches_n.items(): for node in nodes: @@ -203,6 +212,7 @@ class Graph: node_other = graph_other.get_node(node.matched_node_link[-1]) if node_other: node_other.micro_step_id = batch_number + propagate_micro_step_id(self.root) # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id if graph_other: for node in graph_other.root.subnodes: @@ -212,6 +222,7 @@ class Graph: except ValueError: micro_step_id = 0 node.micro_step_id = micro_step_id + propagate_micro_step_id(graph_other.root) return len(batches_n) def overflow_check(self):