From 9569265ea2bf7373130136bb787351db0c1cdc40 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Wed, 25 Sep 2024 17:55:37 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E7=BA=A7=E5=8F=AF=E8=A7=86=E5=8C=96ut?= =?UTF-8?q?=E8=A1=A5=E5=85=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../builder/test_graph_builder.py | 49 +++++++- .../builder/test_msprobe_adapter.py | 37 +++++- .../compare/test_graph_comparator.py | 113 +++++++++++++++++- .../compare/test_mode_adapter.py | 42 ++++++- .../visualization/graph/test_base_node.py | 14 ++- .../visualization/graph/test_graph.py | 29 +++++ .../pytorch_ut/visualization/mapping.yaml | 2 + .../visualization/test_mapping_config.py | 52 ++++++++ .../pytorch_ut/visualization/test_utils.py | 27 +++++ 9 files changed, 358 insertions(+), 7 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml create mode 100644 debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py create mode 100644 debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py index 66eceea4b2..9433bc1363 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py @@ -1,6 +1,6 @@ import unittest from unittest.mock import MagicMock, patch -from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder, Graph +from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder, Graph, BaseNode, NodeOp class TestGraphBuilder(unittest.TestCase): @@ -50,3 +50,50 @@ class TestGraphBuilder(unittest.TestCase): self.assertIn("node1", self.graph.node_map) self.assertEqual(node.input_data, {}) self.assertEqual(node.output_data, {}) + + def test__handle_backward_upnode_missing(self): + construct_dict = {'Module.module.a.forward.0': 'Module.root.forward.0', 'Module.module.a.backward.0': None, + 'Module.root.forward.0': None, 'Module.root.backward.0': None, + 'Module.module.b.forward.0': 'Module.root.forward.0', + 'Module.module.b.backward.0': 'Module.root.backward.0', 'Module.module.c.backward.0': None} + node_id_a = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.a.backward.0', None) + self.assertEqual(node_id_a, 'Module.root.backward.0') + node_id_b = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.b.backward.0', + 'Module.root.backward.0') + self.assertEqual(node_id_b, 'Module.root.backward.0') + node_id_c = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.c.backward.0', None) + self.assertIsNone(node_id_c) + + def test__collect_apis_between_modules_only_apis(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 1) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.api_collection) + self.assertEqual(len(graph.root.subnodes[0].subnodes), 2) + self.assertEqual(graph.root.subnodes[0].id, 'Apis_Between_Modules.0') + + def test__collect_apis_between_modules_mixed_nodes(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.module, 'Module.a.0'), + BaseNode(NodeOp.module, 'Module.b.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0'), + BaseNode(NodeOp.function_api, 'Tensor.c.0'), BaseNode(NodeOp.module, 'Module.a.1')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 5) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.function_api) + self.assertEqual(graph.root.subnodes[1].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[3].op, NodeOp.api_collection) + self.assertEqual(len(graph.root.subnodes[3].subnodes), 2) + self.assertEqual(graph.root.subnodes[3].id, 'Apis_Between_Modules.0') + + def test__collect_apis_between_modules_only_modules(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.module, 'Module.a.0'), BaseNode(NodeOp.module, 'Module.b.0'), + BaseNode(NodeOp.module, 'Module.a.1')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 3) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[1].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[2].op, NodeOp.module) + self.assertEqual(len(graph.root.subnodes[0].subnodes), 0) + self.assertEqual(graph.root.subnodes[0].id, 'Module.a.0') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py index 12ae24279f..f023128b8c 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py @@ -8,7 +8,8 @@ from msprobe.pytorch.visualization.builder.msprobe_adapter import ( format_node_data, compare_node, _format_decimal_string, - _format_data + _format_data, + compare_mapping_data ) from msprobe.pytorch.visualization.utils import GraphConst @@ -44,7 +45,11 @@ class TestMsprobeAdapter(unittest.TestCase): def test_compare_data(self): data_dict_list1 = {'key1': {'type': 'Type1', 'dtype': 'DType1', 'shape': 'Shape1'}} data_dict_list2 = {'key1': {'type': 'Type1', 'dtype': 'DType1', 'shape': 'Shape1'}} + data_dict_list3 = {'key1': {'type': 'Type2', 'dtype': 'DType1', 'shape': 'Shape1'}} + data_dict_list4 = {} self.assertTrue(compare_data(data_dict_list1, data_dict_list2)) + self.assertFalse(compare_data(data_dict_list1, data_dict_list3)) + self.assertFalse(compare_data(data_dict_list1, data_dict_list4)) def test_format_node_data(self): data_dict = {'node1': {'data_name': 'data1', 'full_op_name': 'op1'}} @@ -66,8 +71,34 @@ class TestMsprobeAdapter(unittest.TestCase): s = "0.123456789%" formatted_s = _format_decimal_string(s) self.assertIn("0.123457%", formatted_s) + self.assertEqual('0.123457', _format_decimal_string('0.12345678')) + self.assertEqual('-1', _format_decimal_string('-1')) + self.assertEqual('0.0.25698548%', _format_decimal_string('0.0.25698548%')) def test__format_data(self): - data_dict = {'value': 0.123456789} + data_dict = {'value': 0.123456789, 'value1': None, 'value2': "", 'value3': 1.123123123123e-11, + 'value4': torch.inf, 'value5': -1} _format_data(data_dict) - self.assertEqual(data_dict['value'], '0.123457') \ No newline at end of file + self.assertEqual(data_dict['value'], '0.123457') + self.assertEqual(data_dict['value1'], 'null') + self.assertEqual(data_dict['value2'], '') + self.assertEqual(data_dict['value3'], '1.123123e-11') + self.assertEqual(data_dict['value4'], 'inf') + self.assertEqual(data_dict['value5'], '-1') + + all_none_dict = {'a': None, 'b': None, 'c': None, 'd': None, 'e': None} + _format_data(all_none_dict) + self.assertEqual({'value': 'null'}, all_none_dict) + + def test_compare_mapping_data(self): + dict1 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}, 'c': {'shape': [1, 2, 3]}} + dict2 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}, 'c': {'shape': [1, 2, 3]}} + dict3 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}} + dict4 = {'a': {'shape': [2, 1, 3]}, 'b': {'shape': [1, 2, 3]}} + dict5 = {'a': {'shape': [2, 2, 3]}, 'b': {'shape': [1, 2, 3]}} + dict6 = {'a': {'type': 'str'}} + self.assertTrue(compare_mapping_data(dict1, dict2)) + self.assertTrue(compare_mapping_data(dict1, dict3)) + self.assertTrue(compare_mapping_data(dict1, dict4)) + self.assertFalse(compare_mapping_data(dict1, dict5)) + self.assertTrue(compare_mapping_data(dict1, dict6)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py index bece5380f0..cb69ac7723 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py @@ -1,7 +1,8 @@ import unittest from unittest.mock import patch +from unittest.mock import MagicMock from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator -from msprobe.pytorch.visualization.graph.graph import Graph +from msprobe.pytorch.visualization.graph.graph import Graph, BaseNode, NodeOp from msprobe.pytorch.visualization.utils import GraphConst @@ -30,3 +31,113 @@ class TestGraphComparator(unittest.TestCase): 'is_print_compare_log': True }) self.assertEqual(self.comparator.output_path, self.output_path) + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test_compare(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator._compare_nodes = MagicMock() + comparator._postcompare = MagicMock() + + comparator.compare() + + comparator._compare_nodes.assert_called_once() + comparator._postcompare.assert_called_once() + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test_add_compare_result_to_node(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + node = MagicMock() + compare_result_list = [("output1", "data1"), ("input1", "data2")] + + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.ma = MagicMock() + comparator.ma.prepare_real_data.return_value = True + + comparator.add_compare_result_to_node(node, compare_result_list) + comparator.ma.prepare_real_data.assert_called_once_with(node) + node.data.update.assert_not_called() + + @patch('msprobe.pytorch.visualization.graph.node_colors.NodeColors.get_node_error_status') + @patch('msprobe.pytorch.visualization.utils.get_csv_df') + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.run_real_data') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__postcompare(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode, + mock_run_real_data, mock_get_csv_df, mock_get_node_error_status): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + mock_df = MagicMock() + mock_df.iterrows = MagicMock(return_value=[(None, MagicMock())]) + mock_run_real_data.return_value = mock_df + mock_get_csv_df.return_value = mock_df + mock_get_node_error_status.return_value = True + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.ma = MagicMock() + comparator.ma.is_real_data_compare.return_value = True + comparator._handle_api_collection_index = MagicMock() + comparator.ma.compare_nodes = [MagicMock()] + comparator.ma.parse_result = MagicMock(return_value=(0.9, None)) + + comparator._postcompare() + + comparator._handle_api_collection_index.assert_called_once() + comparator.ma.add_error_key.assert_called() + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__handle_api_collection_index(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + apis = BaseNode(NodeOp.api_collection, 'Apis_Between_Modules.0') + api1 = BaseNode(NodeOp.function_api, 'Tensor.a.0') + api1.data = {GraphConst.JSON_INDEX_KEY: 0.9} + api2 = BaseNode(NodeOp.function_api, 'Tensor.b.0') + api2.data = {GraphConst.JSON_INDEX_KEY: 0.6} + apis.subnodes = [api1, api2] + sub_nodes = [BaseNode(NodeOp.module, 'Module.a.0'), apis, BaseNode(NodeOp.module, 'Module.a.1')] + comparator.graph_n.root.subnodes = sub_nodes + comparator._handle_api_collection_index() + self.assertEqual(comparator.graph_n.root.subnodes[1].data.get(GraphConst.JSON_INDEX_KEY), 0.6) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.compare_node') + @patch('msprobe.pytorch.visualization.graph.graph.Graph.match') + @patch('msprobe.pytorch.visualization.graph.graph.Graph.mapping_match') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__compare_nodes(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode, + mock_mapping_match, mock_match, mock_compare_node): + node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0') + node_b = BaseNode(NodeOp.function_api, 'Tensor.b.0') + mock_load_data_json_file.return_value = {} + mock_load_json_file.return_value = {} + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + mock_mapping_match.return_value = (node_b, [], []) + mock_compare_node.return_value = ['result'] + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.mapping_config = True + comparator._compare_nodes(node_n) + self.assertEqual(node_n.matched_node_link, ['Tensor.b.0']) + self.assertEqual(node_b.matched_node_link, ['Tensor.a.0']) + comparator.mapping_config = False + node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0') + node_b = BaseNode(NodeOp.function_api, 'Tensor.a.0') + mock_match.return_value = (node_b, []) + comparator._compare_nodes(node_n) + self.assertEqual(node_n.matched_node_link, ['Tensor.a.0']) + self.assertEqual(node_b.matched_node_link, ['Tensor.a.0']) + diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py index e136d3e152..da76d8e0d5 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py @@ -1,3 +1,4 @@ +import json import unittest from unittest.mock import patch, MagicMock from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter @@ -30,11 +31,27 @@ class TestModeAdapter(unittest.TestCase): self.assertEqual(precision_index, 0.5) self.assertEqual(other_dict, {}) + mock_mode_adapter._add_md5_compare_data.return_value = 1 + self.adapter.compare_mode = GraphConst.MD5_COMPARE + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) + self.assertEqual(precision_index, 1) + self.assertEqual(other_dict, {'Result': 'pass'}) + + mock_mode_adapter._add_real_compare_data.return_value = 0.6 + self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) + self.assertEqual(precision_index, 0.0) + self.assertEqual(other_dict, {}) + def test_prepare_real_data(self): self.adapter.is_real_data_compare = MagicMock(return_value=True) result = self.adapter.prepare_real_data(self.node) self.assertTrue(result) + self.adapter.is_real_data_compare = MagicMock(return_value=False) + result = self.adapter.prepare_real_data(self.node) + self.assertFalse(result) + def test_compare_mode_methods(self): self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE self.assertTrue(self.adapter.is_summary_compare()) @@ -52,8 +69,31 @@ class TestModeAdapter(unittest.TestCase): self.adapter.add_error_key(node_data) self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]) + node_data = {'key': {}} + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + self.adapter.add_error_key(node_data) + self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], + [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, + CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]) def test_get_tool_tip(self): self.adapter.compare_mode = GraphConst.MD5_COMPARE tips = self.adapter.get_tool_tip() - self.assertEqual(tips, {'md5': ToolTip.MD5}) + self.assertEqual(tips, json.dumps({'md5': ToolTip.MD5})) + + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + tips = self.adapter.get_tool_tip() + self.assertEqual(tips, json.dumps({ + CompareConst.MAX_DIFF: ToolTip.MAX_DIFF, + CompareConst.MIN_DIFF: ToolTip.MIN_DIFF, + CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF, + CompareConst.NORM_DIFF: ToolTip.NORM_DIFF})) + + self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE + tips = self.adapter.get_tool_tip() + self.assertEqual(tips, json.dumps({ + CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO, + CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.COSINE: ToolTip.COSINE, + CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR, + CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR})) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py index 544950f358..3f0f12582a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import patch from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp from msprobe.pytorch.visualization.utils import GraphConst @@ -24,6 +25,10 @@ class TestBaseNode(unittest.TestCase): self.node.get_suggestions() self.assertIn(GraphConst.SUGGEST_KEY, self.node.suggestions) + node = BaseNode(NodeOp.function_api, "up_node_1") + node.get_suggestions() + self.assertIn(GraphConst.SUGGEST_KEY, node.suggestions) + def test_set_input_output(self): input_data = {'input1': 'value1'} output_data = {'output1': 'value2'} @@ -55,10 +60,17 @@ class TestBaseNode(unittest.TestCase): 'upnode': self.up_node.id, 'subnodes': [], 'matched_node_link': [], - 'suggestions': {} + 'suggestions': {}, + 'stack_info': [] } self.assertEqual(self.node.to_dict(), expected_result) def test_get_ancestors(self): expected_ancestors = ['up_node_1'] self.assertEqual(self.node.get_ancestors(), expected_ancestors) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.compare_mapping_data') + def test_compare_mapping_node(self, mock_compare_mapping_data): + mock_compare_mapping_data.return_value = True + result = self.node.compare_mapping_node(BaseNode(NodeOp.function_api, "up_node_1")) + self.assertTrue(result) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py index 19d0987434..10a15c6c58 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py @@ -1,4 +1,5 @@ import unittest +from unittest.mock import MagicMock from msprobe.pytorch.visualization.graph.graph import Graph, NodeOp from msprobe.pytorch.visualization.graph.base_node import BaseNode from msprobe.pytorch.visualization.utils import GraphConst @@ -17,6 +18,15 @@ class TestGraph(unittest.TestCase): self.assertIsNotNone(node) self.assertIn(self.node_id, self.graph.node_map) + node_id = "api" + graph = Graph("model_name") + for i in range(0, 9): + graph.add_node(NodeOp.function_api, node_id, id_accumulation=True) + self.assertEqual(len(graph.node_map), 10) + self.assertIn("api.0", graph.node_map) + self.assertIn("api.8", graph.node_map) + self.assertNotIn("api", graph.node_map) + def test_to_dict(self): self.graph.add_node(self.node_op, self.node_id) result = self.graph.to_dict() @@ -38,6 +48,13 @@ class TestGraph(unittest.TestCase): self.assertIsNone(matched_node) self.assertEqual(ancestors, []) + graph_b.add_node(NodeOp.module, "node_id_a") + graph_a.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a")) + graph_b.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a")) + matched_node, ancestors = Graph.match(graph_a, graph_a.get_node("node_id_a_1"), graph_b) + self.assertIsNotNone(matched_node) + self.assertEqual(ancestors, ['node_id_a']) + def test_dfs(self): graph = Graph("model_name") graph.add_node(NodeOp.module, "node_a") @@ -48,3 +65,15 @@ class TestGraph(unittest.TestCase): self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {}, 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [], 'matched_node_link': [], 'suggestions': {}}}) + + def test_mapping_match(self): + mapping_config = MagicMock() + graph_a = Graph("model_name_a") + graph_b = Graph("model_name_b") + graph_a.add_node(NodeOp.module, "a1", BaseNode(NodeOp.module, "root")) + graph_b.add_node(NodeOp.module, "b1", BaseNode(NodeOp.module, "root")) + mapping_config.get_mapping_string.return_value = "b1" + node_b, ancestors_n, ancestors_b = Graph.mapping_match(graph_a.get_node("a1"), graph_b, mapping_config) + self.assertIsNotNone(node_b) + self.assertEqual(ancestors_n, ["root"]) + self.assertEqual(ancestors_b, ["root"]) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml new file mode 100644 index 0000000000..8b2f85ebf8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml @@ -0,0 +1,2 @@ +- vision_model: "language_model.vision_encoder" +- vision_projection: "language_model.projection" \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py new file mode 100644 index 0000000000..010a4f6861 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py @@ -0,0 +1,52 @@ +import os +import unittest +from msprobe.pytorch.visualization.mapping_config import MappingConfig + + +class TestMappingConfig(unittest.TestCase): + + def setUp(self): + self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") + + def test_validate(self): + with self.assertRaises(ValueError): + MappingConfig.validate(123, "some value") + with self.assertRaises(ValueError): + MappingConfig.validate("some key", 456) + self.assertEqual(MappingConfig.validate("key", "value"), "value") + + def test_convert_to_regex(self): + regex = MappingConfig.convert_to_regex("hello{world}") + self.assertEqual(regex, ".*hello\\{world\\}.*") + + def test_replace_parts(self): + result = MappingConfig._replace_parts('hello world', 'world', 'everyone') + self.assertEqual(result, 'hello everyone') + result = MappingConfig._replace_parts('radio_model.layers.0.input_norm', 'radio_model.layers.{}.input_norm', + 'radio_model.transformer.layers.{}.input_layernorm') + self.assertEqual(result, 'radio_model.transformer.layers.0.input_layernorm') + + def test_get_mapping_string(self): + mc = MappingConfig(self.yaml_path) + mc.classify_config = { + 'category1': [('category1.key1', 'replacement1')], + 'category2': [('category2.key1', 'replacement2')] + } + result = mc.get_mapping_string("some category1.key1 text") + self.assertEqual(result, "some replacement1 text") + + def test_long_string(self): + long_string = "x" * (MappingConfig.MAX_STRING_LEN + 1) + mc = MappingConfig(self.yaml_path) + result = mc.get_mapping_string(long_string) + self.assertEqual(result, long_string) + + def test__classify_and_sort_keys(self): + mc = MappingConfig(self.yaml_path) + result = mc._classify_and_sort_keys() + self.assertEqual(result, {'vision_model': [('vision_model', 'language_model.vision_encoder')], + 'vision_projection': [('vision_projection', 'language_model.projection')]}) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py new file mode 100644 index 0000000000..56493235cd --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py @@ -0,0 +1,27 @@ +import os +import unittest +from msprobe.pytorch.visualization.utils import (load_json_file, load_data_json_file, str2float) + + +class TestMappingConfig(unittest.TestCase): + + def setUp(self): + self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") + + def test_load_json_file(self): + result = load_json_file(self.yaml_path) + self.assertEqual(result, {}) + + def test_load_data_json_file(self): + result = load_data_json_file(self.yaml_path) + self.assertEqual(result, {}) + + def test_str2float(self): + result = str2float('23.4%') + self.assertAlmostEqual(result, 0.234) + result = str2float('2.3.4%') + self.assertAlmostEqual(result, 0) + + +if __name__ == '__main__': + unittest.main() -- Gitee