diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py index 3d5f2972468adab8a436167d2f50eab9ace05873..28fec165d8042e98e4c5be4583b010b6a070fdfb 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -51,12 +51,12 @@ class GraphComparator: else: compare_in_dict[item[0]] = item precision_status, precision_index, other_dict = self.ma.parse_result(node, [compare_in_dict, compare_out_dict]) - node.data[GraphConst.JSON_STATUS_KEY] = precision_status - node.data[GraphConst.JSON_INDEX_KEY] = precision_index - node.data.update(other_dict) + node.update_data({GraphConst.JSON_STATUS_KEY: precision_status}) + node.update_data({GraphConst.JSON_INDEX_KEY: precision_index}) + node.update_data(other_dict) if not precision_status: - self.ma.add_error_key(node.output_data) - node.get_suggestions() + self.ma.add_error_key(node.get_output_data()) + node.add_suggestions() def _parse_param(self, data_paths, stack_path, output_path): self.dump_path_param = { @@ -80,11 +80,11 @@ class GraphComparator: compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} for node in self.ma.compare_nodes: precision_status, precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) - node.data[GraphConst.JSON_STATUS_KEY] = precision_status - node.data[GraphConst.JSON_INDEX_KEY] = precision_index + node.update_data({GraphConst.JSON_STATUS_KEY: precision_status}) + node.update_data({GraphConst.JSON_INDEX_KEY: precision_index}) if not precision_status: - self.ma.add_error_key(node.output_data) - node.get_suggestions() + self.ma.add_error_key(node.get_output_data()) + node.add_suggestions() def _compare_nodes(self, node_n): #递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py index d58f2078b6f8996a31c2f830ef5adf79bc7948c3..de9b6c52c2cb4096cf60b920ec2e59d1e14795f6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py @@ -124,21 +124,21 @@ class ModeAdapter: """ other_dict = {} if self.is_md5_compare(): - precision_status_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict[0]) - precision_status_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict[1]) + precision_status_in = ModeAdapter._add_md5_compare_data(node.get_input_data(), compare_data_dict[0]) + precision_status_out = ModeAdapter._add_md5_compare_data(node.get_output_data(), compare_data_dict[1]) # 所有输入输出md5对比通过,这个节点才算通过 precision_status = precision_status_in and precision_status_out precision_index = 1 if precision_status else 0 other_result = CompareConst.PASS if precision_status else CompareConst.DIFF other_dict[GraphConst.JSON_MD5_KEY] = other_result elif self.is_summary_compare(): - precision_status_in, precision_index_in = ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict[0]) - precision_status_out, precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict[1]) + precision_status_in, precision_index_in = ModeAdapter._add_summary_compare_data(node.get_input_data(), compare_data_dict[0]) + precision_status_out, precision_index_out = ModeAdapter._add_summary_compare_data(node.get_output_data(), compare_data_dict[1]) precision_status = precision_status_in and precision_status_out precision_index = min(precision_index_in, precision_index_out) else: - min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict[0]) - min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict[0]) + min_thousandth_in = ModeAdapter._add_real_compare_data(node.get_input_data(), compare_data_dict[0]) + min_thousandth_out = ModeAdapter._add_real_compare_data(node.get_output_data(), compare_data_dict[0]) if min_thousandth_in and min_thousandth_out: change_percentage = abs(min_thousandth_in - min_thousandth_out) else: diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py index f04f367f591244a6d1ed48529d1fb4aae7cb2453..0f504fe3cc1ebe3c2bb898198c4c0cd86f0bcb47 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py @@ -14,6 +14,7 @@ # limitations under the License. from .node_op import NodeOp +from .multi_data import MultiData from ..utils import Suggestions, GraphConst from ..builder.msprobe_adapter import format_node_data, compare_data @@ -22,14 +23,11 @@ class BaseNode: def __init__(self, node_op, node_id, up_node=None): self.op = node_op self.id = node_id - self.data = {} - self.output_data = {} - self.input_data = {} self.upnode = None self.add_upnode(up_node) self.subnodes = [] self.matched_node_link = [] - self.suggestions = {} + self.multi_data = MultiData() def __str__(self): info = f'id:\t{self.id}' @@ -39,26 +37,27 @@ class BaseNode: """ 用来判断两个节点是否可以被匹配上,认为结构上是否一致 """ - if not compare_data(self.input_data, other.input_data): + if not compare_data(self.get_input_data(), other.get_input_data()): return False - if not compare_data(self.output_data, other.output_data): + if not compare_data(self.get_output_data(), other.get_output_data()): return False return True - def get_suggestions(self): + def add_suggestions(self): """ 精度疑似有问题时,提供一些建议 """ + suggestions = self.get_suggestions() if self.op == NodeOp.module: - self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module - self.suggestions[Suggestions.PTDBG] = Suggestions.PTDBG_URL + suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module + suggestions[Suggestions.PTDBG] = Suggestions.PTDBG_URL elif self.op == NodeOp.function_api: - self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API - self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL + suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API + suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL def set_input_output(self, input_data, output_data): - self.input_data = input_data - self.output_data = output_data + self.set_input_data(input_data) + self.set_output_data(output_data) def add_upnode(self, node): """ @@ -86,13 +85,15 @@ class BaseNode: result = {} result['id'] = self.id result['node_type'] = self.op.value - result['data'] = self.data - result['output_data'] = format_node_data(self.output_data) - result['input_data'] = format_node_data(self.input_data) + result['data'] = self.get_data_dict() + output_data = self.get_output_data_dict() + result['output_data'] = [format_node_data(data) for data in output_data] + input_data = self.get_input_data_dict() + result['input_data'] = [format_node_data(data) for data in input_data] result['upnode'] = self.upnode.id if self.upnode else 'None' result['subnodes'] = [node.id for node in self.subnodes] result['matched_node_link'] = self.matched_node_link - result['suggestions'] = self.suggestions + result['suggestions'] = self.get_suggestions_dict() return result def get_ancestors(self): @@ -105,3 +106,38 @@ class BaseNode: ancestors.append(current_node.id) current_node = current_node.upnode return list(reversed(ancestors)) + + # 提供封装好的四个数据获取,方便外部调用 + def get_input_data(self): + return self.multi_data.get_data(GraphConst.MULTI_INPUT_KEY) + + def get_output_data(self): + return self.multi_data.get_data(GraphConst.MULTI_OUTPUT_KEY) + + def get_suggestions(self): + return self.multi_data.get_data(GraphConst.MULTI_SUGGEST_KEY) + + def get_data(self): + return self.multi_data.get_data(GraphConst.MULTI_DATA_KEY) + + def set_input_data(self, value): + return self.multi_data.set_data(GraphConst.MULTI_INPUT_KEY, value) + + def set_output_data(self, value): + return self.multi_data.set_data(GraphConst.MULTI_OUTPUT_KEY, value) + + def update_data(self, new_dict): + data = self.get_data() + data.update(new_dict) + + def get_input_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_INPUT_KEY) + + def get_output_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_OUTPUT_KEY) + + def get_data_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_DATA_KEY) + + def get_suggestions_dict(self): + return self.multi_data.to_dict(GraphConst.MULTI_SUGGEST_KEY) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py new file mode 100644 index 0000000000000000000000000000000000000000..fe177468b877ad3361c3265180555e0baaf7518f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/multi_data.py @@ -0,0 +1,58 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from ..utils import GraphConst + + +class MultiData: + def __init__(self): + self._data = [] + self.set_repeat_time(1) + + def set_repeat_time(self, repeat_time): + if repeat_time <= 0: + return + self._repeat_time = repeat_time + self._data.clear() + self._data = [{ + GraphConst.MULTI_DATA_KEY: {}, + GraphConst.MULTI_OUTPUT_KEY: {}, + GraphConst.MULTI_INPUT_KEY: {}, + GraphConst.MULTI_SUGGEST_KEY: {}, + } for _ in range(repeat_time)] + self.set_index(0) + + def get_repeat_time(self): + return self._repeat_time + + def set_index(self, index): + if index < 0 or index >= self._repeat_time: + raise Exception('The index is out of range.') + self._index = index + + def get_data(self, key): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + return self._data[self._index].get(key, {}) + + def set_data(self, key, value): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + self._data[self._index][key] = value + + def to_dict(self, key): + if key not in self._data[self._index]: + raise Exception(f'{key} is not in the key of MultiData') + return [data[key] for data in self._data] diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py index fb046f9758686fe810a05b1a23d76880b86bb994..403699c87ce2f02f7223918d38bf115ea5724c61 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py @@ -116,3 +116,7 @@ class GraphConst: OUTPUT_INDEX = -2 STR_MAX_LEN = 50 SMALL_VALUE = 1e-3 + MULTI_DATA_KEY = 'data' + MULTI_OUTPUT_KEY = 'output_data' + MULTI_INPUT_KEY = 'input_data' + MULTI_SUGGEST_KEY = 'suggestions'