diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..675a87a412cf078e940441e53512156e8082476e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/graph_builder.py @@ -0,0 +1,152 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from msprobe.pytorch.visualization.graph.graph import Graph, BaseNode +from msprobe.pytorch.visualization.graph.node_op import NodeOp +from msprobe.pytorch.visualization.utils import load_json_file, load_data_json_file, save_json_file, GraphConst +from msprobe.pytorch.visualization.builder.msprobe_adapter import get_input_output + + +class GraphBuilder: + @staticmethod + def build(construct_path, data_path, model_name='DefaultModel'): + """ + GraphBuilder的对外提供的构图方法 + Args: + construct_path: construct.json路径 + data_path: dump.json路径 + model_name: 模型名字,依赖外部输入 + Returns: Graph,代表图的数据结构 + """ + construct_dict = load_json_file(construct_path) + data_dict = load_data_json_file(data_path) + graph = Graph(model_name) + GraphBuilder._init_nodes(graph, construct_dict, data_dict) + GraphBuilder._collect_apis_between_modules(graph) + return graph + + @staticmethod + def to_json(filename, config): + """ + 将graph导出成.vis文件的接口 + """ + result = {} + if config.graph_b: + result[GraphConst.JSON_NPU_KEY] = config.graph_n.to_dict() + result[GraphConst.JSON_BENCH_KEY] = config.graph_b.to_dict() + else: + result = config.graph_n.to_dict() + if config.tool_tip: + result[GraphConst.JSON_TIP_KEY] = config.tool_tip + if config.node_colors: + result[GraphConst.COLORS] = config.node_colors + if config.micro_steps: + result[GraphConst.MICRO_STEPS] = config.micro_steps + save_json_file(filename, result) + + @staticmethod + def _handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id): + """ + 如果backward节点的父级节点是null,则尝试从同名的forward节点寻找父级节点 + """ + # 匹配以.backward.后跟一个或多个数字结尾的模式 + backward_pattern = r"(\.backward\.)(\d+)$" + forward_pattern = r"(\.forward\.)(\d+)$" + if re.search(backward_pattern, subnode_id) and not upnode_id: + forward_upnode_id = construct_dict.get(re.sub(backward_pattern, r".forward.\2", subnode_id)) + if forward_upnode_id: + new_upnode_id = re.sub(forward_pattern, r".backward.\2", forward_upnode_id) + if new_upnode_id in construct_dict: + return new_upnode_id + return upnode_id + + @staticmethod + def _init_nodes(graph, construct_dict, data_dict): + for subnode_id, upnode_id in construct_dict.items(): + upnode_id = GraphBuilder._handle_backward_upnode_missing(construct_dict, subnode_id, upnode_id) + if upnode_id: + upnode_op = NodeOp.get_node_op(upnode_id) + upnode = GraphBuilder._create_or_get_node(graph, data_dict, upnode_op, upnode_id) + else: + upnode = graph.root + node_op = NodeOp.get_node_op(subnode_id) + GraphBuilder._create_or_get_node(graph, data_dict, node_op, subnode_id, upnode) + + @staticmethod + def _create_or_get_node(graph, data_dict, op, name, upnode=None): + if name in graph.node_map: + node = graph.get_node(name) + else: + graph.add_node(op, name, upnode) + node = graph.get_node(name) + node_data = data_dict.get(name, {}) + # 添加输入输出数据 + input_data, output_data = get_input_output(node_data, node.id) + # 更新数据 + node.set_input_output(input_data, output_data) + # 添加节点 + node.add_upnode(upnode) + return node + + @staticmethod + def _collect_apis_between_modules(graph): + """ + 图首次展开,这些首层节点包含许多module和api,api数量很多导致图被拉得很长严重影响查阅,因此将module之间的apis收集起来成为节点 + Args: + graph: 模型结构 + + Returns: None + """ + i = 0 + output = [] + node_list = graph.root.subnodes + while i < len(node_list): + current_node = node_list[i] + + # 当前节点为api,检查后续是否还有api + if current_node.op == NodeOp.function_api: + temp_nodes = [current_node] + i += 1 + while i < len(node_list) and node_list[i].op == NodeOp.function_api: + temp_nodes.append(node_list[i]) + i += 1 + + # 检查api节点是否大于等于2个 + if len(temp_nodes) >= 2: + # 创建新节点,将这些api节点放入新节点的subnodes属性 + node_id = graph.add_node(NodeOp.api_collection, GraphConst.APIS_BETWEEN_MODULES, + id_accumulation=True) + api_collection_node = graph.get_node(node_id) + api_collection_node.subnodes = temp_nodes + output.append(api_collection_node) + else: + # 如果连续的api节点不足2个,将它们原样添加到输出列表 + output.extend(temp_nodes) + else: + # 如果当前节点为module,直接添加到输出列表 + output.append(current_node) + i += 1 + + graph.root.subnodes = output + + +class GraphExportConfig: + def __init__(self, graph_n, graph_b=None, tool_tip=None, node_colors=None, micro_steps=None): + self.graph_n = graph_n + self.graph_b = graph_b + self.tool_tip = tool_tip + self.node_colors = node_colors + self.micro_steps = micro_steps diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..92af6d67325c2b2ff2b4dbd34bf799b3219445de --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py @@ -0,0 +1,210 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from msprobe.pytorch.compare.acc_compare import read_op, merge_tensor, get_accuracy, _do_multi_process +from msprobe.core.common.utils import task_dumppath_get +from msprobe.pytorch.visualization.utils import GraphConst + + +# 用于将节点名字解析成对应的NodeOp的规则 +op_patterns = [ + r'^(Module)', #NodeOp.module + r'^(Tensor|Torch|Functional|NPU|VF|Distributed|Aten)' #NodeOp.function_api +] + + +def get_compare_mode(dump_path_param): + """ + 获得比较模式,包括summary、MD5和真实数据三种模式 + Args: + dump_path_param: 调用acc_compare接口所依赖的参数 + Returns: 0 summary mode, 1 md5 mode, 2 true data mode + """ + summary_compare, md5_compare = task_dumppath_get(dump_path_param) + if summary_compare: + compare_mode = GraphConst.SUMMARY_COMPARE + elif md5_compare: + compare_mode = GraphConst.MD5_COMPARE + else: + compare_mode = GraphConst.REAL_DATA_COMPARE + return compare_mode + + +def run_real_data(dump_path_param, csv_path): + """ + 多进程运行生成真实数据 + Args: + dump_path_param: 调用acc_compare接口所依赖的参数 + csv_path: 生成文件路径 + """ + return _do_multi_process(dump_path_param, csv_path) + + +def get_input_output(node_data, node_id): + """ + 将dump的原始数据进行拆解,分解为output和input两个数据 + Args: + node_data: 属于单个节点的dump数据 + node_id: 节点名字 + """ + input_data = {} + output_data = {} + op_parsed_list = read_op(node_data, node_id) + for item in op_parsed_list: + full_op_name = item.get('full_op_name', '') + if not full_op_name: + continue + splits = full_op_name.split('.') + if len(splits) < GraphConst.OUTPUT_MIN_LEN: + continue + if GraphConst.OUTPUT in splits[GraphConst.OUTPUT_INDEX_TWO] and \ + GraphConst.INPUT not in splits[GraphConst.OUTPUT_INDEX_THREE]: + output_data[full_op_name] = item + else: + input_data[full_op_name] = item + return input_data, output_data + + +def compare_data(data_dict_list1, data_dict_list2): + """ + 比较get_input_output中输出的结果是否结构一致,比较一致返回True + """ + if len(data_dict_list1) != len(data_dict_list2): + return False + # 用于比较两个节点是否相等的关键字段 + tag_keys = ['type', 'dtype', 'shape'] + for key1, key2 in zip(data_dict_list1, data_dict_list2): + dict1 = data_dict_list1[key1] + dict2 = data_dict_list2[key2] + for tag_key in tag_keys: + tag_value1 = dict1.get(tag_key, None) + tag_value2 = dict2.get(tag_key, None) + if tag_value1 != tag_value2: + return False + return True + + +def compare_mapping_data(data_dict_list1, data_dict_list2): + """ + node1映射node2,可能node1参数多于或少于node2参数,个别参数的shape的维度顺序不同,node1参数null对应node2参数其他值 + 工具要尽可能保证node的数据能够比对,进行数据的弱校验,仅校验参数的shape维度数值是否相同 + """ + for x, y in zip(data_dict_list1.values(), data_dict_list2.values()): + x_shape = x.get('shape') + y_shape = y.get('shape') + if x_shape is None or y_shape is None: + continue + x_shape = sorted(x_shape) if isinstance(x_shape, list) else x_shape + y_shape = sorted(y_shape) if isinstance(y_shape, list) else y_shape + if x_shape != y_shape: + return False + return True + + +def format_node_data(data_dict): + """ + 批量进行节点数据的输出 + """ + del_list = ['requires_grad', 'data_name', 'full_op_name'] + for _, value in data_dict.items(): + if not isinstance(value, dict): + continue + for item in del_list: + if item in value: + del value[item] + _format_data(value) + return data_dict + + +def compare_node(node_ids, data_dicts, stack_json_data, is_summary_compare, is_md5_compare): + """ + 调用acc_compare.py中的get_accuracy获得精度对比指标 + 真实数据对比模式无法获得精度对比指标,需要调用多进程比对接口 + Returns: 包含参数信息和对比指标(真实数据对比模式除外)的list + """ + merge_n = _parse_node(node_ids[0], data_dicts[0], stack_json_data, is_summary_compare, is_md5_compare) + merge_b = _parse_node(node_ids[1], data_dicts[1], stack_json_data, is_summary_compare, is_md5_compare) + result = [] + get_accuracy(result, merge_n, merge_b, is_summary_compare, is_md5_compare) + return result + + +def _parse_node(node_id, data_dict, stack_json_data, is_summary_compare, is_md5_compare): + """ + 转换节点,使其能够作为acc_compare.py中的get_accuracy的入参 + """ + op_parsed_list = read_op(data_dict.get(node_id, {}), node_id) + if node_id in stack_json_data: + op_parsed_list.append( + {'full_op_name': node_id, 'full_info': stack_json_data[node_id]}) + else: + op_parsed_list.append({'full_op_name': node_id, 'full_info': None}) + result = merge_tensor(op_parsed_list, is_summary_compare, is_md5_compare) + if not result: + result['op_name'] = [] + return result + + +def _format_decimal_string(s): + """ + 使用正则表达式匹配包含数字、小数点和可选的百分号的字符串 + """ + pattern = re.compile(r'\d{1,20}\.\d{1,20}%?') + matches = pattern.findall(s) + for match in matches: + is_percent = match.endswith('%') + number_str = match.rstrip('%') + decimal_part = number_str.split('.')[1] + # 如果小数位数大于6,进行处理 + if len(decimal_part) > GraphConst.ROUND_TH: + number_float = float(number_str) + formatted_number = f"{number_float:.{GraphConst.ROUND_TH}f}" + # 如果原来是百分数,加回百分号 + if is_percent: + formatted_number += '%' + # 替换原字符串中的数值部分 + s = s.replace(match, formatted_number) + return s + + +def _format_data(data_dict): + """ + 格式化数据,小数保留6位,处理一些异常值 + """ + pattern = r'^[+-]?(\d+(\.\d*)?|\.\d+)([eE][+-]?\d+)$' + none_num = 0 + for key, value in data_dict.items(): + if isinstance(value, str): + # 将单引号删掉,None换成null避免前端解析错误 + value = value.replace("'", "").replace(GraphConst.NONE, GraphConst.NULL) + value = _format_decimal_string(value) + elif value is None or value == ' ': + value = GraphConst.NULL + # 科学计数法1.123123123123e-11,格式化为1.123123e-11 + elif isinstance(value, float) and len(str(value)) < GraphConst.STR_MAX_LEN and re.match(pattern, str(value)): + value = "{:.6e}".format(value) + elif isinstance(value, float): + value = round(value, GraphConst.ROUND_TH) + # Inf会走入这里,确保转成Inf。另外给其他不符合预期的类型做兜底方案 + if not isinstance(value, (list, tuple, dict, str)): + value = str(value) + if value == GraphConst.NULL or key == GraphConst.ERROR_KEY: + none_num += 1 + data_dict[key] = value + # 字典里的value全null,只保留一个null + if none_num == len(data_dict): + data_dict.clear() + data_dict[GraphConst.VALUE] = GraphConst.NULL diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..a5ec6c8db4be6431352570cfeb3644a7da816f35 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -0,0 +1,131 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.pytorch.visualization.builder.msprobe_adapter import compare_node, get_compare_mode, run_real_data +from msprobe.pytorch.visualization.utils import GraphConst, load_json_file, load_data_json_file, get_csv_df +from msprobe.pytorch.visualization.graph.graph import Graph, NodeOp +from msprobe.pytorch.visualization.graph.node_colors import NodeColors +from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter + + +class GraphComparator: + def __init__(self, graphs, data_paths, stack_path, output_path, mapping_config=None): + self.graph_n = graphs[0] + self.graph_b = graphs[1] + self._parse_param(data_paths, stack_path, output_path) + self.mapping_config = mapping_config + + def compare(self): + """ + 比较函数,初始化结束后单独调用。比较结果写入graph_n + """ + self._compare_nodes(self.graph_n.root) + self._postcompare() + + def add_compare_result_to_node(self, node, compare_result_list): + """ + 将比对结果添加到节点的输入输出数据中 + Args: + node: 节点 + compare_result_list: 包含参数信息和对比指标(真实数据对比模式除外)的list + """ + # 真实数据比对,先暂存节点,在多进程对比得到精度指标后,再将指标添加到节点中 + if self.ma.prepare_real_data(node): + return + compare_in_dict = {} + compare_out_dict = {} + # input和output对比数据分开 + for item in compare_result_list: + if not node.stack_info and node.id in item[0]: + node.stack_info = item[-1] + if 'output' in item[0]: + compare_out_dict[item[0]] = item + else: + compare_in_dict[item[0]] = item + precision_index, other_dict = ( + self.ma.parse_result(node, [compare_in_dict, compare_out_dict])) + node.data[GraphConst.JSON_INDEX_KEY] = precision_index + node.data.update(other_dict) + if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index): + self.ma.add_error_key(node.output_data) + node.get_suggestions() + + def _parse_param(self, data_paths, stack_path, output_path): + self.dump_path_param = { + 'npu_json_path': data_paths[0], + 'bench_json_path': data_paths[1], + 'stack_json_path': stack_path, + 'is_print_compare_log': True + } + self.output_path = output_path + compare_mode = get_compare_mode(self.dump_path_param) + self.ma = ModeAdapter(compare_mode) + self.data_n_dict = load_data_json_file(data_paths[0]) + self.data_b_dict = load_data_json_file(data_paths[1]) + self.stack_json_data = load_json_file(stack_path) + + def _postcompare(self): + self._handle_api_collection_index() + if not self.ma.is_real_data_compare(): + return + df = get_csv_df(self.ma.is_md5_compare(), self.ma.is_summary_compare(), True, self.ma.csv_data) + df = run_real_data(self.dump_path_param, df) + compare_data_dict = {row[0]: row.tolist() for _, row in df.iterrows()} + for node in self.ma.compare_nodes: + precision_index, _ = self.ma.parse_result(node, [compare_data_dict]) + node.data[GraphConst.JSON_INDEX_KEY] = precision_index + if NodeColors.get_node_error_status(self.ma.compare_mode, precision_index): + self.ma.add_error_key(node.output_data) + node.get_suggestions() + + def _handle_api_collection_index(self): + """ + api集合的指标使用集合中所有api最小的指标 + """ + for node in self.graph_n.root.subnodes: + if node.op == NodeOp.api_collection: + precision_index = 1 + for api in node.subnodes: + precision_index = min(precision_index, api.data.get(GraphConst.JSON_INDEX_KEY, 1)) + node.data[GraphConst.JSON_INDEX_KEY] = precision_index + + def _compare_nodes(self, node_n): + """ + 递归遍历NPU树中的节点,如果在Bench中找到具有相同名称的节点,检查他们的祖先和参数信息,检查一致则及逆行精度数据对比 + 这里采用先序遍历,好处在于当这个节点被比较时,他的先序已经被匹配,这可以为后续的模糊匹配提供重要信息 + """ + if self.mapping_config: + node_b, ancestors_n, ancestors_b = Graph.mapping_match(node_n, self.graph_b, self.mapping_config) + if node_b: + ancestors_n.append(node_n.id) + ancestors_b.append(node_b.id) + node_n.matched_node_link = ancestors_b + node_b.matched_node_link = ancestors_n + else: + node_b, ancestors = Graph.match(self.graph_n, node_n, self.graph_b) + if node_b: + ancestors.append(node_b.id) + node_n.add_link(node_b, ancestors) + if node_b: + # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口 + compare_result_list = compare_node([node_n.id, node_b.id], + [self.data_n_dict, self.data_b_dict], + self.stack_json_data, self.ma.is_summary_compare(), + self.ma.is_md5_compare()) + if compare_result_list: + self.ma.add_csv_data(compare_result_list) + self.add_compare_result_to_node(node_n, compare_result_list) + for subnode in node_n.subnodes: + self._compare_nodes(subnode) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..0fb16c8a789fc98522cfc4e62e8563a8c3a5a8bd --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/mode_adapter.py @@ -0,0 +1,199 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from msprobe.core.common.const import CompareConst, Const +from msprobe.pytorch.visualization.utils import ToolTip, GraphConst, str2float + + +class ModeAdapter: + def __init__(self, compare_mode): + self.compare_mode = compare_mode + self.csv_data = [] + self.compare_nodes = [] + + @staticmethod + def _add_md5_compare_data(node_data, compare_data_dict): + precision_index = GraphConst.MIN_INDEX_KEY + for key, value in node_data.items(): + if not isinstance(value, dict): + continue + compare_data = compare_data_dict.get(key) + if compare_data: + headers = CompareConst.MD5_COMPARE_RESULT_HEADER + id_list = [headers.index(x) for x in GraphConst.MD5_INDEX_LIST] + ModeAdapter._match_data(value, compare_data, GraphConst.MD5_INDEX_LIST, id_list) + # md5比对是否通过 + if value.get(CompareConst.RESULT) != CompareConst.PASS: + precision_index = GraphConst.MAX_INDEX_KEY + node_data[key] = value + return precision_index + + @staticmethod + def _add_real_compare_data(node_data, compare_data_dict): + min_thousandth = float(1) + numbers = [] + for key, value in node_data.items(): + if not isinstance(value, dict): + continue + compare_data = compare_data_dict.get(key) + if compare_data: + headers = CompareConst.COMPARE_RESULT_HEADER + id_list = [headers.index(x) for x in GraphConst.REAL_DATA_INDEX_LIST] + ModeAdapter._match_data(value, compare_data, GraphConst.REAL_DATA_INDEX_LIST, id_list) + # 获取一个节点所有的输入或输出最小的双千指标 + thousandth = value.get(CompareConst.ONE_THOUSANDTH_ERR_RATIO) + # 可能是None,可能是非数字内容str + try: + thousandth = float(thousandth) + except (ValueError, TypeError): + thousandth = None + if thousandth is not None: + numbers.append(thousandth) + node_data[key] = value + # 双千指标都是None的异常情况 + if not numbers: + min_thousandth = None + else: + min_thousandth = min(numbers + [min_thousandth]) + return min_thousandth + + @staticmethod + def _add_summary_compare_data( node_data, compare_data_dict): + max_relative_err = 0 + for key, value in node_data.items(): + if not isinstance(value, dict): + continue + compare_data = compare_data_dict.get(key) + if compare_data: + # 对应比对结果csv的列 + key_list = GraphConst.SUMMARY_INDEX_LIST + headers = CompareConst.SUMMARY_COMPARE_RESULT_HEADER + id_list = [headers.index(x) for x in key_list] + ModeAdapter._match_data(value, compare_data, key_list, id_list) + # 相对误差大于0.5疑似有精度问题,小值域1e-3不比较相对误差 + for index, item in enumerate(key_list[4:]): + value_diff = value.get(key_list[index]) + if isinstance(value_diff, float) and value_diff != 0 and abs(value_diff) < GraphConst.SMALL_VALUE: + value[item] = ToolTip.SMALL_VALUE_TIP.format(key_list[index]) + continue + relative_err = str2float(value.get(item)) + max_relative_err = max(max_relative_err, relative_err) + node_data[key] = value + max_relative_err = 1 if max_relative_err > 1 else max_relative_err + return max_relative_err + + @staticmethod + def _match_data(data_dict, compare_data, key_list, id_list): + """ + 绑定精度指标到node的input_data和output_data + """ + if len(key_list) != len(id_list): + return + for id, key in zip(id_list, key_list): + data = compare_data[id] + if data is not None and 'nan' not in str(data) and str(data) != ' ': + data_dict[key] = data + else: + data_dict[key] = 'null' + + def parse_result(self, node, compare_data_dict): + """ + 根据结果返回数据,分别是precision_index,和附加数据 + """ + other_dict = {} + if self.is_md5_compare(): + precision_index_in = ModeAdapter._add_md5_compare_data(node.input_data, compare_data_dict[0]) + precision_index_out = ModeAdapter._add_md5_compare_data(node.output_data, compare_data_dict[1]) + # 所有输入输出md5对比通过,这个节点才算通过 + precision_index = max(precision_index_in, precision_index_out) + other_result = CompareConst.PASS if precision_index == 1 else CompareConst.DIFF + other_dict[CompareConst.RESULT] = other_result + elif self.is_summary_compare(): + precision_index_in = ModeAdapter._add_summary_compare_data(node.input_data, compare_data_dict[0]) + precision_index_out = ModeAdapter._add_summary_compare_data(node.output_data, compare_data_dict[1]) + precision_index = max(precision_index_in, precision_index_out) + else: + min_thousandth_in = ModeAdapter._add_real_compare_data(node.input_data, compare_data_dict[0]) + min_thousandth_out = ModeAdapter._add_real_compare_data(node.output_data, compare_data_dict[0]) + if min_thousandth_in is not None and min_thousandth_out is not None: + change_percentage = abs(min_thousandth_in - min_thousandth_out) + else: + change_percentage = 0 + precision_index = GraphConst.MAX_INDEX_KEY \ + if change_percentage > GraphConst.MAX_INDEX_KEY else change_percentage + return precision_index, other_dict + + def prepare_real_data(self, node): + """ + 为真实数据比较模式准备节点信息 + """ + if self.is_real_data_compare(): + self.compare_nodes.append(node) + return True + return False + + def is_summary_compare(self): + return self.compare_mode == GraphConst.SUMMARY_COMPARE + + def is_md5_compare(self): + return self.compare_mode == GraphConst.MD5_COMPARE + + def is_real_data_compare(self): + return self.compare_mode == GraphConst.REAL_DATA_COMPARE + + def add_csv_data(self, compare_result_list): + if not self.is_real_data_compare(): + return + self.csv_data.extend(compare_result_list) + + def add_error_key(self, node_data): + """ + 根据不同的模式进行提供不同错误信息 + """ + for key, value in node_data.items(): + if not isinstance(value, dict): + continue + if self.is_summary_compare(): + message = [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, + CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] + elif self.is_real_data_compare(): + message = [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] + else: + # 输出件优化 + message = [] + value[GraphConst.ERROR_KEY] = message + node_data[key] = value + + def get_tool_tip(self): + """ + 用于前端展示字段的具体含义 + """ + if self.is_summary_compare(): + tips = { + CompareConst.MAX_DIFF: ToolTip.MAX_DIFF, + CompareConst.MIN_DIFF: ToolTip.MIN_DIFF, + CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF, + CompareConst.NORM_DIFF: ToolTip.NORM_DIFF} + elif self.is_md5_compare(): + tips = {Const.MD5: ToolTip.MD5} + else: + tips = { + CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO, + CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.COSINE: ToolTip.COSINE, + CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR, + CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR} + return json.dumps(tips) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/__init__.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py new file mode 100644 index 0000000000000000000000000000000000000000..87f020a956b24f4b5884e3e60b72513f6f39625f --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/base_node.py @@ -0,0 +1,119 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.pytorch.visualization.graph.node_op import NodeOp +from msprobe.pytorch.visualization.utils import Suggestions, GraphConst +from msprobe.pytorch.visualization.builder.msprobe_adapter import format_node_data, compare_data, compare_mapping_data + + +class BaseNode: + def __init__(self, node_op, node_id, up_node=None): + self.op = node_op + self.id = node_id + self.data = {} + self.output_data = {} + self.input_data = {} + self.upnode = None + self.add_upnode(up_node) + self.subnodes = [] + self.matched_node_link = [] + self.suggestions = {} + self.stack_info = [] + self.micro_step_id = None + + def __str__(self): + info = f'id:\t{self.id}' + return info + + def __eq__(self, other): + """ + 用来判断两个节点是否可以被匹配上,认为结构上是否一致 + """ + if not compare_data(self.input_data, other.input_data): + return False + if not compare_data(self.output_data, other.output_data): + return False + return True + + def compare_mapping_node(self, other): + if not compare_mapping_data(self.input_data, other.input_data): + return False + if not compare_mapping_data(self.output_data, other.output_data): + return False + return True + + def get_suggestions(self): + """ + 精度疑似有问题时,提供一些建议 + """ + if self.op == NodeOp.module: + self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.Module + self.suggestions[Suggestions.DUMP] = Suggestions.DUMP_URL + elif self.op == NodeOp.function_api: + self.suggestions[GraphConst.SUGGEST_KEY] = Suggestions.API + self.suggestions[Suggestions.API_ACCURACY_CHECKER] = Suggestions.API_ACCURACY_CHECKER_URL + + def set_input_output(self, input_data, output_data): + self.input_data = input_data + self.output_data = output_data + + def add_upnode(self, node): + """ + 绑定upnode,用于对两个节点进行上下级关联 + """ + if not node or node.id == self.id or self.upnode: + return + self.upnode = node + node.subnodes.append(self) + + def add_link(self, node, ancestors): + """ + 在节点匹配成功后进行匹配数据的录入 + Args: + node: 和self相互匹配的节点 + ancestors: 对面节点的祖先信息 + """ + self.matched_node_link = ancestors + node.matched_node_link = ancestors + + def to_dict(self): + """ + 输出数据 + """ + result = {} + result['id'] = self.id + result['node_type'] = self.op.value + result['data'] = self.data + result['output_data'] = format_node_data(self.output_data) + result['input_data'] = format_node_data(self.input_data) + result['upnode'] = self.upnode.id if self.upnode else 'None' + result['subnodes'] = [node.id for node in self.subnodes] + result['matched_node_link'] = self.matched_node_link + result['suggestions'] = self.suggestions + result['stack_info'] = self.stack_info + if self.micro_step_id is not None: + result['micro_step_id'] = self.micro_step_id + return result + + def get_ancestors(self): + """ + 获取节点所有祖先的列表 + """ + ancestors = [] + current_node = self.upnode + while current_node: + ancestors.append(current_node.id) + current_node = current_node.upnode + return list(reversed(ancestors)) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..2ca5c6811b496808c521f8bfd5df58adb4605c1e --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py @@ -0,0 +1,167 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.pytorch.visualization.graph.base_node import BaseNode +from msprobe.pytorch.visualization.graph.node_op import NodeOp +from msprobe.pytorch.visualization.utils import GraphConst +from msprobe.core.common.log import logger +from msprobe.core.common.const import Const + + +class Graph: + def __init__(self, model_name): + self.node_map = {} + self.node_id_map = {} + self.add_node(NodeOp.module, model_name) + self.root = self.get_node(model_name) + + def __str__(self): + infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] + info = "\n".join(infos) + return info + + @staticmethod + def match(graph_n, node_n, graph_b): + """ + 给定节点n,在另一个graph中匹配它对应的节点。前置条件是它的父节点匹配已经完成 + 目前采用完全匹配的方式,后续可能在这里加入一定的模糊匹配逻辑 + 返回匹配结果,匹配到的节点,以及祖先树。没匹配到则返回None, [] + """ + if not node_n or node_n.id not in graph_b.node_map: + return None, [] + node_b = graph_b.node_map.get(node_n.id) + if node_n != node_b: + return None, [] + ancestors_n = node_n.get_ancestors() + ancestors_b = node_b.get_ancestors() + if ancestors_n != ancestors_b: + return None, [] + return node_b, ancestors_n + + @staticmethod + def mapping_match(node_n, graph_b, mapping_config): + """ + 根据映射配置对节点进行匹配 + """ + node_b = graph_b.node_map.get(mapping_config.get_mapping_string(node_n.id)) + if not node_b or not node_n.compare_mapping_node(node_b): + return None, [], [] + ancestors_n = node_n.get_ancestors() + ancestors_b = node_b.get_ancestors() + return node_b, ancestors_n, ancestors_b + + @staticmethod + def dfs(node, result): + info = node.to_dict() + result[node.id] = info + for subnode in node.subnodes: + Graph.dfs(subnode, result) + + @staticmethod + def split_nodes_by_micro_step(nodes): + """ + 根据Module名称后缀数字, 区分一个step中的多个micro steps, 后缀数字相同代表节点属于同一个micro step. + 如果是非Module节点,分类到前一个Module节点所在的micro step. + """ + result = {} + default_id = 0 + result[default_id] = [] + + for node in nodes: + if node.op == NodeOp.module: + micro_step_id = node.id.split(Const.SEP)[-1] + try: + micro_step_id = int(micro_step_id) + except ValueError: + logger.warning(f'The node id suffix {micro_step_id} is not a number, micro steps cannot be split.') + micro_step_id = 0 + if micro_step_id not in result: + default_id = micro_step_id + result[micro_step_id] = [] + result[micro_step_id].append(node) + else: + result[default_id].append(node) + return result + + def add_node(self, node_op, node_id, up_node=None, id_accumulation=False): + """ + 在graph中进行节点的添加 + Args: + node_op: 需要添加的节点类型 + node_id: 需要添加的节点id + up_node:对应节点的父节点 + id_accumulation: 是否对传入的重复node_id进行累加 + """ + if node_id in self.node_map: + if id_accumulation: + self.node_id_map[node_id] = 0 + else: + return node_id + if id_accumulation: + if node_id in self.node_id_map: + self.node_id_map[node_id] += 1 + else: + self.node_id_map[node_id] = 0 + node_id = f'{node_id}.{self.node_id_map[node_id]}' + node = BaseNode(node_op, node_id, up_node) + self.node_map[node_id] = node + return node_id + + def get_node(self, node_id): + """ + 返回节点,不存在返回None + """ + return self.node_map.get(node_id, None) + + def to_dict(self): + """ + 用于数据输出 + """ + result = {} + result[GraphConst.JSON_ROOT_KEY] = self.root.id if self.root else 'None' + result[GraphConst.JSON_NODE_KEY] = {} + for node_id in self.node_map: + info = self.node_map.get(node_id).to_dict() + result[GraphConst.JSON_NODE_KEY][node_id] = info + return result + + def paging_by_micro_step(self, graph_other=None): + """ + 给graph首层节点增加micro step标记,供前端分页展示,有助于在处理大规模图数据时进行优化和管理 + 比对场景中,同步更新另一个图graph_other中相应节点的micro step信息 + Args: + self: 当前graph + graph_other: 可选参数,另一个graph + Returns: 分批的数量 + """ + batches_n = Graph.split_nodes_by_micro_step(self.root.subnodes) + for batch_number, nodes in batches_n.items(): + for node in nodes: + node.micro_step_id = batch_number + # 在graph_other中更新已匹配节点的micro_step_id + if graph_other and node.matched_node_link: + node_other = graph_other.get_node(node.matched_node_link[-1]) + if node_other: + node_other.micro_step_id = batch_number + # 遍历graph_other根节点下的所有子节点,确保未匹配节点也有micro_step_id + if graph_other: + for node in graph_other.root.subnodes: + if node.micro_step_id is None: + try: + micro_step_id = int(node.id.split(Const.SEP)[-1]) + except ValueError: + micro_step_id = 0 + node.micro_step_id = micro_step_id + return len(batches_n) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py new file mode 100644 index 0000000000000000000000000000000000000000..8e07625140d284afac60ef1f1cf3b40e3a6f4ea7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_colors.py @@ -0,0 +1,95 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +from msprobe.pytorch.visualization.utils import GraphConst, ToolTip + +SUMMARY_DESCRIPTION = "此节点所有输入输出的统计量相对误差, 值越大代表测量值与标杆值的偏差越大, 相对误差计算方式:|(测量值-标杆值)/标杆值|" +REAL_DATA_DESCRIPTION = (f"此节点所有输入的最小双千分之一和所有输出的最小双千分之一的差值的绝对值, 代表双千指标的变化情况, " + f"值越大代表测量值与标杆值的偏差越大, 双千分之一指标计算方式:{ToolTip.ONE_THOUSANDTH_ERR_RATIO}") +MD5_DESCRIPTION_N = "与标杆相比, 此节点任意输入输出的md5值不同" +MD5_DESCRIPTION_Y = "与标杆相比, 此节点所有输入输出的md5值相同" +NOT_MATCHED = "比对过程中节点未匹配上" + + +class NodeColors(Enum): + # 枚举值后缀数字越小, 颜色越浅 + # value值左闭右开, 两个值相同代表固定值 + YELLOW_1 = ("#FFFCF3", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0, 0.2], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0, 0.05], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}, + GraphConst.MD5_COMPARE: {GraphConst.VALUE: [1, 1], GraphConst.DESCRIPTION: MD5_DESCRIPTION_Y}, + }) + YELLOW_2 = ("#FFEDBE", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.2, 0.4], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.05, 0.1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + ORANGE_1 = ("#FFDC7F", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.4, 0.6], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.1, 0.15], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + ORANGE_2 = ("#FFC62E", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.6, 0.8], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.15, 0.2], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION} + }) + RED = ("#E32020", { + GraphConst.SUMMARY_COMPARE: {GraphConst.VALUE: [0.8, 1], GraphConst.DESCRIPTION: SUMMARY_DESCRIPTION}, + GraphConst.REAL_DATA_COMPARE: {GraphConst.VALUE: [0.2, 1], GraphConst.DESCRIPTION: REAL_DATA_DESCRIPTION}, + GraphConst.MD5_COMPARE: {GraphConst.VALUE: [0, 0], GraphConst.DESCRIPTION: MD5_DESCRIPTION_N}, + }) + GREY = ("#C7C7C7", { + GraphConst.VALUE: [], GraphConst.DESCRIPTION: NOT_MATCHED + }) + + def __init__(self, hex_value, mode_info): + self.hex_value = hex_value + self.mode_info = mode_info + + @staticmethod + def get_node_colors(mode): + """ + 获取不同比对模式下的颜色说明 + Args: + mode: 比对模式 + Returns: 颜色说明 + """ + return { + color.hex_value: color.get_info_by_mode(mode) for color in NodeColors if color.get_info_by_mode(mode) + } + + @staticmethod + def get_node_error_status(mode, value): + """ + 判断精度数据比对指标是否大于基准值 + Args: + mode: 比对模式 + value: 精度数据比对指标 + Returns: bool + """ + info = NodeColors.ORANGE_1.get_info_by_mode(mode) + if info and GraphConst.VALUE in info: + value_range = info[GraphConst.VALUE] + return value > value_range[0] + return False + + def get_info_by_mode(self, mode): + if isinstance(self.mode_info, dict): + # 检查是否是模式特定的信息 + if isinstance(next(iter(self.mode_info.values())), dict): + return self.mode_info.get(mode, {}) + else: + # 所有模式共享相同的信息 + return self.mode_info + return {} diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py new file mode 100644 index 0000000000000000000000000000000000000000..9be4923c5d6913d3e4885fca34227f2db75e28d6 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/node_op.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum +import re +from msprobe.pytorch.visualization.builder.msprobe_adapter import op_patterns + + +class NodeOp(Enum): + module = 0 + function_api = 1 + api_collection = 9 + + @staticmethod + def get_node_op(node_name: str): + """ + 基于代表节点的字符串,解析节点种类 + """ + for op in NodeOp: + index = op.value + if index < 0 or index >= len(op_patterns): + raise Exception("NodeOp and op_patterns in MsprobeAdapter do not match") + pattern = op_patterns[index] + if re.match(pattern, node_name): + return op + raise Exception(f"Cannot parse node_name {node_name} into NodeOp") diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py new file mode 100644 index 0000000000000000000000000000000000000000..0ab966a3f33214095cbb6e01872a0878e63b59b7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph_service.py @@ -0,0 +1,59 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator +from msprobe.pytorch.visualization.utils import GraphConst +from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder, GraphExportConfig +from msprobe.core.common.log import logger +from msprobe.pytorch.visualization.mapping_config import MappingConfig +from msprobe.pytorch.visualization.graph.node_colors import NodeColors + +current_time = time.strftime("%Y%m%d%H%M%S") + + +def compare_graph(dump_path_n, dump_path_b, out_path, model_name='Model', mapping_file=None): + logger.info('Start building model graphs...') + # 对两个数据进行构图 + construct_path_n = os.path.join(dump_path_n, GraphConst.CONSTRUCT_FILE) + construct_path_b = os.path.join(dump_path_b, GraphConst.CONSTRUCT_FILE) + data_path_n = os.path.join(dump_path_n, GraphConst.DUMP_FILE) + data_path_b = os.path.join(dump_path_b, GraphConst.DUMP_FILE) + graph_n = GraphBuilder.build(construct_path_n, data_path_n, model_name) + graph_b = GraphBuilder.build(construct_path_b, data_path_b, model_name) + logger.info('Model graphs built successfully, start Comparing graphs...') + # 基于graph、stack和data进行比较 + stack_path = os.path.join(dump_path_n, GraphConst.STACK_FILE) + graph_comparator = GraphComparator([graph_n, graph_b], [data_path_n, data_path_b], stack_path, out_path, + mapping_config=MappingConfig(mapping_file) if mapping_file else None) + graph_comparator.compare() + micro_steps = graph_n.paging_by_micro_step(graph_b) + output_path = os.path.join(out_path, f'compare_{current_time}.vis') + export_config = GraphExportConfig(graph_n, graph_b, graph_comparator.ma.get_tool_tip(), + NodeColors.get_node_colors(graph_comparator.ma.compare_mode), micro_steps) + GraphBuilder.to_json(output_path, export_config) + logger.info(f'Model graphs compared successfully, the result file is saved in {output_path}') + + +def build_graph(dump_path, out_path, model_name='Model'): + logger.info('Start building model graph...') + construct_path = os.path.join(dump_path, GraphConst.CONSTRUCT_FILE) + data_path = os.path.join(dump_path, GraphConst.DUMP_FILE) + output_path = os.path.join(out_path, f'build_{current_time}.vis') + graph = GraphBuilder.build(construct_path, data_path, model_name) + micro_steps = graph.paging_by_micro_step() + GraphBuilder.to_json(output_path, GraphExportConfig(graph, micro_steps=micro_steps)) + logger.info(f'Model graph built successfully, the result file is saved in {output_path}') diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d986493513a345483bd6c32f808cca002b3cba5b --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py @@ -0,0 +1,77 @@ +import re +import yaml +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.const import Const +from msprobe.pytorch.visualization.utils import GraphConst + + +class MappingConfig: + MAX_STRING_LEN = 10000 + + def __init__(self, yaml_file): + with FileOpen(yaml_file, 'r') as file: + config = yaml.safe_load(file) + try: + self.config = {key: self.validate(key, value) for data in config for key, value in data.items()} + except Exception as e: + raise RuntimeError("Line of yaml contains content that is not '- key: value'.") from e + self.classify_config = self._classify_and_sort_keys() + + @staticmethod + def validate(key, value): + if not isinstance(key, str): + raise ValueError(f"{key} must be a string.") + if not isinstance(value, str): + raise ValueError(f"{value} must be a string.") + return value + + @staticmethod + def convert_to_regex(s): + """ + 字符串转换为正则表达式, {}替换为d+以匹配一个或多个数字, 开始和结束添加.*以匹配任意前缀和后缀 + Args: + s: 字符串 + Returns: 正则表达式 + """ + escaped_pattern = re.escape(s) + pattern = re.sub(r'\\\{\\\}', r'\\d+', escaped_pattern) + pattern = f'.*{pattern}.*' + return pattern + + @staticmethod + def _replace_parts(origin_string, mapping_key, mapping_value): + if GraphConst.BRACE in mapping_key: + parts = mapping_key.split(GraphConst.BRACE) + m_parts = mapping_value.split(GraphConst.BRACE) + return origin_string.replace(parts[0], m_parts[0]).replace(parts[1], m_parts[1]) + else: + return origin_string.replace(mapping_key, mapping_value) + + def get_mapping_string(self, origin_string: str): + if len(origin_string) > MappingConfig.MAX_STRING_LEN: + return origin_string + for category, items in self.classify_config.items(): + if category in origin_string: + for key, value in items: + if re.match(MappingConfig.convert_to_regex(key), origin_string): + return MappingConfig._replace_parts(origin_string, key, value) + return origin_string + + def _classify_and_sort_keys(self): + categorized_dict = {} + for key, value in self.config.items(): + parts = key.split(Const.SEP) + # 获取第一个部分作为新的分类key + category_key = parts[0] + + if category_key not in categorized_dict: + categorized_dict[category_key] = [] + + # 将原始的key-value对添加到对应的分类中 + categorized_dict[category_key].append((key, value)) + + # 对每个分类中的项按key中的.数量进行排序, .数量越多排越靠前, 优先匹配 + for category in categorized_dict: + categorized_dict[category].sort(key=lambda x: -x[0].count(Const.SEP)) + + return categorized_dict diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8a81f4f20922bb9fa58ce1d1e358ed1e0ce12c74 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py @@ -0,0 +1,138 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from msprobe.core.common.file_utils import FileOpen +from msprobe.core.common.const import CompareConst +from msprobe.pytorch.compare.acc_compare import result_to_csv + + +def load_json_file(file_path): + """ + 加载json文件 + """ + try: + with FileOpen(file_path, 'r') as f: + file_dict = json.load(f) + if not isinstance(file_dict, dict): + return {} + return file_dict + except json.JSONDecodeError: + return {} + + +def load_data_json_file(file_path): + """ + 加载dump.json中的data字段 + """ + return load_json_file(file_path).get(GraphConst.DATA_KEY, {}) + + +def save_json_file(file_path, data): + """ + 保存json文件 + """ + with FileOpen(file_path, 'w') as f: + f.write(json.dumps(data, indent=4)) + + +def get_csv_df(md5_compare, summary_compare, stack, csv_data): + """ + 调用acc接口写入csv + """ + return result_to_csv(md5_compare, summary_compare, stack, csv_data) + + +def str2float(percentage_str): + """ + 百分比字符串转换转换为浮点型 + Args: + percentage_str: '0.00%', '23.4%' + Returns: float 0.00, 0.234 + """ + try: + percentage_str = percentage_str.strip('%') + return float(percentage_str) / 100 + except ValueError: + return 0 + + +class ToolTip: + MAX_DIFF = 'NPU与标杆API统计信息比对,最大值的差值' + MIN_DIFF = 'NPU与标杆API统计信息比对,最小值的差值' + MEAN_DIFF = 'NPU与标杆API统计信息比对,平均值的差值' + NORM_DIFF = 'NPU与标杆API统计信息比对,2范数(平方根)的差值' + MD5 = '数据MD5信息,用于比较两个数据信息是否完全一致' + ONE_THOUSANDTH_ERR_RATIO = 'Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之一的比例占总元素个数的比例,比例越接近1越好' + FIVE_THOUSANDTHS_ERR_RATIO = 'Tensor中的元素逐个与对应的标杆数据对比,相对误差小于千分之五的比例占总元素个数的比例,比例越接近1越好' + COSINE = '通过计算两个向量的余弦值来判断其相似度,数值越接近于1说明计算出的两个张量越相似,实际可接受阈值为大于0.99。在计算中可能会存在nan,主要由于可能会出现其中一个向量为0' + MAX_ABS_ERR = '当最大绝对误差越接近0表示其计算的误差越小,实际可接受阈值为小于0.001' + MAX_RELATIVE_ERR = '当最大相对误差越接近0表示其计算的误差越小。当dump数据中存在0或Nan时,比对结果中最大相对误差则出现inf或Nan的情况,属于正常现象' + SMALL_VALUE_TIP = '{} 小于1e-3,不计算相对误差' + + +class Suggestions: + Module = '此模块精度比对结果疑似异常,请使用msprobe工具的数据采集功能对模块中的api进行dump比对' + API = '此api精度比对结果疑似异常,请使用msprobe工具的预检功能对api进行精度检测' + DUMP = 'msprobe工具的数据采集功能' + DUMP_URL = 'https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/pytorch/doc/dump.md' + API_ACCURACY_CHECKER = 'msprobe工具的预检功能' + API_ACCURACY_CHECKER_URL = 'https://gitee.com/ascend/mstt/blob/master/debug/accuracy_tools/msprobe/pytorch/doc/api_accuracy_checker.md' + + +class GraphConst: + CONSTRUCT_FILE = 'construct.json' + DUMP_FILE = 'dump.json' + STACK_FILE = 'stack.json' + GRAPH_FILE = 'graph.vis' + ERROR_KEY = 'error_key' + SUMMARY_COMPARE = 0 + MD5_COMPARE = 1 + REAL_DATA_COMPARE = 2 + JSON_NPU_KEY = 'NPU' + JSON_BENCH_KEY = 'Bench' + JSON_TIP_KEY = 'ToolTip' + JSON_ROOT_KEY = 'root' + JSON_NODE_KEY = 'node' + DATA_KEY = 'data' + REAL_DATA_TH = 0.1 + MAX_RELATIVE_ERR_TH = 0.5 + ROUND_TH = 6 + JSON_INDEX_KEY = 'precision_index' + MAX_INDEX_KEY = 1 + MIN_INDEX_KEY = 0 + SUGGEST_KEY = 'text' + TAG_NA = 'na' + OUTPUT_INDEX_TWO = -2 + OUTPUT_INDEX_THREE = -3 + OUTPUT_MIN_LEN = 3 + INPUT = 'input' + OUTPUT = 'output' + STR_MAX_LEN = 50 + SMALL_VALUE = 1e-3 + MD5_INDEX_LIST = [CompareConst.RESULT] + REAL_DATA_INDEX_LIST = [CompareConst.COSINE, CompareConst.MAX_ABS_ERR, CompareConst.MAX_RELATIVE_ERR, + CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] + SUMMARY_INDEX_LIST = [CompareConst.MAX_DIFF, CompareConst.MIN_DIFF, CompareConst.MEAN_DIFF, + CompareConst.NORM_DIFF, CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, + CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR] + APIS_BETWEEN_MODULES = 'Apis_Between_Modules' + NULL = 'null' + NONE = 'None' + VALUE = 'value' + BRACE = '{}' + DESCRIPTION = 'description' + COLORS = 'Colors' + MICRO_STEPS = 'MicroSteps' diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..9433bc136395773e118810b6154d906670b633f4 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_graph_builder.py @@ -0,0 +1,99 @@ +import unittest +from unittest.mock import MagicMock, patch +from msprobe.pytorch.visualization.builder.graph_builder import GraphBuilder, Graph, BaseNode, NodeOp + + +class TestGraphBuilder(unittest.TestCase): + + def setUp(self): + self.construct_path = "step/rank/construct.json" + self.data_path = "step/rank/dump.json" + self.model_name = "TestModel" + self.graph = Graph(self.model_name) + self.construct_dict = { + "Tensor1": "Module1", + "Module1": None + } + self.data_dict = { + "Module1": {"data": "data for Module1"}, + "Tensor1": {"data": "data for Tensor1"} + } + + @patch('msprobe.pytorch.visualization.builder.graph_builder.load_json_file') + @patch('msprobe.pytorch.visualization.builder.graph_builder.load_data_json_file') + def test_build(self, mock_load_data_json_file, mock_load_json_file): + mock_load_data_json_file.return_value = self.data_dict + mock_load_json_file.return_value = self.construct_dict + + graph = GraphBuilder.build(self.construct_path, self.data_path, self.model_name) + self.assertIsNotNone(graph) + self.assertIsInstance(graph, Graph) + self.assertEqual(len(graph.node_map), 3) + + @patch('msprobe.pytorch.visualization.builder.graph_builder.save_json_file') + def test_to_json(self, mock_save_json_file): + GraphBuilder.to_json("step/rank/output.vis", self.graph) + mock_save_json_file.assert_called_once() + + @patch('msprobe.pytorch.visualization.graph.node_op.NodeOp.get_node_op') + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.get_input_output', return_value=([], [])) + def test__init_nodes(self, mock_get_input_output, mock_get_node_op): + GraphBuilder._init_nodes(self.graph, self.construct_dict, self.data_dict) + mock_get_node_op.assert_any_call("Tensor1") + mock_get_node_op.assert_any_call("Module1") + self.assertIs(self.graph.root, self.graph.get_node("TestModel")) + + def test__create_or_get_node(self): + node_op = MagicMock() + data_dict = {"node1": {}} + node = GraphBuilder._create_or_get_node(self.graph, data_dict, node_op, "node1") + self.assertIn("node1", self.graph.node_map) + self.assertEqual(node.input_data, {}) + self.assertEqual(node.output_data, {}) + + def test__handle_backward_upnode_missing(self): + construct_dict = {'Module.module.a.forward.0': 'Module.root.forward.0', 'Module.module.a.backward.0': None, + 'Module.root.forward.0': None, 'Module.root.backward.0': None, + 'Module.module.b.forward.0': 'Module.root.forward.0', + 'Module.module.b.backward.0': 'Module.root.backward.0', 'Module.module.c.backward.0': None} + node_id_a = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.a.backward.0', None) + self.assertEqual(node_id_a, 'Module.root.backward.0') + node_id_b = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.b.backward.0', + 'Module.root.backward.0') + self.assertEqual(node_id_b, 'Module.root.backward.0') + node_id_c = GraphBuilder._handle_backward_upnode_missing(construct_dict, 'Module.module.c.backward.0', None) + self.assertIsNone(node_id_c) + + def test__collect_apis_between_modules_only_apis(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 1) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.api_collection) + self.assertEqual(len(graph.root.subnodes[0].subnodes), 2) + self.assertEqual(graph.root.subnodes[0].id, 'Apis_Between_Modules.0') + + def test__collect_apis_between_modules_mixed_nodes(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.function_api, 'Tensor.a.0'), BaseNode(NodeOp.module, 'Module.a.0'), + BaseNode(NodeOp.module, 'Module.b.0'), BaseNode(NodeOp.function_api, 'Tensor.b.0'), + BaseNode(NodeOp.function_api, 'Tensor.c.0'), BaseNode(NodeOp.module, 'Module.a.1')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 5) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.function_api) + self.assertEqual(graph.root.subnodes[1].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[3].op, NodeOp.api_collection) + self.assertEqual(len(graph.root.subnodes[3].subnodes), 2) + self.assertEqual(graph.root.subnodes[3].id, 'Apis_Between_Modules.0') + + def test__collect_apis_between_modules_only_modules(self): + graph = Graph('TestNet') + graph.root.subnodes = [BaseNode(NodeOp.module, 'Module.a.0'), BaseNode(NodeOp.module, 'Module.b.0'), + BaseNode(NodeOp.module, 'Module.a.1')] + GraphBuilder._collect_apis_between_modules(graph) + self.assertEqual(len(graph.root.subnodes), 3) + self.assertEqual(graph.root.subnodes[0].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[1].op, NodeOp.module) + self.assertEqual(graph.root.subnodes[2].op, NodeOp.module) + self.assertEqual(len(graph.root.subnodes[0].subnodes), 0) + self.assertEqual(graph.root.subnodes[0].id, 'Module.a.0') diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..f023128b8c7350b2bcefebe6007dc5ef46133e14 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/builder/test_msprobe_adapter.py @@ -0,0 +1,104 @@ +import unittest +from unittest.mock import patch +from msprobe.pytorch.visualization.builder.msprobe_adapter import ( + get_compare_mode, + run_real_data, + get_input_output, + compare_data, + format_node_data, + compare_node, + _format_decimal_string, + _format_data, + compare_mapping_data +) +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestMsprobeAdapter(unittest.TestCase): + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.task_dumppath_get', return_value=(True, False)) + def test_get_compare_mode_summary(self, mock_task_dumppath_get): + mode = get_compare_mode("dummy_param") + self.assertEqual(mode, GraphConst.SUMMARY_COMPARE) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter._do_multi_process') + def test_run_real_data(self, mock_do_multi_process): + run_real_data("dump_path", "csv_path") + mock_do_multi_process.assert_called_once_with("dump_path", "csv_path") + + def test_get_input_output(self): + node_data = { + 'input_args': [{'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [5], + 'Max': 2049.0, 'Min': 0.0, 'Mean': 410.20001220703125, 'Norm': 2049.0009765625, + 'requires_grad': False, 'full_op_name': 'Distributed.broadcast.0.forward_input.0'}, + {'type': 'int', 'value': 0}], + 'input_kwargs': {'group': None}, + 'output': [{'type': 'torch.Tensor', 'dtype': 'torch.int64', 'shape': [5], + 'Max': 2049.0, 'Min': 0.0, 'Mean': 410.20001220703125, 'Norm': 2049.0009765625, + 'requires_grad': False, 'full_op_name': 'Distributed.broadcast.0.forward_output.0'}, + {'type': 'int', 'value': 0}, None] + } + node_id = "Distributed.broadcast.0.forward" + input_data, output_data = get_input_output(node_data, node_id) + self.assertIn("Distributed.broadcast.0.forward_output.0", output_data) + self.assertIn("Distributed.broadcast.0.forward_input.0", input_data) + + def test_compare_data(self): + data_dict_list1 = {'key1': {'type': 'Type1', 'dtype': 'DType1', 'shape': 'Shape1'}} + data_dict_list2 = {'key1': {'type': 'Type1', 'dtype': 'DType1', 'shape': 'Shape1'}} + data_dict_list3 = {'key1': {'type': 'Type2', 'dtype': 'DType1', 'shape': 'Shape1'}} + data_dict_list4 = {} + self.assertTrue(compare_data(data_dict_list1, data_dict_list2)) + self.assertFalse(compare_data(data_dict_list1, data_dict_list3)) + self.assertFalse(compare_data(data_dict_list1, data_dict_list4)) + + def test_format_node_data(self): + data_dict = {'node1': {'data_name': 'data1', 'full_op_name': 'op1'}} + result = format_node_data(data_dict) + self.assertNotIn('data_name', result['node1']) + self.assertNotIn('requires_grad', result['node1']) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.get_accuracy') + def test_compare_node(self, mock_get_accuracy): + node_ids = ["node1", "node2"] + data_dicts = [{'node1': {"input_args": [], "input_kwargs": {}, "output": {}}}, + {'node2': {"input_args": [], "input_kwargs": {}, "output": {}}}] + stack_json_data = {} + result = compare_node(node_ids, data_dicts, stack_json_data, False, False) + mock_get_accuracy.assert_called_once() + self.assertIsInstance(result, list) + + def test__format_decimal_string(self): + s = "0.123456789%" + formatted_s = _format_decimal_string(s) + self.assertIn("0.123457%", formatted_s) + self.assertEqual('0.123457', _format_decimal_string('0.12345678')) + self.assertEqual('-1', _format_decimal_string('-1')) + self.assertEqual('0.0.25698548%', _format_decimal_string('0.0.25698548%')) + + def test__format_data(self): + data_dict = {'value': 0.123456789, 'value1': None, 'value2': "", 'value3': 1.123123123123e-11, + 'value4': torch.inf, 'value5': -1} + _format_data(data_dict) + self.assertEqual(data_dict['value'], '0.123457') + self.assertEqual(data_dict['value1'], 'null') + self.assertEqual(data_dict['value2'], '') + self.assertEqual(data_dict['value3'], '1.123123e-11') + self.assertEqual(data_dict['value4'], 'inf') + self.assertEqual(data_dict['value5'], '-1') + + all_none_dict = {'a': None, 'b': None, 'c': None, 'd': None, 'e': None} + _format_data(all_none_dict) + self.assertEqual({'value': 'null'}, all_none_dict) + + def test_compare_mapping_data(self): + dict1 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}, 'c': {'shape': [1, 2, 3]}} + dict2 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}, 'c': {'shape': [1, 2, 3]}} + dict3 = {'a': {'shape': [1, 2, 3]}, 'b': {'shape': [1, 2, 3]}} + dict4 = {'a': {'shape': [2, 1, 3]}, 'b': {'shape': [1, 2, 3]}} + dict5 = {'a': {'shape': [2, 2, 3]}, 'b': {'shape': [1, 2, 3]}} + dict6 = {'a': {'type': 'str'}} + self.assertTrue(compare_mapping_data(dict1, dict2)) + self.assertTrue(compare_mapping_data(dict1, dict3)) + self.assertTrue(compare_mapping_data(dict1, dict4)) + self.assertFalse(compare_mapping_data(dict1, dict5)) + self.assertTrue(compare_mapping_data(dict1, dict6)) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..cb69ac7723d9bff42ef62aa93fd62131790b5fc2 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_graph_comparator.py @@ -0,0 +1,143 @@ +import unittest +from unittest.mock import patch +from unittest.mock import MagicMock +from msprobe.pytorch.visualization.compare.graph_comparator import GraphComparator +from msprobe.pytorch.visualization.graph.graph import Graph, BaseNode, NodeOp +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestGraphComparator(unittest.TestCase): + + def setUp(self): + self.graphs = [Graph("model1"), Graph("model2")] + self.data_paths = ["step1/rank/dump.json", "step2/rank/dump.json"] + self.stack_path = "step1/rank/stack.json" + self.output_path = "output/output.vis" + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__parse_param(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + self.comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + self.comparator._parse_param(self.data_paths, self.stack_path, self.output_path) + + self.assertEqual(self.comparator.dump_path_param, { + 'npu_json_path': self.data_paths[0], + 'bench_json_path': self.data_paths[1], + 'stack_json_path': self.stack_path, + 'is_print_compare_log': True + }) + self.assertEqual(self.comparator.output_path, self.output_path) + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test_compare(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator._compare_nodes = MagicMock() + comparator._postcompare = MagicMock() + + comparator.compare() + + comparator._compare_nodes.assert_called_once() + comparator._postcompare.assert_called_once() + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test_add_compare_result_to_node(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + node = MagicMock() + compare_result_list = [("output1", "data1"), ("input1", "data2")] + + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.ma = MagicMock() + comparator.ma.prepare_real_data.return_value = True + + comparator.add_compare_result_to_node(node, compare_result_list) + comparator.ma.prepare_real_data.assert_called_once_with(node) + node.data.update.assert_not_called() + + @patch('msprobe.pytorch.visualization.graph.node_colors.NodeColors.get_node_error_status') + @patch('msprobe.pytorch.visualization.utils.get_csv_df') + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.run_real_data') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__postcompare(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode, + mock_run_real_data, mock_get_csv_df, mock_get_node_error_status): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + mock_df = MagicMock() + mock_df.iterrows = MagicMock(return_value=[(None, MagicMock())]) + mock_run_real_data.return_value = mock_df + mock_get_csv_df.return_value = mock_df + mock_get_node_error_status.return_value = True + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.ma = MagicMock() + comparator.ma.is_real_data_compare.return_value = True + comparator._handle_api_collection_index = MagicMock() + comparator.ma.compare_nodes = [MagicMock()] + comparator.ma.parse_result = MagicMock(return_value=(0.9, None)) + + comparator._postcompare() + + comparator._handle_api_collection_index.assert_called_once() + comparator.ma.add_error_key.assert_called() + + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__handle_api_collection_index(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode): + mock_load_data_json_file.return_value = "data_dict" + mock_load_json_file.return_value = "construct_dict" + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + apis = BaseNode(NodeOp.api_collection, 'Apis_Between_Modules.0') + api1 = BaseNode(NodeOp.function_api, 'Tensor.a.0') + api1.data = {GraphConst.JSON_INDEX_KEY: 0.9} + api2 = BaseNode(NodeOp.function_api, 'Tensor.b.0') + api2.data = {GraphConst.JSON_INDEX_KEY: 0.6} + apis.subnodes = [api1, api2] + sub_nodes = [BaseNode(NodeOp.module, 'Module.a.0'), apis, BaseNode(NodeOp.module, 'Module.a.1')] + comparator.graph_n.root.subnodes = sub_nodes + comparator._handle_api_collection_index() + self.assertEqual(comparator.graph_n.root.subnodes[1].data.get(GraphConst.JSON_INDEX_KEY), 0.6) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.compare_node') + @patch('msprobe.pytorch.visualization.graph.graph.Graph.match') + @patch('msprobe.pytorch.visualization.graph.graph.Graph.mapping_match') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.get_compare_mode') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_json_file') + @patch('msprobe.pytorch.visualization.compare.graph_comparator.load_data_json_file') + def test__compare_nodes(self, mock_load_data_json_file, mock_load_json_file, mock_get_compare_mode, + mock_mapping_match, mock_match, mock_compare_node): + node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0') + node_b = BaseNode(NodeOp.function_api, 'Tensor.b.0') + mock_load_data_json_file.return_value = {} + mock_load_json_file.return_value = {} + mock_get_compare_mode.return_value = GraphConst.SUMMARY_COMPARE + mock_mapping_match.return_value = (node_b, [], []) + mock_compare_node.return_value = ['result'] + comparator = GraphComparator(self.graphs, self.data_paths, self.stack_path, self.output_path) + comparator.mapping_config = True + comparator._compare_nodes(node_n) + self.assertEqual(node_n.matched_node_link, ['Tensor.b.0']) + self.assertEqual(node_b.matched_node_link, ['Tensor.a.0']) + comparator.mapping_config = False + node_n = BaseNode(NodeOp.function_api, 'Tensor.a.0') + node_b = BaseNode(NodeOp.function_api, 'Tensor.a.0') + mock_match.return_value = (node_b, []) + comparator._compare_nodes(node_n) + self.assertEqual(node_n.matched_node_link, ['Tensor.a.0']) + self.assertEqual(node_b.matched_node_link, ['Tensor.a.0']) + diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..da76d8e0d57dc700d45a16f4ea79df1ff7ff1707 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/compare/test_mode_adapter.py @@ -0,0 +1,99 @@ +import json +import unittest +from unittest.mock import patch, MagicMock +from msprobe.pytorch.visualization.compare.mode_adapter import ModeAdapter +from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp +from msprobe.pytorch.visualization.utils import GraphConst, ToolTip +from msprobe.core.common.const import CompareConst + + +class TestModeAdapter(unittest.TestCase): + + def setUp(self): + self.node_op = NodeOp.module + self.node_id = "node_1" + self.node = BaseNode(self.node_op, self.node_id) + self.compare_mode = GraphConst.REAL_DATA_COMPARE + self.adapter = ModeAdapter(self.compare_mode) + self.compare_data_dict = [{}, {}] + + def test_add_md5_compare_data(self): + node_data = {'md5_key': 'some_md5_value'} + compare_data_dict = {'md5_key': 'expected_md5_value'} + precision_index = ModeAdapter._add_md5_compare_data(node_data, compare_data_dict) + self.assertEqual(precision_index, 0) + + @patch('msprobe.pytorch.visualization.compare.mode_adapter.ModeAdapter') + def test_parse_result(self, mock_mode_adapter): + mock_mode_adapter._add_summary_compare_data.return_value = 0.5 + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) + self.assertEqual(precision_index, 0.5) + self.assertEqual(other_dict, {}) + + mock_mode_adapter._add_md5_compare_data.return_value = 1 + self.adapter.compare_mode = GraphConst.MD5_COMPARE + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) + self.assertEqual(precision_index, 1) + self.assertEqual(other_dict, {'Result': 'pass'}) + + mock_mode_adapter._add_real_compare_data.return_value = 0.6 + self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE + precision_index, other_dict = self.adapter.parse_result(self.node, self.compare_data_dict) + self.assertEqual(precision_index, 0.0) + self.assertEqual(other_dict, {}) + + def test_prepare_real_data(self): + self.adapter.is_real_data_compare = MagicMock(return_value=True) + result = self.adapter.prepare_real_data(self.node) + self.assertTrue(result) + + self.adapter.is_real_data_compare = MagicMock(return_value=False) + result = self.adapter.prepare_real_data(self.node) + self.assertFalse(result) + + def test_compare_mode_methods(self): + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + self.assertTrue(self.adapter.is_summary_compare()) + self.assertFalse(self.adapter.is_md5_compare()) + self.assertFalse(self.adapter.is_real_data_compare()) + + def test_add_csv_data(self): + compare_result_list = ['result1', 'result2'] + self.adapter.add_csv_data(compare_result_list) + self.assertEqual(self.adapter.csv_data, compare_result_list) + + def test_add_error_key(self): + node_data = {'key': {}} + self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE + self.adapter.add_error_key(node_data) + self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], + [CompareConst.ONE_THOUSANDTH_ERR_RATIO, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO]) + node_data = {'key': {}} + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + self.adapter.add_error_key(node_data) + self.assertEqual(node_data['key'][GraphConst.ERROR_KEY], + [CompareConst.MAX_RELATIVE_ERR, CompareConst.MIN_RELATIVE_ERR, + CompareConst.MEAN_RELATIVE_ERR, CompareConst.NORM_RELATIVE_ERR]) + + def test_get_tool_tip(self): + self.adapter.compare_mode = GraphConst.MD5_COMPARE + tips = self.adapter.get_tool_tip() + self.assertEqual(tips, json.dumps({'md5': ToolTip.MD5})) + + self.adapter.compare_mode = GraphConst.SUMMARY_COMPARE + tips = self.adapter.get_tool_tip() + self.assertEqual(tips, json.dumps({ + CompareConst.MAX_DIFF: ToolTip.MAX_DIFF, + CompareConst.MIN_DIFF: ToolTip.MIN_DIFF, + CompareConst.MEAN_DIFF: ToolTip.MEAN_DIFF, + CompareConst.NORM_DIFF: ToolTip.NORM_DIFF})) + + self.adapter.compare_mode = GraphConst.REAL_DATA_COMPARE + tips = self.adapter.get_tool_tip() + self.assertEqual(tips, json.dumps({ + CompareConst.ONE_THOUSANDTH_ERR_RATIO: ToolTip.ONE_THOUSANDTH_ERR_RATIO, + CompareConst.FIVE_THOUSANDTHS_ERR_RATIO: ToolTip.FIVE_THOUSANDTHS_ERR_RATIO, + CompareConst.COSINE: ToolTip.COSINE, + CompareConst.MAX_ABS_ERR: ToolTip.MAX_ABS_ERR, + CompareConst.MAX_RELATIVE_ERR: ToolTip.MAX_RELATIVE_ERR})) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py new file mode 100644 index 0000000000000000000000000000000000000000..3f0f12582a36b21000416f2603143b87d0032c65 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_base_node.py @@ -0,0 +1,76 @@ +import unittest +from unittest.mock import patch +from msprobe.pytorch.visualization.graph.base_node import BaseNode, NodeOp +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestBaseNode(unittest.TestCase): + + def setUp(self): + self.node_op = NodeOp.module + self.node_id = "node_1" + self.up_node = BaseNode(self.node_op, "up_node_1") + self.node = BaseNode(self.node_op, self.node_id, self.up_node) + + def test_init_and_str(self): + self.assertEqual(self.node.op, self.node_op) + self.assertEqual(self.node.id, self.node_id) + self.assertEqual(str(self.node), 'id:\tnode_1') + + def test_eq(self): + other_node = BaseNode(self.node_op, self.node_id, self.up_node) + self.assertEqual(self.node, other_node) + + def test_get_suggestions(self): + self.node.get_suggestions() + self.assertIn(GraphConst.SUGGEST_KEY, self.node.suggestions) + + node = BaseNode(NodeOp.function_api, "up_node_1") + node.get_suggestions() + self.assertIn(GraphConst.SUGGEST_KEY, node.suggestions) + + def test_set_input_output(self): + input_data = {'input1': 'value1'} + output_data = {'output1': 'value2'} + self.node.set_input_output(input_data, output_data) + self.assertEqual(self.node.input_data, input_data) + self.assertEqual(self.node.output_data, output_data) + + def test_add_upnode(self): + self.node = BaseNode(self.node_op, self.node_id) + new_up_node = BaseNode(self.node_op, "new_up_node_1") + self.node.add_upnode(new_up_node) + self.assertEqual(self.node.upnode, new_up_node) + self.assertIn(self.node, new_up_node.subnodes) + + def test_add_link(self): + other_node = BaseNode(self.node_op, "other_node_1") + ancestors = ['a1', 'a2'] + self.node.add_link(other_node, ancestors) + self.assertEqual(self.node.matched_node_link, ancestors) + self.assertEqual(other_node.matched_node_link, ancestors) + + def test_to_dict(self): + expected_result = { + 'id': self.node_id, + 'node_type': self.node_op.value, + 'data': {}, + 'output_data': {}, + 'input_data': {}, + 'upnode': self.up_node.id, + 'subnodes': [], + 'matched_node_link': [], + 'suggestions': {}, + 'stack_info': [] + } + self.assertEqual(self.node.to_dict(), expected_result) + + def test_get_ancestors(self): + expected_ancestors = ['up_node_1'] + self.assertEqual(self.node.get_ancestors(), expected_ancestors) + + @patch('msprobe.pytorch.visualization.builder.msprobe_adapter.compare_mapping_data') + def test_compare_mapping_node(self, mock_compare_mapping_data): + mock_compare_mapping_data.return_value = True + result = self.node.compare_mapping_node(BaseNode(NodeOp.function_api, "up_node_1")) + self.assertTrue(result) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..e7a55bb36e2291caf14a5e27052c5600f4492d0c --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_graph.py @@ -0,0 +1,102 @@ +import unittest +from unittest.mock import MagicMock +from msprobe.pytorch.visualization.graph.graph import Graph, NodeOp +from msprobe.pytorch.visualization.graph.base_node import BaseNode +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestGraph(unittest.TestCase): + + def setUp(self): + self.graph = Graph("model_name") + self.node_id = "node_id" + self.node_op = NodeOp.module + + def test_add_node_and_get_node(self): + self.graph.add_node(self.node_op, self.node_id) + node = self.graph.get_node(self.node_id) + self.assertIsNotNone(node) + self.assertIn(self.node_id, self.graph.node_map) + + node_id = "api" + graph = Graph("model_name") + for i in range(0, 9): + graph.add_node(NodeOp.function_api, node_id, id_accumulation=True) + self.assertEqual(len(graph.node_map), 10) + self.assertIn("api.0", graph.node_map) + self.assertIn("api.8", graph.node_map) + self.assertNotIn("api", graph.node_map) + + def test_to_dict(self): + self.graph.add_node(self.node_op, self.node_id) + result = self.graph.to_dict() + self.assertEqual(result[GraphConst.JSON_ROOT_KEY], "model_name") + self.assertIn(self.node_id, result[GraphConst.JSON_NODE_KEY]) + + def test_str(self): + self.graph.add_node(self.node_op, self.node_id) + expected_str = f'{self.node_id}' + self.assertIn(expected_str, str(self.graph)) + + def test_match(self): + graph_a = Graph("model_name_a") + graph_b = Graph("model_name_b") + node_a = BaseNode(self.node_op, self.node_id) + graph_a.add_node(NodeOp.module, "node_id_a") + graph_b.add_node(NodeOp.module, "node_id_b") + matched_node, ancestors = Graph.match(graph_a, node_a, graph_b) + self.assertIsNone(matched_node) + self.assertEqual(ancestors, []) + + graph_b.add_node(NodeOp.module, "node_id_a") + graph_a.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a")) + graph_b.add_node(NodeOp.module, "node_id_a_1", graph_a.get_node("node_id_a")) + matched_node, ancestors = Graph.match(graph_a, graph_a.get_node("node_id_a_1"), graph_b) + self.assertIsNotNone(matched_node) + self.assertEqual(ancestors, ['node_id_a']) + + def test_dfs(self): + graph = Graph("model_name") + graph.add_node(NodeOp.module, "node_a") + graph.add_node(NodeOp.module, "node_b") + node_a = BaseNode(self.node_op, self.node_id) + result = {} + graph.dfs(node_a, result) + self.assertEqual(result, {'node_id': {'id': 'node_id', 'node_type': 0, 'data': {}, + 'output_data': {}, 'input_data': {}, 'upnode': 'None', 'subnodes': [], + 'matched_node_link': [], 'suggestions': {}}}) + + def test_split_nodes_by_micro_step(self): + nodes = [BaseNode(NodeOp.module, 'a.0'), BaseNode(NodeOp.module, 'b.0'), + BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.1'), + BaseNode(NodeOp.module, 'b.1'), BaseNode(NodeOp.api_collection, 'apis.1')] + result = Graph.split_nodes_by_micro_step(nodes) + self.assertEqual(len(result), 2) + self.assertEqual(len(result[0]), 3) + + def test_paging_by_micro_step(self): + nodes = [BaseNode(NodeOp.module, 'a.0'), BaseNode(NodeOp.module, 'b.0'), + BaseNode(NodeOp.api_collection, 'apis.0'), BaseNode(NodeOp.module, 'a.1'), + BaseNode(NodeOp.module, 'b.1'), BaseNode(NodeOp.api_collection, 'apis.1')] + + graph = Graph('Model1') + graph.root.subnodes = nodes + graph_other = Graph('Model2') + graph_other.root.subnodes = nodes + + result = graph.paging_by_micro_step(graph_other) + self.assertEqual(result, 2) + self.assertEqual(graph.root.subnodes[0].micro_step_id, 0) + self.assertEqual(graph_other.root.subnodes[0].micro_step_id, 0) + + def test_mapping_match(self): + mapping_config = MagicMock() + graph_a = Graph("model_name_a") + graph_b = Graph("model_name_b") + graph_a.add_node(NodeOp.module, "a1", BaseNode(NodeOp.module, "root")) + graph_b.add_node(NodeOp.module, "b1", BaseNode(NodeOp.module, "root")) + mapping_config.get_mapping_string.return_value = "b1" + node_b, ancestors_n, ancestors_b = Graph.mapping_match(graph_a.get_node("a1"), graph_b, mapping_config) + self.assertIsNotNone(node_b) + self.assertEqual(ancestors_n, ["root"]) + self.assertEqual(ancestors_b, ["root"]) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py new file mode 100644 index 0000000000000000000000000000000000000000..0be05586b150cfd39670535cc3015925d9e2f44e --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_colors.py @@ -0,0 +1,67 @@ +import unittest +from msprobe.pytorch.visualization.graph.node_colors import NodeColors, SUMMARY_DESCRIPTION, REAL_DATA_DESCRIPTION, \ + NOT_MATCHED +from msprobe.pytorch.visualization.utils import GraphConst + + +class TestNodeColors(unittest.TestCase): + + def test_get_info_by_mode(self): + node_yellow = NodeColors.YELLOW_1 + summary_info = node_yellow.get_info_by_mode(GraphConst.SUMMARY_COMPARE) + self.assertEqual(summary_info[GraphConst.VALUE], [0, 0.2]) + self.assertEqual(summary_info[GraphConst.DESCRIPTION], SUMMARY_DESCRIPTION) + node_grey = NodeColors.GREY + md5_info = node_grey.get_info_by_mode(GraphConst.MD5_COMPARE) + self.assertEqual(md5_info[GraphConst.VALUE], []) + self.assertEqual(md5_info[GraphConst.DESCRIPTION], NOT_MATCHED) + node_red = NodeColors.RED + real_info = node_red.get_info_by_mode(GraphConst.REAL_DATA_COMPARE) + self.assertEqual(real_info[GraphConst.VALUE], [0.2, 1]) + self.assertEqual(real_info[GraphConst.DESCRIPTION], REAL_DATA_DESCRIPTION) + none_info = node_yellow.get_info_by_mode("non_existent_mode") + self.assertEqual(none_info, {}) + + def test_get_node_colors(self): + # 测试获取所有颜色信息的函数 + mode = GraphConst.SUMMARY_COMPARE + colors_info = NodeColors.get_node_colors(mode) + self.assertIn("#FFFCF3", colors_info) + self.assertIn("#FFEDBE", colors_info) + self.assertIn("#FFDC7F", colors_info) + self.assertIn("#FFC62E", colors_info) + self.assertIn("#E32020", colors_info) + self.assertIn("#C7C7C7", colors_info) + + # 确保返回的字典具有正确的描述和值范围 + expected_value_range = [0, 0.2] + expected_description = "此节点所有输入输出的统计量相对误差, 值越大代表测量值与标杆值的偏差越大, 相对误差计算方式:|(测量值-标杆值)/标杆值|" + self.assertEqual(colors_info["#FFFCF3"][GraphConst.VALUE], expected_value_range) + self.assertEqual(colors_info["#FFFCF3"][GraphConst.DESCRIPTION], expected_description) + + mode = GraphConst.MD5_COMPARE + colors_info = NodeColors.get_node_colors(mode) + self.assertIn("#FFFCF3", colors_info) + self.assertIn("#C7C7C7", colors_info) + self.assertNotIn("#FFDC7F", colors_info) + + expected_value_range = [1, 1] + expected_description = "与标杆相比, 此节点所有输入输出的md5值相同" + self.assertEqual(colors_info["#FFFCF3"][GraphConst.VALUE], expected_value_range) + self.assertEqual(colors_info["#FFFCF3"][GraphConst.DESCRIPTION], expected_description) + + def test_get_node_error_status(self): + # 测试错误状态判断功能 + mode = GraphConst.SUMMARY_COMPARE + value0 = 0 + value1 = 0.25 + value2 = 0.55 + value3 = 111 + self.assertFalse(NodeColors.get_node_error_status(mode, value0)) + self.assertFalse(NodeColors.get_node_error_status(mode, value1)) + self.assertTrue(NodeColors.get_node_error_status(mode, value2)) + self.assertTrue(NodeColors.get_node_error_status(mode, value3)) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py new file mode 100644 index 0000000000000000000000000000000000000000..1a340ac8b3c7144a9e07485c93e289a950eee8c7 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/graph/test_node_op.py @@ -0,0 +1,28 @@ +import unittest +from msprobe.pytorch.visualization.graph.node_op import NodeOp + + +class TestNodeOp(unittest.TestCase): + + def test_get_node_op_valid(self): + node_name = "ModuleTest" + self.assertEqual(NodeOp.get_node_op(node_name), NodeOp.module) + + def test_get_node_op_invalid(self): + node_name = "InvalidNodeName" + with self.assertRaises(Exception): + NodeOp.get_node_op(node_name) + + def test_get_node_op_all(self): + test_cases = [ + ("ModuleTest", NodeOp.module), + ("TensorTest", NodeOp.function_api), + ("TorchTest", NodeOp.function_api), + ("FunctionalTest", NodeOp.function_api), + ("NPUTest", NodeOp.function_api), + ("VFTest", NodeOp.function_api), + ("DistributedTest", NodeOp.function_api), + ("AtenTest", NodeOp.function_api) + ] + for node_name, expected_op in test_cases: + self.assertEqual(NodeOp.get_node_op(node_name), expected_op) diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml new file mode 100644 index 0000000000000000000000000000000000000000..8b2f85ebf872aae4b3377842ac899824da5877f9 --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/mapping.yaml @@ -0,0 +1,2 @@ +- vision_model: "language_model.vision_encoder" +- vision_projection: "language_model.projection" \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py new file mode 100644 index 0000000000000000000000000000000000000000..010a4f686198ef1127299fcad1b9a5abf6505d6b --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_mapping_config.py @@ -0,0 +1,52 @@ +import os +import unittest +from msprobe.pytorch.visualization.mapping_config import MappingConfig + + +class TestMappingConfig(unittest.TestCase): + + def setUp(self): + self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") + + def test_validate(self): + with self.assertRaises(ValueError): + MappingConfig.validate(123, "some value") + with self.assertRaises(ValueError): + MappingConfig.validate("some key", 456) + self.assertEqual(MappingConfig.validate("key", "value"), "value") + + def test_convert_to_regex(self): + regex = MappingConfig.convert_to_regex("hello{world}") + self.assertEqual(regex, ".*hello\\{world\\}.*") + + def test_replace_parts(self): + result = MappingConfig._replace_parts('hello world', 'world', 'everyone') + self.assertEqual(result, 'hello everyone') + result = MappingConfig._replace_parts('radio_model.layers.0.input_norm', 'radio_model.layers.{}.input_norm', + 'radio_model.transformer.layers.{}.input_layernorm') + self.assertEqual(result, 'radio_model.transformer.layers.0.input_layernorm') + + def test_get_mapping_string(self): + mc = MappingConfig(self.yaml_path) + mc.classify_config = { + 'category1': [('category1.key1', 'replacement1')], + 'category2': [('category2.key1', 'replacement2')] + } + result = mc.get_mapping_string("some category1.key1 text") + self.assertEqual(result, "some replacement1 text") + + def test_long_string(self): + long_string = "x" * (MappingConfig.MAX_STRING_LEN + 1) + mc = MappingConfig(self.yaml_path) + result = mc.get_mapping_string(long_string) + self.assertEqual(result, long_string) + + def test__classify_and_sort_keys(self): + mc = MappingConfig(self.yaml_path) + result = mc._classify_and_sort_keys() + self.assertEqual(result, {'vision_model': [('vision_model', 'language_model.vision_encoder')], + 'vision_projection': [('vision_projection', 'language_model.projection')]}) + + +if __name__ == '__main__': + unittest.main() diff --git a/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..56493235cd63d97b8a2beca0358c59ed16a154cf --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/pytorch_ut/visualization/test_utils.py @@ -0,0 +1,27 @@ +import os +import unittest +from msprobe.pytorch.visualization.utils import (load_json_file, load_data_json_file, str2float) + + +class TestMappingConfig(unittest.TestCase): + + def setUp(self): + self.yaml_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "mapping.yaml") + + def test_load_json_file(self): + result = load_json_file(self.yaml_path) + self.assertEqual(result, {}) + + def test_load_data_json_file(self): + result = load_data_json_file(self.yaml_path) + self.assertEqual(result, {}) + + def test_str2float(self): + result = str2float('23.4%') + self.assertAlmostEqual(result, 0.234) + result = str2float('2.3.4%') + self.assertAlmostEqual(result, 0) + + +if __name__ == '__main__': + unittest.main()