diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py index 2465cf1ae0dcc62aaaa2f82c3c864107245a01fa..d275416d100a4430c88bfbc945129df9a600a244 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -52,12 +52,12 @@ class GraphComparator: compare_in_dict[item[0]] = item precision_status, precision_index, other_dict = ( self.ma.parse_result(node, [compare_in_dict, compare_out_dict])) - node.data[GraphConst.JSON_STATUS_KEY] = precision_status - node.data[GraphConst.JSON_INDEX_KEY] = precision_index - node.data.update(other_dict) + node.update_data({GraphConst.JSON_STATUS_KEY: precision_status}) + node.update_data({GraphConst.JSON_INDEX_KEY: precision_index}) + node.update_data(other_dict) if not precision_status: - self.ma.add_error_key(node.output_data) - node.get_suggestions() + self.ma.add_error_key(node.get_output_data()) + node.add_suggestions() def _parse_param(self, data_paths, stack_path, output_path): self.dump_path_param = { @@ -81,11 +81,11 @@ class GraphComparator: compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} for node in self.ma.compare_nodes: precision_status, precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) - node.data[GraphConst.JSON_STATUS_KEY] = precision_status - node.data[GraphConst.JSON_INDEX_KEY] = precision_index + node.update_data({GraphConst.JSON_STATUS_KEY: precision_status}) + node.update_data({GraphConst.JSON_INDEX_KEY: precision_index}) if not precision_status: - self.ma.add_error_key(node.output_data) - node.get_suggestions() + self.ma.add_error_key(node.get_output_data()) + node.add_suggestions() def _compare_nodes(self, node_n): #递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py index b3f286c28f2d4c1393d21fd14cb17be4ba7ce60d..d3d181286ea81c42cb8186d64d06734e96f7b0e2 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py @@ -119,21 +119,21 @@ class ModeAdapter: """ other_dict = {} if self.is_md5_compare(): - precision_status_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict[0]) - precision_status_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict[1]) + precision_status_in = ModeAdapter._add_md5_compare_data(node.get_input_data(), compare_data_dict[0]) + precision_status_out = ModeAdapter._add_md5_compare_data(node.get_output_data(), compare_data_dict[1]) # 所有输入输出md5对比通过,这个节点才算通过 precision_status = precision_status_in and precision_status_out precision_index = 1 if precision_status else 0 other_result = CompareConst.PASS if precision_status else CompareConst.DIFF other_dict[CompareConst.RESULT] = other_result elif self.is_summary_compare(): - precision_status_in, precision_index_in = ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict[0]) - precision_status_out, precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict[1]) + precision_status_in, precision_index_in = ModeAdapter._add_summary_compare_data(node.get_input_data(), compare_data_dict[0]) + precision_status_out, precision_index_out = ModeAdapter._add_summary_compare_data(node.get_output_data(), compare_data_dict[1]) precision_status = precision_status_in and precision_status_out precision_index = min(precision_index_in, precision_index_out) else: - min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict[0]) - min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict[0]) + min_thousandth_in = ModeAdapter._add_real_compare_data(node.get_input_data(), compare_data_dict[0]) + min_thousandth_out = ModeAdapter._add_real_compare_data(node.get_output_data(), compare_data_dict[0]) if min_thousandth_in and min_thousandth_out: change_percentage = abs(min_thousandth_in - min_thousandth_out) else: diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py index e8c86e243e0ce2f34363f55e39ab344e989e205a..76de28a622643945c3c7ca807c28a87a096787de 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py @@ -14,6 +14,7 @@ # limitations under the License. from msprobe.pytorch.visualization.graph.node_op import NodeOp +from msprobe.pytorch.visualization.graph.multi_data import MultiData from msprobe.pytorch.visualization.utils import Suggestions, GraphConst from msprobe.pytorch.visualization.builder.msprobe_adapter import format_node_data, compare_data @@ -22,14 +23,11 @@ class BaseNode: def __init__(self, node_op, node_id, up_node=None): self.op = node_op self.id = node_id - self.data = {} - self.output_data = {} - self.input_data = {} self.upnode = None self.add_upnode(up_node) self.subnodes = [] self.matched_node_link = [] - self.suggestions = {} + self.multi_data = MultiData() def __str__(self): info = f'id:\t{self.id}' @@ -39,26 +37,27 @@ class BaseNode: """ 用来判断两个节点是否可以被匹配上,认为结构上是否一致 """ - if not compare_data(self.input_data, other.input_data): + if not compare_data(self.get_input_data(), other.get_input_data()): return False - if not compare_data(self.output_data, other.output_data): + if not compare_data(self.get_output_data(), other.get_output_data()): return False return True - def get_suggestions(self): + def add_suggestions(self): """ 精度疑似有问题时,提供一些建议 """ + suggestions = self.get_suggestions() if self.op == NodeOp.module: - self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module - self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL + suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module + suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL elif self.op == NodeOp.function_api: - self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API - self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL + suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API + suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL def set_input_output(self, input_data, output_data): - self.input_data = input_data - self.output_data = output_data + self.set_input_data(input_data) + self.set_output_data(output_data) def add_upnode(self, node): """ @@ -86,13 +85,15 @@ class BaseNode: result = {} result['id'] = self.id result['node_type'] = self.op.value - result['data'] = self.data - result['output_data'] = format_node_data(self.output_data) - result['input_data'] = format_node_data(self.input_data) + result['data'] = self.get_data_dict() + output_data = self.get_output_data_dict() + result['output_data'] = [format_node_data(data) for data in output_data] + input_data = self.get_input_data_dict() + result['input_data'] = [format_node_data(data) for data in input_data] result['upnode'] = self.upnode.id if self.upnode else 'None' result['subnodes'] = [node.id for node in self.subnodes] result['matched_node_link'] = self.matched_node_link - result['suggestions'] = self.suggestions + result['suggestions'] = self.get_suggestions_dict() return result def get_ancestors(self): @@ -105,3 +106,38 @@ class BaseNode: ancestors.append(current_node.id) current_node = current_node.upnode return list(reversed(ancestors)) + + # 提供封装好的四个数据获取,方便外部调用 + def get_input_data(self): + return self.multi_data.get_data(GraphConst.MULTI_INPUT_KEY) + + def get_output_data(self): + return self.multi_data.get_data(GraphConst.MULTI_OUTPUT_KEY) + + def get_suggestions(self): + return self.multi_data.get_data(GraphConst.MULTI_SUGGEST_KEY) + + def get_data(self): + return self.multi_data.get_data(GraphConst.MULTI_DATA_KEY) + + def set_input_data(self, value): + return self.multi_data.set_data(GraphConst.MULTI_INPUT_KEY, value) + + def set_output_data(self, value): + return self.multi_data.set_data(GraphConst.MULTI_OUTPUT_KEY, value) + + def update_data(self, new_dict): + data = self.get_data() + data.update(new_dict) + + def get_input_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_INPUT_KEY) + + def get_output_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_OUTPUT_KEY) + + def get_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_DATA_KEY) + + def get_suggestions_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_SUGGEST_KEY) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py new file mode 100644 index 0000000000000000000000000000000000000000..0b628c3332fe8b85253bbc4fd40e6f23e37627c8 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.pytorch.visualization.utils import GraphConst + + +class MultiData: + def __init__(self): + self._data = [] + self.set_repeat_time(1) + + def set_repeat_time(self, repeat_time): + if repeat_time <= 0: + return + self._repeat_time = repeat_time + self._data.clear() + self._data = [{ + GraphConst.MULTI_DATA_KEY: {}, + GraphConst.MULTI_OUTPUT_KEY: {}, + GraphConst.MULTI_INPUT_KEY: {}, + GraphConst.MULTI_SUGGEST_KEY: {}, + } for _ in range(repeat_time)] + self.set_index(0) + + def get_repeat_time(self): + return self._repeat_time + + def set_index(self, index): + if index < 0 or index >= self._repeat_time: + raise Exception('The index is out of range.') + self._index = index + + def get_data(self, key): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + return self._data[self._index].get(key, {}) + + def set_data(self, key, value): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + self._data[self._index][key] = value + + def to_dict(self, key): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + return [data[key] for data in self._data] diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py index a1b72ee850041df6200f4cd2cc7624a59b29756f..6d607804e31c02f51b9c7814cbe5efdbb7ba84a5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py @@ -127,3 +127,7 @@ class GraphConst: SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF, CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] + MULTI_DATA_KEY = 'data' + MULTI_OUTPUT_KEY = 'output_data' + MULTI_INPUT_KEY = 'input_data' + MULTI_SUGGEST_KEY = 'suggestions' diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py index 66eceea4b2a1ccf48ac95491c1a2cdca718a403a..4cc627704802127913e1af088930e7b58e1143b8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py @@ -48,5 +48,5 @@ class TestGraphBuilder(unittest.TestCase): data_dict = {"node1": {}} node = GraphBuilder._create_or_get_node(self.graph, data_dict, node_op, "node1") self.assertIn("node1", self.graph.node_map) - self.assertEqual(node.input_data, {}) - self.assertEqual(node.output_data, {}) + self.assertEqual(node.get_input_data(), {}) + self.assertEqual(node.get_output_data(), {}) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py index bece5380f04836a232a8c154a606c1cb68759b1c..80f9e742b11edec17dee7901def80a67282cf9d9 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py @@ -2,6 +2,8 @@ import unittest from unittest.mock import patch from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator from msprobe.pytorch.visualization.graph.graph import Graph +from msprobe.pytorch.visualization.graph.base_node import BaseNode +from msprobe.pytorch.visualization.graph.node_op import NodeOp from msprobe.pytorch.visualization.utils import GraphConst @@ -30,3 +32,17 @@ class TestGraphComparator(unittest.TestCase): 'is_print_compare_log': True }) self.assertEqual(self.comparator.output_path, self.output_path) + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test_bad_precision_status(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + node = BaseNode(NodeOp.function_api, 'test_api') + self.comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + self.comparator.ma.add_error_key(node.get_output_data()) + node.add_suggestions() + suggestions = node.get_suggestions() + self.assertEqual(suggestions['text'], '此api精度比对结果疑似异常,请使用msprobe工具的预检功能对api进行精度检测') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py index 7883a09a34115132ac2b8b217de434e32e58c279..fd63e973b5c757da4893ae0459062c3066a16f4d 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py @@ -1,4 +1,5 @@ import unittest +import json from unittest.mock import patch, MagicMock from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp @@ -58,4 +59,4 @@ class TestModeAdapter(unittest.TestCase): def test_get_tool_tip(self): self.adapter.compare_mode = GraphConst.MD5_COMPARE tips = self.adapter.get_tool_tip() - self.assertEqual(tips, {'md5': ToolTip.MD5}) + self.assertEqual(tips, json.dumps({'md5': ToolTip.MD5})) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py index 544950f35881e19eb449a138a4b0937ca91eb1d7..1115a84cec1f98c3b153966f403ed7b688b6a2cc 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py @@ -21,15 +21,15 @@ class TestBaseNode(unittest.TestCase): self.assertEqual(self.node, other_node) def test_get_suggestions(self): - self.node.get_suggestions() - self.assertIn(GraphConst.SUGGEST_KEY, self.node.suggestions) + self.node.add_suggestions() + self.assertIn(GraphConst.SUGGEST_KEY, self.node.get_suggestions()) def test_set_input_output(self): input_data = {'input1': 'value1'} output_data = {'output1': 'value2'} self.node.set_input_output(input_data, output_data) - self.assertEqual(self.node.input_data, input_data) - self.assertEqual(self.node.output_data, output_data) + self.assertEqual(self.node.get_input_data(), input_data) + self.assertEqual(self.node.get_output_data(), output_data) def test_add_upnode(self): self.node = BaseNode(self.node_op, self.node_id) @@ -49,13 +49,13 @@ class TestBaseNode(unittest.TestCase): expected_result = { 'id': self.node_id, 'node_type': self.node_op.value, - 'data': {}, - 'output_data': {}, - 'input_data': {}, + 'data': [{}], + 'output_data': [{}], + 'input_data': [{}], 'upnode': self.up_node.id, 'subnodes': [], 'matched_node_link': [], - 'suggestions': {} + 'suggestions': [{}] } self.assertEqual(self.node.to_dict(), expected_result) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py index 19d098743458a61d13146b6da1b65098f90171b7..0811851b1820d535dc8d54143d51885ca505eee8 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py @@ -45,6 +45,6 @@ class TestGraph(unittest.TestCase): node_a = BaseNode(self.node_op, self.node_id) result = {} graph.dfs(node_a, result) - self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {}, - 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [], - 'matched_node_link': [], 'suggestions': {}}}) + self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': [{}], + 'output_data': [{}], 'input_data': [{}], 'upnode': 'None', 'subnodes': [], + 'matched_node_link': [], 'suggestions': [{}]}}) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_multi_data.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_multi_data.py new file mode 100644 index 0000000000000000000000000000000000000000..99c37321f40aad7c6b5f753de561ee7e324ab493 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_multi_data.py @@ -0,0 +1,61 @@ +import unittest +import pytest +from msprobe.pytorch.visualization.graph.multi_data import MultiData +from msprobe.pytorch.visualization.utils import GraphConst + +class TestMultiData(unittest.TestCase): + + def setUp(self): + self.multi_data = MultiData() + self.key_list = [GraphConst.MULTI_DATA_KEY, GraphConst.MULTI_OUTPUT_KEY, + GraphConst.MULTI_INPUT_KEY, GraphConst.MULTI_SUGGEST_KEY] + + def test_init(self): + self.assertEqual(self.multi_data._repeat_time, 1) + for key in self.key_list: + self.assertEqual(self.multi_data._data[0][key], {}) + + def test_independence_in_repeats(self): + test_key = self.key_list[1] + self.multi_data.set_repeat_time(3) + # index0下写入数据 + self.multi_data.set_data(test_key, {'index0':1}) + self.assertEqual(self.multi_data.get_data(test_key), {'index0':1}) + # 切换到index2下检查数据,写入数据 + self.multi_data.set_index(2) + self.assertEqual(self.multi_data.get_data(test_key), {}) + self.multi_data.set_data(test_key, {'index2':1}) + self.assertEqual(self.multi_data.get_data(test_key), {'index2':1}) + # 切换会index0,检查数据独立性 + self.multi_data.set_index(0) + self.assertEqual(self.multi_data.get_data(test_key), {'index0':1}) + + def test_bad_case(self): + self.assertNotEqual(self.multi_data._repeat_time, 5) + # 正确的重复次数设置 + self.multi_data.set_repeat_time(5) + self.assertEqual(self.multi_data._repeat_time, 5) + # 不正确的重复次数设置 + self.multi_data.set_repeat_time(0) + self.assertEqual(self.multi_data._repeat_time, 5) + # 正确的index设置 + self.assertNotEqual(self.multi_data._index, 2) + self.multi_data.set_index(2) + self.assertEqual(self.multi_data._index, 2) + # 不正确的index设置 + with pytest.raises(Exception): + self.multi_data.set_index(5) + self.assertEqual(self.multi_data._index, 2) + + def test_data_output(self): + test_key = self.key_list[1] + self.multi_data.set_repeat_time(3) + # 数据写入 + self.multi_data.set_data(test_key, {'index0':1}) + self.multi_data.set_index(1) + self.multi_data.set_data(test_key, {'index1':1}) + self.multi_data.set_index(2) + self.multi_data.set_data(test_key, {'index2':1}) + # 批量数据输出 + result = self.multi_data.to_dict(test_key) + self.assertEqual(result, [{'index0':1}, {'index1':1}, {'index2':1}])