diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py index 7d5c156030df580cb1beb10b21bb03c593a9348e..157ac7a6defd2a1b39fee655fae3e1e1fe04850c 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py @@ -38,7 +38,7 @@ class GraphBuilder: return graph @staticmethod - def to_json(filename, graph_n, graph_b=None, tool_tip=None): + def to_json(filename, graph_n, graph_b=None, tool_tip=None, node_colors=None): """ 将graph导出成.vis文件的接口 Args: @@ -46,6 +46,7 @@ class GraphBuilder: graph_n: Graph graph_b: bench Graph,为空是只输出graph_b,不为空会同时输出两个graph,作为对比的结果 tool_tip: 在对比模型下输出的意见 + node_colors: 在对比模型下节点的颜色说明 """ result = {} if graph_b: @@ -55,6 +56,8 @@ class GraphBuilder: result = graph_n.to_dict() if tool_tip: result[GraphConst.JSON_TIP_KEY] = tool_tip + if node_colors: + result[GraphConst.COLORS] = node_colors save_json_file(filename, result) @staticmethod diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py index 86ba1b1f869983daaccbc7642dfa560d9d92f75c..84ee8189e78037a1f939f82dfb59485076996920 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -16,6 +16,7 @@ from msprobe.pytorch.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data from msprobe.pytorch.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df from msprobe.pytorch.visualization.graph.graph import Graph, NodeOp +from msprobe.pytorch.visualization.graph.node_colors import NodeColors from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter @@ -50,11 +51,11 @@ class GraphComparator: compare_out_dict[item[0]] = item else: compare_in_dict[item[0]] = item - precision_status, precision_index, other_dict = ( + precision_index, other_dict = ( self.ma.parse_result(node, [compare_in_dict, compare_out_dict])) node.data[GraphConst.JSON_INDEX_KEY] = precision_index node.data.update(other_dict) - if not precision_status: + if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index): self.ma.add_error_key(node.output_data) node.get_suggestions() @@ -80,9 +81,9 @@ class GraphComparator: df = run_real_data(self.dump_path_param, df) compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} for node in self.ma.compare_nodes: - precision_status, precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) + precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) node.data[GraphConst.JSON_INDEX_KEY] = precision_index - if not precision_status: + if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index): self.ma.add_error_key(node.output_data) node.get_suggestions() diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py index f73824bcd0fbca7a422e12eec9000dc556dcdd4a..746b67575943b66addc0e9f325eed3fd4590f2c5 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py @@ -26,7 +26,7 @@ class ModeAdapter: @staticmethod def _add_md5_compare_data(node_data, compare_data_dict): - precision_status = True + precision_index = 0 for key, value in node_data.items(): if not isinstance(value, dict): continue @@ -37,9 +37,9 @@ class ModeAdapter: ModeAdapter._match_data(value, compare_data, GraphConst.MD5_INDEX_LIST, id_list) # md5比对是否通过 if value.get(CompareConst.RESULT) != CompareConst.PASS: - precision_status = False + precision_index = 1 node_data[key] = value - return precision_status + return precision_index @staticmethod def _add_real_compare_data(node_data, compare_data_dict): @@ -72,7 +72,6 @@ class ModeAdapter: @staticmethod def _add_summary_compare_data( node_data, compare_data_dict): - precision_status = True max_relative_err = 0 for key, value in node_data.items(): if not isinstance(value, dict): @@ -93,11 +92,8 @@ class ModeAdapter: relative_err = str2float(value.get(item)) max_relative_err = max(max_relative_err, relative_err) node_data[key] = value - if max_relative_err > GraphConst.MAX_RELATIVE_ERR_TH: - precision_status = False max_relative_err = 1 if max_relative_err > 1 else max_relative_err - precision_index = 1 - max_relative_err - return precision_status, precision_index + return max_relative_err @staticmethod def _match_data(data_dict, compare_data, key_list, id_list): @@ -115,22 +111,20 @@ class ModeAdapter: def parse_result(self, node, compare_data_dict): """ - 根据结果返回数据,分别是precision_status,precision_index,和附加数据 + 根据结果返回数据,分别是precision_index,和附加数据 """ other_dict = {} if self.is_md5_compare(): - precision_status_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict[0]) - precision_status_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict[1]) + precision_index_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict[0]) + precision_index_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict[1]) # 所有输入输出md5对比通过,这个节点才算通过 - precision_status = precision_status_in and precision_status_out - precision_index = 1 if precision_status else 0 - other_result = CompareConst.PASS if precision_status else CompareConst.DIFF + precision_index = max(precision_index_in, precision_index_out) + other_result = CompareConst.PASS if precision_index == 1 else CompareConst.DIFF other_dict[CompareConst.RESULT] = other_result elif self.is_summary_compare(): - precision_status_in, precision_index_in = ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict[0]) - precision_status_out, precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict[1]) - precision_status = precision_status_in and precision_status_out - precision_index = min(precision_index_in, precision_index_out) + precision_index_in = ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict[0]) + precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict[1]) + precision_index = max(precision_index_in, precision_index_out) else: min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict[0]) min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict[0]) @@ -138,11 +132,8 @@ class ModeAdapter: change_percentage = abs(min_thousandth_in - min_thousandth_out) else: change_percentage = 0 - precision_status = True - if change_percentage > GraphConst.REAL_DATA_TH: - precision_status = False - precision_index = 0 if change_percentage > 1 else 1 - change_percentage - return precision_status, precision_index, other_dict + precision_index = 1 if change_percentage > 1 else change_percentage + return precision_index, other_dict def prepare_real_data(self, node): """ diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4f7b609ad088fc53a76ad9d8a88435cf628560 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py @@ -0,0 +1,94 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from msprobe.pytorch.visualization.utils import GraphConst + +SUMMARY_DESCRIPTION = "此节点所有输入输出的统计量相对误差." +REAL_DATA_DESCRIPTION = "此节点所有输入的最小双千分之一和所有输出的最小双千分之一的差值的绝对值, 代表双千指标的变化情况." +MD5_DESCRIPTION_N = "此节点任意输入输出的md5值不同." +MD5_DESCRIPTION_Y = "此节点所有输入输出的md5值相同." +NOT_MATCHED = "比对过程中节点未匹配上." + + +class NodeColors(Enum): + # 枚举值后缀数字越小, 颜色越浅 + # value值左闭右开, 两个值相同代表固定值 + YELLOW_1 = ("#FFFCF3", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0, 0.2], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0, 0.1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}, + GraphConst.MD5_COMPARE: {GraphConst.VALUE: [1, 1], GraphConst.DESCRIPTION: MD5_DESCRIPTION_Y}, + }) + YELLOW_2 = ("#FFEDBE", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.2, 0.4], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.1, 0.2], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + ORANGE_1 = ("#FFDC7F", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.4, 0.6], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.2, 0.3], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + ORANGE_2 = ("#FFC62E", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.6, 0.8], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.3, 0.4], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + RED = ("#E32020", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.8, 1], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.4, 1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}, + GraphConst.MD5_COMPARE: {GraphConst.VALUE: [0, 0], GraphConst.DESCRIPTION: MD5_DESCRIPTION_N}, + }) + GREY = ("#C7C7C7", { + GraphConst.VALUE: [], GraphConst.DESCRIPTION: NOT_MATCHED + }) + + def __init__(self, hex_value, mode_info): + self.hex_value = hex_value + self.mode_info = mode_info + + @staticmethod + def get_node_colors(mode): + """ + 获取不同比对模式下的颜色说明 + Args: + mode: 比对模式 + Returns: 颜色说明 + """ + return { + color.hex_value: color.get_info_by_mode(mode) for color in NodeColors if color.get_info_by_mode(mode) + } + + @staticmethod + def get_node_error_status(mode, value): + """ + 判断精度数据比对指标是否大于基准值 + Args: + mode: 比对模式 + value: 精度数据比对指标 + Returns: bool + """ + info = NodeColors.ORANGE_1.get_info_by_mode(mode) + if info and GraphConst.VALUE in info: + value_range = info[GraphConst.VALUE] + return value > value_range[0] + return False + + def get_info_by_mode(self, mode): + if isinstance(self.mode_info, dict): + # 检查是否是模式特定的信息 + if isinstance(next(iter(self.mode_info.values())), dict): + return self.mode_info.get(mode, {}) + else: + # 所有模式共享相同的信息 + return self.mode_info + return {} diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py index 0e4fdab7cf832725ae53978c0049e52a6e08ee13..eabc50999746f3187adc313331d15583e0e627b8 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py @@ -18,6 +18,7 @@ import time from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator from msprobe.pytorch.visualization.utils import GraphConst from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder +from msprobe.pytorch.visualization.graph.node_colors import NodeColors from msprobe.core.common.log import logger current_time = time.strftime("%Y%m%d%H%M%S") @@ -38,7 +39,8 @@ def compare_graph(dump_path_n, dump_path_b, out_path, model_name='Model'): graph_comparator = GraphComparator([graph_n, graph_b], [data_path_n, data_path_b], stack_path, out_path) graph_comparator.compare() output_path = os.path.join(out_path, f'compare_{current_time}.vis') - GraphBuilder.to_json(output_path, graph_n, graph_b, graph_comparator.ma.get_tool_tip()) + GraphBuilder.to_json(output_path, graph_n, graph_b, graph_comparator.ma.get_tool_tip(), + NodeColors.get_node_colors(graph_comparator.ma.compare_mode)) logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}') diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py index 4699215866784ab01d61aae0e35e8650236979ef..f7cbaa58b7a810074a5a5776b0bb6703e1de2e41 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py @@ -110,7 +110,6 @@ class GraphConst: REAL_DATA_TH = 0.1 MAX_RELATIVE_ERR_TH = 0.5 ROUND_TH = 6 - JSON_STATUS_KEY = 'precision_status' JSON_INDEX_KEY = 'precision_index' SUGGEST_KEY = 'text' TAG_NA = 'na' @@ -131,3 +130,5 @@ class GraphConst: NULL = 'null' NONE = 'None' VALUE = 'value' + DESCRIPTION = 'description' + COLORS = 'Colors' diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py index 7883a09a34115132ac2b8b217de434e32e58c279..e136d3e1521ce83842b846ce103d9eb043cae13a 100644 --- a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py @@ -19,16 +19,14 @@ class TestModeAdapter(unittest.TestCase): def test_add_md5_compare_data(self): node_data = {'md5_key': 'some_md5_value'} compare_data_dict = {'md5_key': 'expected_md5_value'} - precision_status = ModeAdapter._add_md5_compare_data(node_data, compare_data_dict) - self.assertTrue(precision_status) + precision_index = ModeAdapter._add_md5_compare_data(node_data, compare_data_dict) + self.assertEqual(precision_index, 0) @patch('msprobe.pytorch.visualization.compare.mode_adapter.ModeAdapter') def test_parse_result(self, mock_mode_adapter): - mock_mode_adapter._add_summary_compare_data.return_value = (True, 0.5) + mock_mode_adapter._add_summary_compare_data.return_value = 0.5 self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE - precision_status, precision_index, other_dict = self.adapter.parse_result( - self.node, self.compare_data_dict) - self.assertEqual(precision_status, True) + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) self.assertEqual(precision_index, 0.5) self.assertEqual(other_dict, {})