From b053cd234be9dbd432f0c6145f0598ecf8c668d8 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Thu, 10 Oct 2024 16:14:23 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=80=9A=E8=BF=87?= =?UTF-8?q?=E6=98=A0=E5=B0=84=E6=96=87=E4=BB=B6=E8=BF=9B=E8=A1=8Cpytorch?= =?UTF-8?q?=E5=92=8Cmindspore=E6=AF=94=E5=AF=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../visualization/builder/msprobe_adapter.py | 4 +- .../visualization/compare/graph_comparator.py | 7 +++ .../pytorch/visualization/graph/graph.py | 52 +++++++++++++++++++ .../pytorch/visualization/mapping_config.py | 47 ++++++++++++++--- 4 files changed, 100 insertions(+), 10 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py index 92af6d6732..3a0c566cf6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py @@ -21,8 +21,8 @@ from msprobe.pytorch.visualization.utils import GraphConst # 用于将节点名字解析成对应的NodeOp的规则 op_patterns = [ - r'^(Module)', #NodeOp.module - r'^(Tensor|Torch|Functional|NPU|VF|Distributed|Aten)' #NodeOp.function_api + r'^(Module.|Cell.)', #NodeOp.module + r'^(Tensor.|Torch.|Functional.|NPU.|VF.|Distributed.|Aten.|Mint.|Primitive.|Jit.|MintFunctional.)' #NodeOp.function_api ] diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py index a5ec6c8db4..e7826b5c7d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/compare/graph_comparator.py @@ -118,6 +118,13 @@ class GraphComparator: if node_b: ancestors.append(node_b.id) node_n.add_link(node_b, ancestors) + if not node_b and node_n.op == NodeOp.function_api: + node_b, ancestors_n, ancestors_b = Graph.fuzzy_match_api(node_n, self.graph_b) + if node_b: + ancestors_n.append(node_n.id) + ancestors_b.append(node_b.id) + node_n.matched_node_link = ancestors_b + node_b.matched_node_link = ancestors_n if node_b: # 真实数据比对只会得到基本信息,并没有精度指标,需要调用多进程对比接口 compare_result_list = compare_node([node_n.id, node_b.id], diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py index 2ca5c6811b..15b3c7b970 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/graph/graph.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import re from msprobe.pytorch.visualization.graph.base_node import BaseNode from msprobe.pytorch.visualization.graph.node_op import NodeOp from msprobe.pytorch.visualization.utils import GraphConst @@ -21,6 +22,8 @@ from msprobe.core.common.const import Const class Graph: + fuzzy_mapping = {} + def __init__(self, model_name): self.node_map = {} self.node_id_map = {} @@ -58,10 +61,36 @@ class Graph: node_b = graph_b.node_map.get(mapping_config.get_mapping_string(node_n.id)) if not node_b or not node_n.compare_mapping_node(node_b): return None, [], [] + if node_n.op == NodeOp.function_api and node_b.upnode and node_n.upnode \ + and node_b.upnode.id != mapping_config.get_mapping_string(node_n.upnode.id): + return None, [], [] ancestors_n = node_n.get_ancestors() ancestors_b = node_b.get_ancestors() return node_b, ancestors_n, ancestors_b + @staticmethod + def fuzzy_match_api(node_n, graph_b): + """ + 对api节点进行模糊匹配,具体方式为同一module里的api,忽略工具采集赋予的序号,按照位置先后顺序进行匹配 + """ + up_node_n = node_n.upnode + if not up_node_n.matched_node_link: + return None, [], [] + up_node_b = graph_b.node_map.get(up_node_n.matched_node_link[-1]) + if not up_node_b: + return None, [], [] + + api_fuzzy_mapping = Graph.fuzzy_mapping.get(up_node_n.id) + if api_fuzzy_mapping is None: + api_fuzzy_mapping = Graph._make_api_fuzzy_mapping(up_node_n, up_node_b) + Graph.fuzzy_mapping[up_node_n.id] = api_fuzzy_mapping + + if node_n.id in api_fuzzy_mapping: + node_b = graph_b.node_map.get(api_fuzzy_mapping[node_n.id]) + return node_b, node_n.get_ancestors(), node_b.get_ancestors() + else: + return None, [], [] + @staticmethod def dfs(node, result): info = node.to_dict() @@ -95,6 +124,29 @@ class Graph: result[default_id].append(node) return result + @staticmethod + def _make_api_fuzzy_mapping(up_node_n, up_node_b): + api_fuzzy_mapping = {} + count_n = {} + origin_names_n = [item.id for item in up_node_n.subnodes] + origin_names_b = [item.id for item in up_node_b.subnodes] + names_n = [re.sub(r'\.\d+', '', item.id) for item in up_node_n.subnodes] + names_b = [re.sub(r'\.\d+', '', item.id) for item in up_node_b.subnodes] + + # 记录module_n中每个api名称出现的次数和位置 + for index, name in enumerate(names_n): + if name in count_n: + count_n[name].append(index) + else: + count_n[name] = [index] + + # 遍历module_b,为每个名称找到对应的module_n中的字符串 + for index, name in enumerate(names_b): + if name in count_n and count_n[name]: + n_index = count_n[name].pop(0) + api_fuzzy_mapping[origin_names_n[n_index]] = origin_names_b[index] + return api_fuzzy_mapping + def add_node(self, node_op, node_id, up_node=None, id_accumulation=False): """ 在graph中进行节点的添加 diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py index d986493513..fafba4d897 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/mapping_config.py @@ -15,7 +15,9 @@ class MappingConfig: self.config = {key: self.validate(key, value) for data in config for key, value in data.items()} except Exception as e: raise RuntimeError("Line of yaml contains content that is not '- key: value'.") from e - self.classify_config = self._classify_and_sort_keys() + self.is_exact_match = False + self._preprocess_mapping_data() + self.classify_config = self._classify_and_sort_keys() if self.is_exact_match else {} @staticmethod def validate(key, value): @@ -48,13 +50,16 @@ class MappingConfig: return origin_string.replace(mapping_key, mapping_value) def get_mapping_string(self, origin_string: str): - if len(origin_string) > MappingConfig.MAX_STRING_LEN: - return origin_string - for category, items in self.classify_config.items(): - if category in origin_string: - for key, value in items: - if re.match(MappingConfig.convert_to_regex(key), origin_string): - return MappingConfig._replace_parts(origin_string, key, value) + if self.is_exact_match and origin_string in self.config: + return self.config.get(origin_string) + else: + if len(origin_string) > MappingConfig.MAX_STRING_LEN: + return origin_string + for category, items in self.classify_config.items(): + if category in origin_string: + for key, value in items: + if re.match(MappingConfig.convert_to_regex(key), origin_string): + return MappingConfig._replace_parts(origin_string, key, value) return origin_string def _classify_and_sort_keys(self): @@ -75,3 +80,29 @@ class MappingConfig: categorized_dict[category].sort(key=lambda x: -x[0].count(Const.SEP)) return categorized_dict + + def _preprocess_mapping_data(self): + """ + 预处理以.数字为后缀的映射数据字符串,添加forward和backward,并对此格式的数据进行精确匹配 + Cell.model.GPTModel.0 ---> Cell.model.GPTModel.forward.0, Cell.model.GPTModel.backward.0 + Mint.cat.0 ---> Mint.cat.0.forward, Mint.cat.0.backward + """ + + def handle_mapping_string(s): + if re.match(r'^(Module|Cell)\.', s) and not re.search(r'\.(forward|backward)\.\d+$', s): + last_dot_index = s.rfind('.') + if last_dot_index != -1: + head = s[:last_dot_index] + tail = s[last_dot_index:] + return head + Const.SEP + Const.FORWARD + tail, head + Const.SEP + Const.BACKWARD + tail + elif not re.search(r'\.\d+\.(forward|backward)$', s): + return s + Const.SEP + Const.FORWARD, s + Const.SEP + Const.BACKWARD + return s, s + + if re.search(r'\.\d+$', next(iter(self.config))): + self.is_exact_match = True + config = {} + for key, value in self.config.items(): + config[handle_mapping_string(key)[0]] = handle_mapping_string(value)[0] + config[handle_mapping_string(key)[1]] = handle_mapping_string(value)[1] + self.config = config -- Gitee From 3cb99999afb1f357724a2f98c57ad6301bb31883 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Thu, 10 Oct 2024 16:19:39 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E6=94=AF=E6=8C=81=E9=80=9A=E8=BF=87?= =?UTF-8?q?=E6=98=A0=E5=B0=84=E6=96=87=E4=BB=B6=E8=BF=9B=E8=A1=8Cpytorch?= =?UTF-8?q?=E5=92=8Cmindspore=E6=AF=94=E5=AF=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/pytorch/visualization/builder/msprobe_adapter.py | 2 +- debug/accuracy_tools/msprobe/pytorch/visualization/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py index 3a0c566cf6..b7a643daca 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/builder/msprobe_adapter.py @@ -85,7 +85,7 @@ def compare_data(data_dict_list1, data_dict_list2): if len(data_dict_list1) != len(data_dict_list2): return False # 用于比较两个节点是否相等的关键字段 - tag_keys = ['type', 'dtype', 'shape'] + tag_keys = ['type', 'shape'] for key1, key2 in zip(data_dict_list1, data_dict_list2): dict1 = data_dict_list1[key1] dict2 = data_dict_list2[key2] diff --git a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py index 8a81f4f209..6cd995093f 100644 --- a/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py +++ b/debug/accuracy_tools/msprobe/pytorch/visualization/utils.py @@ -65,7 +65,7 @@ def str2float(percentage_str): try: percentage_str = percentage_str.strip('%') return float(percentage_str) / 100 - except ValueError: + except (ValueError, AttributeError): return 0 -- Gitee