From c2103d2d7946c3d3e61aa451b03619e461a864c4 Mon Sep 17 00:00:00 2001 From: 18868120292 Date: Thu, 23 May 2024 00:36:32 +0800 Subject: [PATCH] common_part2 --- profiler/advisor/common/__init__.py | 0 profiler/advisor/common/analyzer_scopes.py | 9 + profiler/advisor/common/graph/__init__.py | 0 profiler/advisor/common/graph/graph.py | 135 ++++++ profiler/advisor/common/graph/graph_match.py | 355 +++++++++++++++ profiler/advisor/common/graph/graph_parser.py | 413 ++++++++++++++++++ profiler/advisor/common/version_control.py | 26 ++ 7 files changed, 938 insertions(+) create mode 100644 profiler/advisor/common/__init__.py create mode 100644 profiler/advisor/common/analyzer_scopes.py create mode 100644 profiler/advisor/common/graph/__init__.py create mode 100644 profiler/advisor/common/graph/graph.py create mode 100644 profiler/advisor/common/graph/graph_match.py create mode 100644 profiler/advisor/common/graph/graph_parser.py create mode 100644 profiler/advisor/common/version_control.py diff --git a/profiler/advisor/common/__init__.py b/profiler/advisor/common/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/profiler/advisor/common/analyzer_scopes.py b/profiler/advisor/common/analyzer_scopes.py new file mode 100644 index 000000000..0c6a2ac26 --- /dev/null +++ b/profiler/advisor/common/analyzer_scopes.py @@ -0,0 +1,9 @@ +class SupportedScopes: + + # used for specify fourth-level commands and define the key of the result dict + # the key defined bellow must be the same as value + TIMELINE_FUSION_OPS = "timeline_fusion_ops" + GRAPH = "graph" + SLOW_RANK = "slow_rank" + SLOW_LINK = "slow_link" + PORFILING_OPERATOR_ANALYSIS = "profiling_operator_analysis" diff --git a/profiler/advisor/common/graph/__init__.py b/profiler/advisor/common/graph/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/profiler/advisor/common/graph/graph.py b/profiler/advisor/common/graph/graph.py new file mode 100644 index 000000000..6bab2042d --- /dev/null +++ b/profiler/advisor/common/graph/graph.py @@ -0,0 +1,135 @@ +import logging +from typing import Dict, List, Tuple, Callable, Any, Optional, Union + +import networkx as nx + +from profiler.advisor.common.graph.graph_parser import HostGraphNode, QueryGraphNode + +logger = logging.getLogger() + + +class Graph: + """ + Graph Struct + """ + + # pylint: disable=too-many-instance-attributes + def __init__(self, + nodes: Dict[str, Optional[Union[HostGraphNode, QueryGraphNode]]] = None, + edges: List[Tuple[Optional[Union[HostGraphNode, QueryGraphNode]], + Optional[Union[HostGraphNode, QueryGraphNode]]]] = None, + name: str = None): + self.name = name + self.graph = nx.DiGraph(name=name) + self.nodes = nodes if nodes is not None else {} + self.edges = edges if edges is not None else list() + + def build(self): + for op_name, node in self.nodes.items(): + # add node and mark op_name as tag + self.add_node(node, + op_type=node.op_type + ) + for edge in self.edges: + self.add_edge(*edge) + return self.graph + + def get_size(self) -> Dict[str, int]: + if not hasattr(self.graph, "nodes"): + return {"edges": 0, "nodes": 0} + + return {"edges": len(self.graph.edges), + "nodes": len(self.graph.nodes)} + + def add_node(self, node: HostGraphNode, **kwargs): + if node is None: + return + self.graph.add_node(node, **kwargs) + + def add_edge(self, pre_node: HostGraphNode, next_node: HostGraphNode): + if pre_node is None or next_node is None: + return + + if pre_node not in self.graph or \ + next_node not in self.graph: + logging.error("Nodes between edge should be both exists.") + return + + self.graph.add_edge(pre_node, next_node) + + def add_node_with_edge(self, node, adj_nodes: List[HostGraphNode]): + self.add_node(node) + for adj in adj_nodes: + self.add_edge(node, adj) + + def remove_node(self, node: HostGraphNode = None) -> None: + if node is None: + return + + self.graph.remove_node(node) + + def remove_edge(self, pre_node: HostGraphNode = None, next_node: HostGraphNode = None) -> None: + if pre_node is None or next_node is None: + raise ValueError(f"Invalid edge from {pre_node} to {pre_node}.") + + self.remove_edge(pre_node, next_node) + + def get_subgraph(self, nodes: List[HostGraphNode]) -> nx.DiGraph: + nodes = list(set(nodes)) + for node in nodes: + if not self.is_node_exists(node): + raise ValueError(f"Failed to subtract subgraph because {node.op_name} is not in the graph.") + + return self.graph.subgraph(nodes) + + def highlight_subgraph(self, subgraph: nx.DiGraph = None) -> None: + pass + + def get_node(self, node: HostGraphNode): + if node not in self.graph: + return + + return self.graph[node] + + def get_node_by_name(self, node_name: str): + return self.nodes.get(node_name, None) + + def is_node_exists(self, node: HostGraphNode): + return node in self.graph + + def draw(self, + graph: nx.DiGraph = None, + with_labels: bool = False, + labels: Dict[HostGraphNode, Any] = None, + pos_func: Callable = None, + font_weight: str = "bold", + savefig: bool = False, + node_size: int = 50, + **kwargs + ): + try: + import matplotlib.pylab as plt + except ImportError: + logger.error('Please install matplotlib first by using `pip install matplotlib`.') + return + + if graph is None: + graph = self.graph + + pos = pos_func(graph) if pos_func is not None else None + + if with_labels: + if labels is None: + labels = {k: f"{k}\n({v['op_name']})" for k, v in graph.nodes.items()} + + nx.draw(graph, + with_labels=with_labels, + pos=pos, + node_size=node_size, + font_weight=font_weight, + labels=labels, + **kwargs + ) + if savefig: + plt.savefig(self.name + ".png") + plt.show() diff --git a/profiler/advisor/common/graph/graph_match.py b/profiler/advisor/common/graph/graph_match.py new file mode 100644 index 000000000..d0dfc1629 --- /dev/null +++ b/profiler/advisor/common/graph/graph_match.py @@ -0,0 +1,355 @@ +import itertools +import logging +from functools import lru_cache +from collections import deque +from typing import Dict, Generator, List, Callable, Hashable, Tuple + +import networkx as nx + + +@lru_cache() +def match_node_attr_fun(query_node: Hashable, + host_node: Hashable, + query_graph: nx.Graph, + host_graph: nx.Graph + ) -> bool: + """ + Check query node matches the attributes in host graph + + :param query_node: Query graph node + :param host_node: Host graph node + :param query_graph: Query Graph + :param host_graph: Host graph + :return: bool, match or not + """ + # get node attr + if query_node not in query_graph.nodes or host_node not in host_graph.nodes: + return False + + query_node = query_graph.nodes[query_node] + host_node = host_graph.nodes[host_node] + for attr, val in query_node.items(): + if attr not in host_node: + return False + if isinstance(host_node[attr], str) and isinstance(val, str): + if host_node[attr].lower() != val.lower(): + return False + else: + if host_node[attr] != val: + return False + return True + + +@lru_cache() +def match_node_struct_fun(query_node: Hashable, + host_node: Hashable, + query_graph: nx.Graph, + host_graph: nx.Graph + ) -> bool: + """ + Check query node matches the structure in host graph + + :param query_node: Query graph node + :param host_node: Host graph node + :param query_graph: Query Graph + :param host_graph: Host graph + :return: bool, match or not + """ + if query_node not in query_graph.nodes or host_node not in host_graph.nodes: + return False + + return host_graph.degree(host_node) >= query_graph.degree(query_node) + + +@lru_cache() +def match_edge_attr_fun(query_edge: Tuple[Hashable, Hashable], + host_edge: Tuple[Hashable, Hashable], + query_graph: nx.Graph, + host_graph: nx.Graph + ) -> bool: + """ + Check query edge matches the attr in host graph + + :param query_edge: Query graph edge + :param host_edge: Host graph edge + :param query_graph: Query Graph + :param host_graph: Host graph + :return: bool, match or not + """ + # get edge attr + if query_edge not in query_graph.edges or host_edge not in host_graph.edges: + return False + + query_edge = query_graph.edges[query_edge] + host_edge = host_graph.edges[host_edge] + for attr, val in query_edge.items(): + if attr not in host_edge: + return False + if isinstance(host_edge[attr], str) and isinstance(val, str): + if host_edge[attr].lower() != val.lower(): + return False + else: + if host_edge[attr] != val: + return False + return True + + +def find_isomorphisms(query_graph: nx.Graph, + host_graph: nx.Graph, + *args, + _node_attr_fun: Callable = match_node_attr_fun, + _node_struct_fun: Callable = match_node_struct_fun, + _edge_attr_fun: Callable = match_edge_attr_fun, + limit: int = None, + **kwargs) -> List[Dict[Hashable, Hashable]]: + """ + Find all the sub graphs that are isomorphic to query_graph in host_graph . + + :param query_graph: The graph object to query + :param host_graph: The graph object to be queried + :param args: Position args + :param _node_attr_fun: The function to match node attr + :param _node_struct_fun: The function to match node structural + :param _edge_attr_fun: The function to match edge attr + :param limit: The limitation for the number of returned mappings + :param kwargs: Keyword args + :return: Matched node mapping list + ``` + [{query_id: host_id, ...}, ...] + ``` + """ + candidates = [] + for query_result in find_isomorphisms_iter( + query_graph, + host_graph, + *args, + _node_attr_fun=_node_attr_fun, + _node_struct_fun=_node_struct_fun, + _edge_attr_fun=_edge_attr_fun, + **kwargs + ): + candidates.append(query_result) + if limit and len(candidates) >= limit: + return candidates + return candidates + + +def find_isomorphisms_iter(query_graph: nx.Graph, + host_graph: nx.Graph, + directed: bool = None, + _node_attr_fun: Callable = None, + _node_struct_fun: Callable = None, + _edge_attr_fun: Callable = None, + ) -> Generator[Dict[Hashable, Hashable], None, None]: + """ + A generation to find one isomorphic subgraph in host_graph for query_graph. + + :param query_graph: The graph object to query + :param host_graph: The graph object to be queried + :param directed: Whether direction should be considered during search + :param _node_attr_fun: The function to match node attr + :param _node_struct_fun: The function to match node structural + :param _edge_attr_fun: The function to match edge attr + :return: Yield mappings from query node IDs to host graph IDs: {query_id: host_id, ...} + + """ + if directed is None: + # query graph and host graph should consider directions. + if isinstance(query_graph, nx.DiGraph) and \ + isinstance(host_graph, nx.DiGraph): + directed = True + else: + directed = False + + # Initialize queue + dq = deque() + dq.appendleft({}) + + while len(dq) > 0: + backbone = dq.pop() + next_candidate_backbones = get_next_candidates(backbone=backbone, + query_graph=query_graph, + host_graph=host_graph, + directed=directed, + _node_attr_fun=_node_attr_fun, + _node_struct_fun=_node_struct_fun, + _edge_attr_fun=_edge_attr_fun, + ) + for candidate in next_candidate_backbones: + # find a legal isomorphism + if len(candidate) == len(query_graph): + yield candidate + else: + # continue to search + dq.appendleft(candidate) + + +def get_next_candidates( + backbone: Dict, + query_graph: nx.Graph, # noqa + host_graph: nx.Graph, # noqa + next_node: Hashable = None, + directed: bool = True, # noqa + _node_attr_fun: Callable = None, # noqa + _node_struct_fun: Callable = None, # noqa + _edge_attr_fun: Callable = None # noqa +) -> List[Dict[Hashable, Hashable]]: + """ + Get a list of candidate node assignments for the next "step" of this map. + + :param backbone: Mapping of query node IDs to one set of host graph IDs + :param next_node: Optional suggestion for the next node to assign + :return: List[Dict[Hashable, Hashable]]: A new list of node mappings with one additional element mapped + """ + node_priority = {n: 1 for n in query_graph.nodes} + candidate_nodes = [] + + if next_node is None and len(backbone) == 0: + # Start case + next_node = max(node_priority.keys(), + key=lambda x: node_priority.get(x, 0)) + + for node in host_graph.nodes: + if _node_attr_fun(next_node, node, query_graph, host_graph) and \ + _node_struct_fun(next_node, node, query_graph, host_graph): + candidate_nodes.append({next_node: node}) + return candidate_nodes + + nodes_with_maximum_backbone = [] + for query_node_id in query_graph.nodes: + if query_node_id in backbone: + continue + + backbone_neighbors = [] + if not directed: + backbone_neighbors = query_graph.adj[query_node_id] + else: + # nx.DiGraph.pred: A <- B: find previous node from B to A + # nx.DiGraph.adj: A -> B : find next node from A to B + backbone_neighbors = list(set(query_graph.adj[query_node_id]).union(set(query_graph.pred[query_node_id]))) + + query_backbone_node_count = sum([1 for _node in backbone_neighbors if _node in backbone]) + if query_backbone_node_count > 0: + # Find a longer backbone node + nodes_with_maximum_backbone.append(query_node_id) + + # next_node is connected to the current backbone. + next_node = max(nodes_with_maximum_backbone, key=lambda x: node_priority.get(x, 0)) + + # verify all edges between `next_node` and nodes in the backbone are exist in host graph + # Step1: find all edges between `next_node` and nodes in the backbone + next_edge_edges = [] + for _node in query_graph.adj[next_node]: + if _node in backbone: + # `next_node` -> `_node` + next_edge_edges.append((None, next_node, _node)) + + if directed: + for _node in query_graph.pred[next_node]: + if _node in backbone: + # `_node` -> `next_node` + next_edge_edges.append((_node, next_node, None)) + + if len(next_edge_edges) == 0: + logging.warning("Find node without any edge, which is invalid.") + return [] + # Step2: verify candidate nodes that have such edges in the host graph + candidate_nodes = [] + if len(next_edge_edges) == 1: + source, _, target = next_edge_edges[0] + if not directed: + candidate_nodes = list(host_graph.adj[backbone[target]]) + else: + if source is not None: + # means `source` is a `from` edge + candidate_nodes = list(host_graph.adj[backbone[source]]) + elif target is not None: + # means `target` is a `from` edge + candidate_nodes = list(host_graph.pred[backbone[target]]) + + elif len(next_edge_edges) > 1: + candidate_nodes_set = set() + for (source, _, target) in candidate_nodes: + if not directed: + candidate_nodes_from_this_edge = host_graph.adj[backbone[target]] + else: + if source is not None: + candidate_nodes_from_this_edge = host_graph.adj[backbone[source]] + else: # target is not None: + candidate_nodes_from_this_edge = host_graph.pred[backbone[target]] + + if len(candidate_nodes_set) > 0: + candidate_nodes_set = candidate_nodes_set.intersection(candidate_nodes_from_this_edge) + else: + # Initialize candidate_nodes_set + candidate_nodes_set.update(candidate_nodes_from_this_edge) + candidate_nodes = list(candidate_nodes_set) + + tentative_results = [] + for _node in candidate_nodes: + if all([_node not in backbone.values(), + _node_attr_fun(next_node, _node, query_graph, host_graph), + _node_struct_fun(next_node, _node, query_graph, host_graph)] + ): + tentative_results.append({**backbone, + next_node: _node}) + + final_candidates = check_edges_mapping(tentative_results, + query_graph=query_graph, + host_graph=host_graph, + _edge_attr_fun=_edge_attr_fun) + return final_candidates + + +def check_edges_mapping(candidates: List[Dict[Hashable, Hashable]], + query_graph: nx.Graph, + host_graph: nx.Graph, + _edge_attr_fun: Callable = None + ) -> List[Dict[Hashable, Hashable]]: + """ + Check that all edges between the assigned nodes exist in the host graph. + + :param candidates: mapping nodes candidates + :param query_graph: The graph object to query + :param host_graph: The graph object to be queried + :param _edge_attr_fun: The function to match edge attr + :return: + """ + monomorphism_candidates = [] + + for candidate in candidates: + if len(candidate) != len(query_graph): + monomorphism_candidates.append(candidate) + continue + + all_pass_flag = True + for edge_start, edge_end in query_graph.edges: + # check edge in host graph + if not host_graph.has_edge(candidate[edge_start], candidate[edge_end]): + all_pass_flag = False + break + + # check edge attr + if _edge_attr_fun is None or not _edge_attr_fun( + (edge_start, edge_end), + (candidate[edge_start], candidate[edge_end]), + query_graph, + host_graph + ): + all_pass_flag = False + break + + if all_pass_flag: + monomorphism_candidates.append(candidate) + + # Isomorphisms check + final_candidates = [] + for candidate in monomorphism_candidates: + all_product = itertools.product(candidate.keys(), candidate.keys()) + for edge_start, edge_end in all_product: + if not query_graph.has_edge(edge_start, edge_end) and \ + host_graph.has_edge(candidate[edge_start], candidate[edge_end]): + break + else: + final_candidates.append(candidate) + return final_candidates diff --git a/profiler/advisor/common/graph/graph_parser.py b/profiler/advisor/common/graph/graph_parser.py new file mode 100644 index 000000000..d4c67fc19 --- /dev/null +++ b/profiler/advisor/common/graph/graph_parser.py @@ -0,0 +1,413 @@ +import os +import logging +import yaml +import itertools +from collections import deque +from dataclasses import dataclass +from typing import List, Tuple, Dict + +logger = logging.getLogger() + + +@dataclass +class Tensor: + def __init__(self): + super().__init__() + self.shape = [] + self.origin_shape = [] + self.shape_range = [] + self.origin_shape_range = [] + self.dtype = "" + self.origin_data_type = "" + self.format = "" + self.origin_format = [] + + +@dataclass +class Attr: + + def __init__(self): + super().__init__() + self.key = str() + self.value = [] + + +class HostGraphNode: + def __init__(self): + super().__init__() + self.graph_name = str() + self.op_name = str() + self.op_type = str() + self.inputs = [] + self.input = [] + self.outputs = [] + self.output = [] + self.strides = [] + self.pads = [] + self.groups = "" + self.dilations = [] + self.kernelname = "" + self._attrs = [] + + def __repr__(self): + return f"" + + +@dataclass +class HostGraph: + def __init__(self): + super().__init__() + self.name = "" + self.nodes = {} + self.inputs = [] + self.edges = [] + self.model_name = None + self.file_path = None + + def build(self): + """build a graph""" + for name, node in self.nodes.items(): + for input_node in node.inputs: + if input_node not in self.nodes: + continue + self.nodes[input_node].outputs.append(name) + + +class HostGraphParser: + """ + Parse graph metadata from text file + """ + def __init__(self, file_path): + self.buffer = deque(maxlen=100) + self.line_no = 0 + self._file_path = file_path + self.edges: List[Tuple[HostGraphNode, HostGraphNode]] = [] + self.nodes: Dict[str, HostGraphNode] = {} + self.graphs = self._parse(self._file_path) + self._get_node_dict() + self._get_edges_list() + del self.graphs[0] + + @staticmethod + def _get_key_value( line): + res = line.split(':', 1) + return res[0].strip(), res[1].strip().strip('"') + + @staticmethod + def _parse_attr(key, value, obj): + if not isinstance(obj, list) and not obj: + return + if key == "dim" and hasattr(obj, "shape"): + obj.shape.append(value) + elif key == "name" and hasattr(obj, "op_name"): + obj.op_name = value + elif key == "name" and hasattr(obj, "name"): + obj.name = value + elif key == "dtype" and hasattr(obj, "dtype"): + obj.dtype = value + elif key == "layout" and hasattr(obj, "format"): + obj.format = value + elif key == "type" and hasattr(obj, "op_type"): + obj.op_type = value + elif key == "input" and hasattr(obj, "input"): + obj.inputs.append(value.strip('"').split(':')[0]) + elif key == "key" and hasattr(obj, "key"): + obj.key = value + elif hasattr(obj, key): + setattr(obj, key, value) + elif isinstance(obj, list) and key != "val_type": + obj.append(value) + + def _parse_struct(self, in_file, key, in_obj): + + def parse_shape(file, obj): + obj = self._parse_line(file, obj) + + def parse_input_desc(file, obj): + tensor = self._parse_line(file, Tensor()) + if obj and hasattr(obj, "input"): + obj.input.append(tensor) + + def parse_out_desc(file, obj): + tensor = self._parse_line(file, Tensor()) + if obj and hasattr(obj, "output"): + obj.output.append(tensor) + + def parse_op(file, obj: HostGraph): + node = self._parse_line(file, HostGraphNode()) + if hasattr(obj, "name"): + node.graph_name = obj.name + if obj and hasattr(obj, "nodes") and node.op_name: + obj.nodes[node.op_name] = node + + def parse_graph(file, obj): + graph = self._parse_line(file, HostGraph()) + obj.append(graph) + + def parse_attr(file, obj): + attr = self._parse_line(file, Attr()) + if hasattr(obj, attr.key): + if attr.key not in ['format']: + setattr(obj, attr.key, attr.value) + elif attr.key.endswith("_kernelname"): + setattr(obj, "kernelname", attr.value) + if obj and hasattr(obj, "get_attrs"): + obj.get_attrs().append(attr) + + def parse_list(file, obj): + value = [] + self._parse_line(file, value) + if isinstance(obj, list): + obj.append(value) + else: + obj = value + + def parse_value(file, obj): + if hasattr(obj, "value"): + obj.value = self._parse_line(file, obj.value) + + def parse_default(file, _obj=None): + """function with unused argument""" + self._parse_line(file, None) + + parse_methods = { + "shape": parse_shape, + "input_desc": parse_input_desc, + "output_desc": parse_out_desc, + "op": parse_op, + "graph": parse_graph, + "attr": parse_attr, + "list_list_int": parse_list, + "list_list_i": parse_list, + "list": parse_list, + "value": parse_value, + } + parse_methods.get(key, parse_default)(in_file, in_obj) + + def _read_line(self, file): + self.line_no += 1 + line = file.readline() + if line.strip().endswith('}'): + end_line = "" + while self.buffer and not end_line.strip().endswith("{"): + end_line = self.buffer.pop() + else: + self.buffer.append(line) + return line.strip() + + def _parse_line(self, file, obj=None): + line = self._read_line(file) + try: + while line and not line.endswith("}"): + if line.endswith('{'): + key = line.rstrip('{').strip() + self._parse_struct(file, key, obj) + else: + key, value = self._get_key_value(line) + self._parse_attr(key, value, obj) + line = self._read_line(file) + except Exception as exception: + if self.buffer: + logger.debug("***********************graph content**************************") + while self.buffer: + line = self.buffer.popleft() + logger.debug(line) + logger.debug("***********************graph content**************************") + raise exception + return obj + + def _parse(self, graph_file): + # pylint:disable=broad-except + graph_list = [] + with open(graph_file, "r", encoding="gbk") as file: + try: + graph_list = self._parse_line(file, graph_list) + except Exception: + logger.error( + "Parse line %s of file %s failed, make sure the format is correct.", self.line_no, graph_file + ) + graphs = [] + for graph in graph_list: + if isinstance(graph, HostGraph): + graphs.append(graph) + for graph in graphs: + graph.model_name = graphs[0].name + graph.file_path = self._file_path + graph.build() + return graphs + + def _get_edges_list(self) -> None: + if len(self.graphs) <= 0: + return + + def is_repeat_edge(edge, edge_collector): + for _edge in edge_collector: + if edge[0].op_name == _edge[0].op_name and edge[1].op_name == _edge[1].op_name: + return True + return False + + for node in self.nodes.values(): + for input_node_name in node.inputs: + if input_node_name not in self.nodes: + continue + input_node = self.nodes[input_node_name] + if not is_repeat_edge((input_node, node), self.edges): + self.edges.append((input_node, node)) + for output_node_name in node.outputs: + if output_node_name not in self.nodes: + continue + output_node = self.nodes[output_node_name] + if not is_repeat_edge((node, output_node), self.edges): + self.edges.append((node, output_node)) + + def _get_node_dict(self) -> None: + if not self.graphs: + self.nodes = {} + return + self.nodes = {node.op_name: node for graph in self.graphs for node in graph.nodes.values()} + + +class QueryGraphNode: + """ + Graph Node + """ + _ID = 0 + + def __init__(self, op_type: str, op_pass: str): + self._op_type = op_type + self._id = QueryGraphNode._ID + self._op_pass = op_pass + QueryGraphNode._ID += 1 + + def get_property(self, name): + """ + get property + """ + return getattr(self, name, lambda: None) + + @property + def op_type(self): + return self._op_type + + @property + def op_name(self): + return self._op_type + "_id_" + str(self._id) + + @property + def op_pass(self): + return self._op_pass + + @op_type.setter + def op_type(self, op_type): + self._op_type = op_type + + def __eq__(self, other): + return self._op_type == other._op_type and \ + self._id == other._id + + def __hash__(self): + return hash(self._op_type + str(self._id)) + + @staticmethod + def trim_string(string: str, length: int = -1): + """ + + Trim string to target length + :param string: Original string + :param length: Target length of string, -1 indicates original string. + :return: Trimmed string + """ + if string is None or not isinstance(string, str): + raise TypeError(f"Param string must be a string type but got {type(string)}.") + + if length <= -1 or len(string) <= length: + return string + + return string[:length] + + +class QueryGraphParser: + def __init__(self, rule_database_path: str): + self._fusion_rules: Dict[str, List[Tuple]] = dict() + self.load_database(rule_database_path) + self.num_rules = sum([len(v) for v in self._fusion_rules.values()]) + + @property + def fusion_rules(self): + return self._fusion_rules + + def load_database(self, rule_database): + if not os.path.isabs(rule_database): + rule_database = os.path.join(os.path.dirname(__file__), + "../", "../", + rule_database) + + if not os.path.exists(rule_database): + raise FileNotFoundError(f"Path {rule_database} does not exist.") + with open(rule_database, 'r') as f: + database = yaml.safe_load(f) + self.parse_yaml(database) + + def parse_yaml(self, yaml_database): + fusion_strategy_list = yaml_database.get("GraphFusion", []) + if yaml_database.get("UBFusion", []): + fusion_strategy_list.extend(yaml_database.get("UBFusion", [])) + for fusion_strategy in fusion_strategy_list: + if not isinstance(fusion_strategy, dict): + continue + (fusion_name, strategy), = fusion_strategy.items() + version = strategy.get("version", 0) + if version == 0 or version == "0": + self._fusion_rules[fusion_name] = self.build_query_graph_v0(fusion_name, + strategy.get('struct', [])) + elif version == 1 or version == "1": + self._fusion_rules[fusion_name] = self.build_query_graph_v1(fusion_name, + strategy.get('nodes', []), + strategy.get('edges', [])) + + @staticmethod + def build_query_graph_v0(graph_name: str, graph_struct: List[str]) -> List[Tuple]: + nodes = dict() + graphs = [] + edges = [] + + pre_node, next_node = None, None + for node in graph_struct: + pre_node = next_node + next_node = QueryGraphNode(node, graph_name) + nodes[next_node.op_name] = next_node + if pre_node is None or next_node is None: + continue + edges.append((pre_node, next_node,)) + graphs.append((nodes, edges, graph_name,)) + return graphs + + @staticmethod + def build_query_graph_v1(graph_name: str, + nodes_list: List[Dict], + edges_list: List[List[str]]) -> List[Tuple]: + graphs = [] + node_index = dict() + multi_node_list = [] + for index, node in enumerate(nodes_list): + (node_name, op_type), = node.items() + if isinstance(op_type, str): + op_type = [op_type] + multi_node_list.append([QueryGraphNode(op, graph_name) for op in op_type]) + node_index[node_name] = index + + multi_node = list(itertools.product(*multi_node_list)) + + for index, sub_nodes in enumerate(multi_node): + sub_graph_name = graph_name if index == 0 else f"{graph_name}#{index}" + sub_edge = [] + sub_node = dict() + for node in sub_nodes: + sub_node[node.op_name] = node + for edge in edges_list: + pre_node, next_node = edge + pre_node_index, next_node_index = node_index.get(pre_node), node_index.get(next_node) + sub_edge.append((sub_nodes[pre_node_index], sub_nodes[next_node_index])) + sub_graph = (sub_node, sub_edge, sub_graph_name,) + graphs.append(sub_graph) + return graphs diff --git a/profiler/advisor/common/version_control.py b/profiler/advisor/common/version_control.py new file mode 100644 index 000000000..38b054543 --- /dev/null +++ b/profiler/advisor/common/version_control.py @@ -0,0 +1,26 @@ +import logging +from typing import List + +logger = logging.getLogger() + + +class VersionControl: + _SUPPORT_VERSIONS = [] + + @classmethod + def is_supported(cls, cann_version: str) -> bool: + """ + Check whether the CANN software version is supported, which can be viewed by executing the following command: + 'cat /usr/local/Ascend/ascend-toolkit/latest/aarch64-linux/ascend_toolkit_install.info' + """ + flag = (cls._SUPPORT_VERSIONS.__contains__(cann_version)) + if not flag: + logger.debug("class type is %s, which is not support current CANN version %s", cls.__name__, cann_version) + return flag + + def get_support_version(self) -> List[str]: + """ + Acquire the CANN software version + :return: supported CANN software version + """ + return self._SUPPORT_VERSIONS -- Gitee