From a617adb88e63020fb23322308fa7888cff5d6821 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Fri, 22 Aug 2025 14:57:32 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90bugfix=E3=80=91=E5=88=86=E7=BA=A7?= =?UTF-8?q?=E5=8F=AF=E8=A7=86=E5=8C=96=E8=BD=ACdb=E8=81=94=E8=B0=83?= =?UTF-8?q?=E8=8B=A5=E5=B9=B2=E9=97=AE=E9=A2=98=E5=A4=84=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/msprobe/visualization/db_utils.py | 10 ++++++++-- .../msprobe/visualization/graph/graph.py | 11 +++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/visualization/db_utils.py b/debug/accuracy_tools/msprobe/visualization/db_utils.py index 6d9e65d8fc..3988ebecc8 100644 --- a/debug/accuracy_tools/msprobe/visualization/db_utils.py +++ b/debug/accuracy_tools/msprobe/visualization/db_utils.py @@ -66,7 +66,12 @@ config_columns = { } indexes = { - "index1": ["step", "rank", "data_source", "up_node", "node_name"] + "index1": ["step", "rank", "data_source", "up_node", "node_order"], + "index2": ["step", "rank", "data_source", "node_name"], + "index3": ["step", "rank", "data_source", "node_order"], + "index4": ["step", "rank", "node_order"], + "index5": ["step", "rank", "micro_step_id", "node_order"], + "index6": ["step", "rank", "modified", "matched_node_link"] } SAFE_NAME_PATTERN = re.compile(r'^[a-zA-Z0-9_]+$') @@ -201,7 +206,8 @@ def node_to_db(graph, db_name): json.dumps(node.stack_info), json.dumps(node.parallel_merge_info) if node.parallel_merge_info else '', json.dumps(node.matched_distributed), 0, - json.dumps(format_node_data(node.input_data)), json.dumps(format_node_data(node.output_data)), + json.dumps(format_node_data(node.input_data, node.id, graph.compare_mode)), + json.dumps(format_node_data(node.output_data, node.id, graph.compare_mode)), graph.data_source, graph.data_path, graph.step, graph.rank)) to_db(db_name, create_table_sql, insert_sql, data) diff --git a/debug/accuracy_tools/msprobe/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/visualization/graph/graph.py index dac9951462..4bec2a65d4 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/graph.py @@ -194,6 +194,15 @@ class Graph: graph_other: 可选参数,另一个graph Returns: 分批的数量 """ + + @recursion_depth_decorator( + 'msprobe.visualization.graph.graph.Graph.paging_by_micro_step.propagate_micro_step_id', max_depth=500) + def propagate_micro_step_id(node): + if node.upnode is not None and node.micro_step_id is None: + node.micro_step_id = node.upnode.micro_step_id + for sub_node in node.subnodes: + propagate_micro_step_id(sub_node) + batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes) for batch_number, nodes in batches_n.items(): for node in nodes: @@ -203,6 +212,7 @@ class Graph: node_other = graph_other.get_node(node.matched_node_link[-1]) if node_other: node_other.micro_step_id = batch_number + propagate_micro_step_id(self.root) # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id if graph_other: for node in graph_other.root.subnodes: @@ -212,6 +222,7 @@ class Graph: except ValueError: micro_step_id = 0 node.micro_step_id = micro_step_id + propagate_micro_step_id(graph_other.root) return len(batches_n) def overflow_check(self): -- Gitee