diff --git a/debug/accuracy_tools/msprobe/test/visualization_ut/test_db_utils.py b/debug/accuracy_tools/msprobe/test/visualization_ut/test_db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..03d8d3d6e0c9d5d7c044a8dabd624eb225e2663e --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/visualization_ut/test_db_utils.py @@ -0,0 +1,212 @@ +import os +import unittest +import sqlite3 +import json +import tempfile +from unittest.mock import patch, MagicMock +from msprobe.visualization.db_utils import to_db, node_to_db, config_to_db, get_graph_unique_id, get_node_unique_id +from msprobe.visualization.utils import GraphConst + + +class TestDatabaseUtils(unittest.TestCase): + def setUp(self): + self.db_file = tempfile.NamedTemporaryFile(delete=False) + self.db_file.close() + self.db_name = self.db_file.name + + def tearDown(self): + if os.path.exists(self.db_name): + os.unlink(self.db_name) + + def test_get_graph_unique_id(self): + # 测试获取图的唯一标识符 + graph = MagicMock() + graph.data_source = "source" + graph.step = 1 + graph.rank = 0 + self.assertEqual(get_graph_unique_id(graph), "source10") + + def test_get_node_unique_id(self): + # 测试获取节点的唯一标识符 + graph = MagicMock() + graph.data_source = "source" + graph.step = 1 + graph.rank = 0 + + node = MagicMock() + node.id = "node1" + self.assertEqual(get_node_unique_id(graph, node), "source10node1") + + @patch('msprobe.core.common.log.logger') + def test_to_db_single_record(self, mock_logger): + # 测试单条数据插入 + create_table_sql = "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" + insert_sql = "INSERT INTO test (id, name) VALUES (?, ?)" + data = [(1, "test")] + + to_db(self.db_name, create_table_sql, insert_sql, data) + + # 验证数据是否插入 + conn = sqlite3.connect(self.db_name) + cursor = conn.cursor() + cursor.execute("SELECT * FROM test") + result = cursor.fetchall() + conn.close() + + self.assertEqual(result, [(1, "test")]) + mock_logger.error.assert_not_called() + + @patch('msprobe.core.common.log.logger') + def test_to_db_multiple_records(self, mock_logger): + # 测试多条数据插入 + create_table_sql = "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" + insert_sql = "INSERT INTO test (id, name) VALUES (?, ?)" + data = [(1, "test1"), (2, "test2"), (3, "test3")] + + to_db(self.db_name, create_table_sql, insert_sql, data, db_insert_size=2) + + # 验证数据是否插入 + conn = sqlite3.connect(self.db_name) + cursor = conn.cursor() + cursor.execute("SELECT * FROM test") + result = cursor.fetchall() + conn.close() + + self.assertEqual(len(result), 3) + mock_logger.error.assert_not_called() + + @patch('msprobe.core.common.log.logger.error') + def test_to_db_sql_error(self, mock_logger): + # 测试SQL错误处理 + create_table_sql = "CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)" + insert_sql = "INVALID SQL" # 故意使用无效SQL + data = [(1, "test")] + + with self.assertRaises(RuntimeError): + to_db(self.db_name, create_table_sql, insert_sql, data) + + mock_logger.assert_called_once() + + @patch('msprobe.visualization.db_utils.to_db') + @patch('msprobe.visualization.db_utils.get_graph_unique_id') + @patch('msprobe.visualization.db_utils.get_node_unique_id') + @patch('msprobe.visualization.builder.msprobe_adapter.format_node_data') + def test_node_to_db(self, mock_format, mock_node_id, mock_graph_id, mock_to_db): + # 测试节点数据插入数据库 + graph = MagicMock() + graph.data_source = "source" + graph.step = 1 + graph.rank = 0 + graph.data_path = "/path" + + node1 = MagicMock() + node1.id = "node1" + node1.op.value = "OP1" + node1.upnode = None + node1.data = {GraphConst.JSON_INDEX_KEY: 1, GraphConst.OVERFLOW_LEVEL: "LOW"} + node1.micro_step_id = 1 + node1.matched_node_link = ["link1"] + node1.stack_info = ["stack1"] + node1.matched_distributed = ["dist1"] + node1.input_data = {"input": "data"} + node1.output_data = {"output": "data"} + + node2 = MagicMock() + node2.id = "node2" + node2.op.value = "OP2" + node2.upnode = node1 + node2.data = {} + node2.micro_step_id = None + node2.matched_node_link = [] + node2.stack_info = [] + node2.matched_distributed = [] + node2.input_data = {} + node2.output_data = {} + + graph.node_map = {"node1": node1, "node2": node2} + + mock_graph_id.return_value = "graph_id" + mock_node_id.side_effect = ["node_id1", "node_id2", "node_id1", "node_id2"] + mock_format.side_effect = [{"formatted": "input"}, {"formatted": "output"}] + + node_to_db(graph, self.db_name) + + # 验证to_db函数被正确调用 + expected_data = [ + ( + "node_id1", "graph_id", 0, "node1", "OP1", "", 1, "LOW", 1, + json.dumps(["link1"]), json.dumps(["stack1"]), json.dumps(["dist1"]), 0, + json.dumps({"input": "data"}), json.dumps({"output": "data"}), + "source", "/path", 1, 0 + ), + ( + "node_id2", "graph_id", 1, "node2", "OP2", "node_id2", None, None, 0, + json.dumps([]), json.dumps([]), json.dumps([]), 0, + json.dumps({}), json.dumps({}), + "source", "/path", 1, 0 + ) + ] + + mock_to_db.assert_called_once() + args, _ = mock_to_db.call_args + self.assertEqual(args[0], self.db_name) + self.assertIn("CREATE TABLE IF NOT EXISTS tb_nodes", args[1]) + self.assertIn("INSERT INTO tb_nodes", args[2]) + self.assertEqual(args[3], expected_data) + + @patch('msprobe.visualization.db_utils.to_db') + def test_config_to_db(self, mock_to_db): + # 测试配置数据插入数据库 + config = MagicMock() + config.graph_b = None + config.task = "task" + config.tool_tip = "tooltip" + config.micro_steps = 5 + config.overflow_check = 1 + config.node_colors = {"type1": "red"} + + config_to_db(config, self.db_name) + + # 验证to_db函数被正确调用 + expected_data = [ + ( + "1", "build", "task", "tooltip", 5, 1, + json.dumps({"type1": "red"}) + ) + ] + + mock_to_db.assert_called_once() + args, _ = mock_to_db.call_args + self.assertEqual(args[0], self.db_name) + self.assertIn("CREATE TABLE IF NOT EXISTS tb_config", args[1]) + self.assertIn("INSERT OR IGNORE INTO tb_config", args[2]) + self.assertEqual(args[3], expected_data) + + @patch('msprobe.visualization.db_utils.to_db') + def test_config_to_db_compare(self, mock_to_db): + # 测试比较模式下的配置数据插入 + config = MagicMock() + config.graph_b = MagicMock() # 存在graph_b,应使用compare模式 + config.task = "task" + config.tool_tip = "tooltip" + config.micro_steps = 5 + config.overflow_check = 1 + config.node_colors = {"type1": "red"} + + config_to_db(config, self.db_name) + + # 验证to_db函数被正确调用 + expected_data = [ + ( + "1", "compare", "task", "tooltip", 5, 1, + json.dumps({"type1": "red"}) + ) + ] + + mock_to_db.assert_called_once() + args, _ = mock_to_db.call_args + self.assertEqual(args[3], expected_data) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py index 78b4b83cb17c99a80dfbc6eeb9ceafba1543fedf..4f04ab74bbe08b21f2c10a6d31c36f1d4aea2a82 100644 --- a/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py +++ b/debug/accuracy_tools/msprobe/visualization/builder/graph_builder.py @@ -24,6 +24,7 @@ from msprobe.visualization.builder.msprobe_adapter import op_patterns from msprobe.visualization.graph.graph import Graph from msprobe.visualization.graph.node_op import NodeOp from msprobe.visualization.utils import GraphConst +from msprobe.visualization.db_utils import node_to_db, config_to_db class GraphBuilder: @@ -78,6 +79,19 @@ class GraphBuilder: result[GraphConst.OVERFLOW_CHECK] = config.overflow_check save_json(filename, result, indent=4) + @staticmethod + def to_db(filename, config): + config.graph_n.step = config.step + config.graph_n.rank = config.rank + config.graph_n.compare_mode = config.compare_mode + node_to_db(config.graph_n, filename) + if config.graph_b: + config.graph_b.data_source = GraphConst.JSON_BENCH_KEY + config.graph_b.step = config.step + config.graph_b.rank = config.rank + node_to_db(config.graph_b, filename) + config_to_db(config, filename) + @staticmethod def _simplify_stack(stack_dict): """ @@ -279,7 +293,7 @@ class GraphBuilder: class GraphExportConfig: def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None, task='', - overflow_check=False, compare_mode=None): + overflow_check=False, compare_mode=None, step=0, rank=0): self.graph_n = graph_n self.graph_b = graph_b self.tool_tip = tool_tip @@ -288,6 +302,8 @@ class GraphExportConfig: self.task = task self.overflow_check = overflow_check self.compare_mode = compare_mode + self.step = step + self.rank = rank @dataclass diff --git a/debug/accuracy_tools/msprobe/visualization/db_utils.py b/debug/accuracy_tools/msprobe/visualization/db_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62a12494d8ac59eec7af072609a2291d9da9fcc0 --- /dev/null +++ b/debug/accuracy_tools/msprobe/visualization/db_utils.py @@ -0,0 +1,114 @@ +# Copyright (c) 2024-2025, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sqlite3 +import json +from msprobe.core.common.log import logger +from msprobe.visualization.utils import GraphConst +from msprobe.visualization.builder.msprobe_adapter import format_node_data + + +def to_db(db_name, create_table_sql, insert_sql, data, db_insert_size=1000): + conn = sqlite3.connect(db_name) + cursor = conn.cursor() + cursor.execute(create_table_sql) + + try: + if len(data) == 1: + cursor.execute(insert_sql, data[0]) + conn.commit() + else: + for i in range(0, len(data), db_insert_size): + batch = data[i:i + db_insert_size] + cursor.executemany(insert_sql, batch) + conn.commit() + except sqlite3.Error as e: + logger.error(f"An sqlite3 error occurred: {e}") + conn.close() + raise RuntimeError() + else: + conn.close() + + +def node_to_db(graph, db_name): + create_table_sql = """ + CREATE TABLE IF NOT EXISTS tb_nodes ( + id TEXT PRIMARY KEY, + graph_id TEXT NOT NULL, + node_order INTEGER NOT NULL, + node_name TEXT NOT NULL, + node_type TEXT NOT NULL, + upnode TEXT NOT NULL, + precision_index INTEGER, + overflow_level TEXT, + micro_step_id INTEGER NOT NULL, + matched_node_link TEXT, + stack_info TEXT, + matched_distributed TEXT, + modified INTEGER NOT NULL, + input_data TEXT, + output_data TEXT, + data_source TEXT, + dump_data_dir TEXT, + step INTEGER NOT NULL, + rank INTEGER NOT NULL + ); + """ + insert_sql = """ + INSERT INTO tb_nodes (id, graph_id, node_order, node_name, node_type, upnode, precision_index, + overflow_level, micro_step_id, matched_node_link, stack_info, matched_distributed, modified, + input_data, output_data, data_source, dump_data_dir, step, rank) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) + """ + data = [] + for i, node in enumerate(graph.node_map.values()): + data.append((get_node_unique_id(graph, node), get_graph_unique_id(graph), i, node.id, node.op.value, + get_node_unique_id(graph, node.upnode) if node.upnode else '', + node.data.get(GraphConst.JSON_INDEX_KEY), node.data.get(GraphConst.OVERFLOW_LEVEL), + node.micro_step_id if node.micro_step_id is not None else 0, json.dumps(node.matched_node_link), + json.dumps(node.stack_info), json.dumps(node.matched_distributed), 0, + json.dumps(format_node_data(node.input_data)), json.dumps(format_node_data(node.output_data)), + graph.data_source, graph.data_path, graph.step, graph.rank)) + to_db(db_name, create_table_sql, insert_sql, data) + + +def config_to_db(config, db_name): + create_table_sql = """ + CREATE TABLE IF NOT EXISTS tb_config ( + id TEXT PRIMARY KEY, + graph_type TEXT NOT NULL, + task TEXT, + tool_tip TEXT, + micro_steps INTEGER, + overflow_check INTEGER, + node_colors TEXT NOT NULL + ); + """ + insert_sql = """ + INSERT OR IGNORE INTO tb_config (id, graph_type, task, tool_tip, micro_steps, overflow_check, + node_colors) + VALUES (?, ?, ?, ?, ?, ?, ?) + """ + data = [("1", "compare" if config.graph_b else "build", config.task, config.tool_tip, config.micro_steps, + config.overflow_check, json.dumps(config.node_colors))] + to_db(db_name, create_table_sql, insert_sql, data) + + +def get_graph_unique_id(graph): + return f'{graph.data_source}{graph.step}{graph.rank}' + + +def get_node_unique_id(graph, node): + return f'{graph.data_source}{graph.step}{graph.rank}{node.id}' diff --git a/debug/accuracy_tools/msprobe/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/visualization/graph/graph.py index f4caec221f4168e73b7414b3493f3d3f6f79265c..a9724b167363ab71518825d4e15f9b7077bda2a3 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/visualization/graph/graph.py @@ -28,6 +28,10 @@ class Graph: self.root = self.get_node(model_name) self.data_path = data_path self.dump_data = dump_data + self.data_source = GraphConst.JSON_NPU_KEY + self.step = 0 + self.rank = 0 + self.compare_mode = GraphConst.SUMMARY_COMPARE def __str__(self): infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] diff --git a/debug/accuracy_tools/msprobe/visualization/graph_service.py b/debug/accuracy_tools/msprobe/visualization/graph_service.py index 3e99449f43fddc51dda11684d1e26692d73bb7e3..de541816c8006c6c8699873aa74ec7ab08a199c3 100644 --- a/debug/accuracy_tools/msprobe/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/visualization/graph_service.py @@ -89,13 +89,16 @@ def _export_compare_graph_result(args, result): if not output_file_name: output_file_name = f'compare_{current_time}.vis' logger.info(f'Start exporting compare graph result, file name: {output_file_name}...') + output_db_name = f'compare_{current_time}.vis.db' output_path = os.path.join(args.output_path, output_file_name) + output_db_path = os.path.join(args.output_path, output_db_name) task = GraphConst.GRAPHCOMPARE_MODE_TO_DUMP_MODE_TO_MAPPING.get(graph_comparator.ma.compare_mode) export_config = GraphExportConfig(graphs[0], graphs[1], graph_comparator.ma.get_tool_tip(), NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps, task, - args.overflow_check, graph_comparator.ma.compare_mode) + args.overflow_check, graph_comparator.ma.compare_mode, result.rank, result.step) try: GraphBuilder.to_json(output_path, export_config) + GraphBuilder.to_db(output_db_path, export_config) logger.info(f'Exporting compare graph result successfully, the result file is saved in {output_path}') return '' except RuntimeError as e: @@ -182,10 +185,14 @@ def _export_build_graph_result(args, result): if not output_file_name: output_file_name = f'build_{current_time}.vis' logger.info(f'Start exporting graph for {output_file_name}...') + output_db_name = f'build_{current_time}.vis.db' output_path = os.path.join(out_path, output_file_name) + output_db_path = os.path.join(out_path, output_db_name) + config = GraphExportConfig(graph, micro_steps=micro_steps, overflow_check=overflow_check, rank=result.rank, + step=result.step) try: - GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps, - overflow_check=overflow_check)) + GraphBuilder.to_json(output_path, config) + GraphBuilder.to_db(output_db_path, config) logger.info(f'Model graph exported successfully, the result file is saved in {output_path}') return None except RuntimeError as e: @@ -282,6 +289,9 @@ def _get_compare_graph_results(input_param, serializable_args, step, pool, err_c br), error_callback=err_call)) compare_graph_results = [task.get() for task in compare_graph_tasks] + if step is not None: + for result in compare_graph_results: + result.step = step return compare_graph_results @@ -323,6 +333,10 @@ def _build_graph_ranks(dump_ranks_path, args, step=None): error_callback=err_call)) build_graph_results = [task.get() for task in build_graph_tasks] + if step is not None: + for result in build_graph_results: + result.step = step + if args.parallel_params: validate_parallel_param(args.parallel_params[0], dump_ranks_path) build_graph_results = GraphMerger(build_graph_results, args.parallel_params[0]).merge_graph() @@ -503,18 +517,20 @@ def _ms_graph_service_command(args): class CompareGraphResult: - def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, output_file_name=''): + def __init__(self, graph_n, graph_b, graph_comparator, micro_steps, rank=0, step=0, output_file_name=''): self.graph_n = graph_n self.graph_b = graph_b self.graph_comparator = graph_comparator self.micro_steps = micro_steps self.rank = rank + self.step = step self.output_file_name = output_file_name class BuildGraphResult: - def __init__(self, graph, micro_steps=0, rank=0, output_file_name=''): + def __init__(self, graph, micro_steps=0, rank=0, step=0, output_file_name=''): self.graph = graph self.micro_steps = micro_steps self.rank = rank + self.step = step self.output_file_name = output_file_name