diff --git a/debug/visualization/builder/dynamo_parser.py b/debug/visualization/builder/dynamo_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..1458098288c9f0af1ed3c926e4374f5f942a9d28 --- /dev/null +++ b/debug/visualization/builder/dynamo_parser.py @@ -0,0 +1,138 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.fx as fx +from torch.export import export + +from builder.graph_parser import GraphParser +from graph.graph import Graph +from graph.base_node import BaseNode +from graph.node_op import NodeOp +from tool.log import print_warn_log, print_error_log +from tool.id_manager import IdManager + + +class DynamoParser(GraphParser): + def __init__(self): + super(DynamoParser, self).__init__() + + def parse(self, trace_out, model_class): + f_graph = trace_out[0] + self.graph = Graph() + self.graph.root = self._add_module('', model_class.__name__) + self._parse_once(trace_out, True) + if not trace_out[1]: + return self.graph + self._parse_once(trace_out, False) + return self.graph + + def _parse_once(self, trace_out, is_forward=True): + if len(trace_out) != 2: + print_error_log("len of trace_out != 2") + return + graph = trace_out[0 if is_forward else 1] + self._generate_nodes(graph, is_forward) + self._connect_graph(graph, is_forward) + + def _generate_nodes(self, graph, is_forward): + for raw_node in graph.nodes: + if str(raw_node) in self.graph.rawid_map: + continue + stack_info = self.get_stack_info(raw_node, is_forward) + self._update_stack(stack_info, is_forward) + node_op = self.get_op(raw_node) + node_type = self.get_type(raw_node) + raw_id = str(raw_node) + self._add_node(node_op, node_type, raw_id, is_forward) + + def _connect_graph(self, graph, is_forward): + # 连线依赖fx_node.args数据,它记录了对应节点的所有输入节点 + # 基于args信息,通过稀疏的连接表对BaseNode及其相关父节点进行连接 + for raw_node in graph.nodes: + raw_id = str(raw_node) + if raw_id not in self.graph.rawid_map: + continue + this_node = self.graph.rawid_map.get(raw_id) + if not is_forward and this_node.is_forward : + continue + node_id = this_node.id + for arg in raw_node.args: + if isinstance(arg, tuple): + if len(arg) != 1: + print_error_log('len arg != 1 in _parse_once') + arg = arg[0] + arg_id = str(arg) + if arg_id not in self.graph.rawid_map: + continue + self._connect(self.graph.rawid_map.get(arg_id).id, node_id) + + def get_stack_info(self, node: fx.Node, is_forward=True): + if is_forward: + return DynamoParser._get_forward_stack_info(node) + else: + return self._get_backward_stack_info(node) + + @staticmethod + def get_type(fx_node: fx.Node): + node_str = str(fx_node).lower() + if fx_node.op == "placeholder": + return node_str + elif fx_node.op == "call_function": + splits = node_str.split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + result = "_".join(splits[:-1]) + return result.lower() + else: + return node_str + elif fx_node.op == "output": + return "output" + else: + return "default" + + @staticmethod + def get_op(fx_node: fx.Node): + str_map = {"placeholder": NodeOp.tensor, "call_function": NodeOp.function, "output": NodeOp.default} + return str_map.get(fx_node.op, NodeOp.default) + + @staticmethod + def _get_forward_stack_info(node: fx.Node): + if not node.meta: + return [] + dynamo_stack_info = node.meta.get("nn_module_stack", {}) + stack_info = [ + (key, dynamo_stack_info.get(key)[1].__name__) for key in dynamo_stack_info + ] + return stack_info + + def _get_backward_stack_info(self, node: fx.Node): + # 获得他即将获得node_id + op = DynamoParser.get_type(node) + this_id = IdManager.get_next_id(op) + # 根据node_id找到对应的正向算子 + pair_id = IdManager.find_pair_id(this_id) + if not pair_id or pair_id not in self.graph.node_map: + print_warn_log(f'{this_id} pair id not found') + return [] + f_node = self.graph.node_map.get(pair_id) + # 根据正向算子的调用栈生成对应的反向调用栈 + stack = [] + if f_node: + f_node = f_node.upnode + while f_node and self.graph._is_root(f_node) : + stack.append((f_node.id + 'b', f_node.type)) + f_node = f_node.upnode + stack.reverse() + return stack diff --git a/debug/visualization/tool/dynamo_parser.py b/debug/visualization/builder/dynamo_tracer.py similarity index 33% rename from debug/visualization/tool/dynamo_parser.py rename to debug/visualization/builder/dynamo_tracer.py index f442ad0165b5ffc7c1409e1b285e3d4e773bff71..aec37d1bcb5b9fbeb48846cd66650c26bb5314f9 100644 --- a/debug/visualization/tool/dynamo_parser.py +++ b/debug/visualization/builder/dynamo_tracer.py @@ -16,51 +16,33 @@ import torch import torch.fx as fx from torch.export import export +from torch._functorch.aot_autograd import aot_export_module -from graph.node_op import NodeOp +from builder.graph_tracer import GraphTracer -class DynamoParser: - @staticmethod - def get_type(fx_node:fx.Node): - if fx_node.op == "placeholder": - return str(fx_node).lower() - elif fx_node.op == "call_function": - splits = str(fx_node).split('_') - if len(splits) >= 2 and splits[-1].isdigit(): - result = "_".join(splits[:-1]) - return result.lower() - else: - return str(fx_node).lower() - elif fx_node.op == "output": - return "output" - else: - return "default" - - @staticmethod - def get_id(fx_node:fx.Node): - fx_name = str(fx_node).lower() - if "arg" in fx_name: - return f'{fx_name}_0' - splits = fx_name.split('_') - if len(splits) >= 2 and splits[-1].isdigit(): - return fx_name - else: - return f'{fx_name}_0' - - @staticmethod - def get_op(fx_node:fx.Node): - str_map = {"placeholder":NodeOp.tensor, "call_function":NodeOp.function, "output": NodeOp.default} - return str_map.get(fx_node.op, NodeOp.default) - - @staticmethod - def get_stack_info(fx_node:fx.Node): - if not fx_node.meta: - return [] - return fx_node.meta.get("nn_module_stack", []) +class LossWrapper(torch.nn.Module): + def __init__(self, net): + super(LossWrapper, self).__init__() + self.net = net - @staticmethod - def get_fx_graph_by_export(net:torch.nn.Module, x:torch.Tensor): + def forward(self, x): + y = self.net(x) + loss = y.sum() + y = y.detach() + return (loss, y) + + +class DynamoTracer(GraphTracer): + def __init__(self): + super(DynamoTracer, self).__init__() + + def trace(self, net: torch.nn.Module, x: torch.Tensor, need_backward=True): export_program = export(net, args=(x,)) - fx_graph = export_program.graph - return fx_graph + f_graph = export_program.graph + b_graph = None + if not need_backward: + return (f_graph, b_graph) + loss_wrapper = LossWrapper(net) + b_graph = aot_export_module(loss_wrapper, (x,), trace_joint=True, output_loss_index=0)[0].graph + return (f_graph, b_graph) diff --git a/debug/visualization/builder/graph_builder.py b/debug/visualization/builder/graph_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..d1cdabd8dbebfafeac9fe8605a0fbe988cead721 --- /dev/null +++ b/debug/visualization/builder/graph_builder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from builder.graph_tracer import GraphTracer +from builder.dynamo_tracer import DynamoTracer +from builder.graph_parser import GraphParser +from builder.dynamo_parser import DynamoParser + + +class GraphBuilder(): + def __init__(self): + self.tracer = DynamoTracer() + self.parser = DynamoParser() + + def build(self, net: torch.nn.Module, x: torch.Tensor, need_backward=True): + trace_out = self.tracer.trace(net, x, need_backward) + model_class = net.__class__ + graph = self.parser.parse(trace_out, model_class) + return graph diff --git a/debug/visualization/builder/graph_parser.py b/debug/visualization/builder/graph_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..4ee34720797fc08c354b1bf70b2032fed29a3c95 --- /dev/null +++ b/debug/visualization/builder/graph_parser.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from graph.graph import Graph +from graph.base_node import BaseNode +from graph.node_op import NodeOp +from tool.id_manager import IdManager +from tool.log import print_warn_log + + +class GraphParser(ABC): + def __init__(self): + self.graph = None + + @abstractmethod + def parse(self, trace_out, model_class): + raise NotImplementedError('Function parse need to be implemented.') + + def _update_stack(self, stack_info, is_forward=True): + self.graph.depth = max(self.graph.depth, len(stack_info) + 1) + self.graph.recent_node = None + if not is_forward and not self.graph.root.pair: + raw_id = self.graph.root.id + 'b' + self._add_module(raw_id, self.graph.root.type, False) + else: + self.graph.recent_node = self.graph.root if is_forward else self.graph.root.pair + for key, module_class in stack_info: + if key in self.graph.rawid_map: + node = self.graph.rawid_map.get(key) + self.graph.recent_node = node + else: + self._add_module(key, module_class, is_forward) + + def _add_module(self, raw_id, module_class, is_forward=True): + module_node = self._add_node(NodeOp.module, module_class, raw_id, is_forward) + self.graph.recent_node = module_node + return module_node + + def _add_node(self, node_op, node_type, raw_id, is_forward=True): + this_node = BaseNode(node_op, node_type, self.graph.recent_node, is_forward) + node_id = this_node.id + self.graph.node_map[node_id] = this_node + self.graph.rawid_map[raw_id] = this_node + # 绑定反向关系 + if not is_forward: + pair_id = IdManager.find_pair_id(node_id) + if pair_id not in self.graph.node_map: + print_warn_log(f'pair_id {pair_id} _add_node not found') + return this_node + f_node = self.graph.node_map.get(pair_id) + BaseNode.add_direciton_pair(f_node, this_node) + return this_node + + def _connect(self, input_id, output_id): + if input_id not in self.graph.node_map or output_id not in self.graph.node_map: + return + inode = self.graph.node_map.get(input_id) + onode = self.graph.node_map.get(output_id) + istack = [] + while not self.graph._is_root(inode): + istack.append(inode) + inode = inode.upnode + ostack = [] + while not self.graph._is_root(onode): + ostack.append(onode) + onode = onode.upnode + while istack or ostack: + if istack: + inode = istack.pop() + if ostack: + onode = ostack.pop() + if (inode.id == onode.id): + continue + BaseNode.add_data_flow(inode, onode) diff --git a/debug/visualization/builder/graph_tracer.py b/debug/visualization/builder/graph_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..abb7a23b8ca7f36b4bb739cb4e607d3b2f641b2e --- /dev/null +++ b/debug/visualization/builder/graph_tracer.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +import torch + + +class GraphTracer(ABC): + def __init__(self): + pass + + @abstractmethod + def trace(self, net: torch.nn.Module, x: torch.Tensor, need_backward=True): + raise NotImplementedError('Function trace need to be implemented.') diff --git a/debug/visualization/graph/base_node.py b/debug/visualization/graph/base_node.py index 3640bee7862c7862901f2855a3ab133e1c032144..09b30fb65872c896b032cb38037d5b1dec19852d 100644 --- a/debug/visualization/graph/base_node.py +++ b/debug/visualization/graph/base_node.py @@ -16,49 +16,57 @@ import torch.fx as fx from tool.id_manager import IdManager -from tool.dynamo_parser import DynamoParser +from tool.log import print_error_log from graph.node_op import NodeOp class BaseNode: - def __init__(self, input_arg, up_node=None): - if isinstance(input_arg, fx.Node): - self.op = DynamoParser.get_op(input_arg) - self.type = DynamoParser.get_type(input_arg) - elif isinstance(input_arg, str): - self.op = NodeOp.module - self.type = input_arg.lower() - else: - self.op = NodeOp.default - self.type = "default" - self.id = IdManager.get_id(self.type) + def __init__(self, node_op, node_type, up_node=None, is_forward=True): + self.op = node_op + self.type = node_type + self.id = IdManager.get_id(self.type, is_forward) self.data = {} self.outputs = [] self.inputs = [] self.upnode = up_node self.subnodes = [] + if up_node: + up_node.add_subnode(self) + self.is_forward = is_forward + self.pair = None def __str__(self): info = f'id:\t{self.id}' return info def get_info(self): - info = f'{self.id}\t{self.op}\n' + info = f'{self.id}\t{self.op}' + if not self.is_forward: + info += '(b)' for key in self.data: - info += f'{key}:\t{self.data.get(key)}' + info += f'\n{key}:\t{self.data.get(key)}' return info - def add_output(self, node): - if node.id == self.id: - return - self.outputs.append(node) - - def add_input(self, node): - if node.id == self.id: - return - self.inputs.append(node) - def add_subnode(self, node): if node.id == self.id: return self.subnodes.append(node) + + @staticmethod + def add_data_flow(inode, onode): + if not inode or not onode: + return + if inode.id == onode.id: + return + inode.outputs.append(onode) + onode.inputs.append(inode) + + @staticmethod + def add_direciton_pair(node1, node2): + if not node1 or not node2: + return + if node1.type != node2.type or node1.is_forward == node2.is_forward: + print_error_log("Error in add_direction_pair") + return + node1.pair = node2 + node2.pair = node1 diff --git a/debug/visualization/graph/graph.py b/debug/visualization/graph/graph.py index 61bdc00bd88c20938af116bd1ff06bedc75e570d..6474f38a6a122f3e0429816279736b759b453a47 100644 --- a/debug/visualization/graph/graph.py +++ b/debug/visualization/graph/graph.py @@ -16,7 +16,6 @@ import torch from graph.base_node import BaseNode -from tool.dynamo_parser import DynamoParser from tool.log import print_error_log @@ -27,25 +26,7 @@ class Graph: self.depth = 0 self.node_map = {} self.rawid_map = {} - - def build_graph(self, net:torch.nn.Module, x:torch.Tensor): - self.root = BaseNode(net.__class__.__name__) - self.recent_node = self.root - raw_graph = DynamoParser.get_fx_graph_by_export(net, x) - for raw_node in raw_graph.nodes: - self._update_stack(raw_node) - self._add_node(raw_node) - for raw_node in raw_graph.nodes: - node_id = DynamoParser.get_id(raw_node) - for arg in raw_node.args: - if isinstance(arg, tuple): - arg = arg[0] - arg_id = DynamoParser.get_id(arg) - if arg_id not in self.node_map: - continue - self._connect(arg_id, node_id) - def update_info(self, info_map): for key in info_map: if key not in self.node_map: @@ -54,57 +35,17 @@ class Graph: node = self.node_map.get(key) for sub_key in info_map.get(key): node.data[sub_key] = info_map.get(key).get(sub_key) - - def _connect(self, input_id, output_id): - if input_id not in self.node_map or output_id not in self.node_map: - return - inode = self.node_map.get(input_id) - onode = self.node_map.get(output_id) - istack = [] - while not self._is_root(inode): - istack.append(inode) - inode = inode.upnode - ostack = [] - while not self._is_root(onode): - ostack.append(onode) - onode = onode.upnode - while istack or ostack: - if istack: - inode = istack.pop() - if ostack: - onode = ostack.pop() - if (inode.id == onode.id): - continue - inode.add_output(onode) - onode.add_input(inode) - - def _update_stack(self, raw_node): - stack_info = DynamoParser.get_stack_info(raw_node) - self.depth = max(self.depth, len(stack_info) + 1) - self.recent_node = self.root - for key in stack_info: - if key in self.rawid_map: - node = self.rawid_map.get(key) - self.recent_node = node - else: - _, module_class = stack_info.get(key) - self._add_module(key, module_class) - def _add_node(self, raw_node): - this_node = BaseNode(raw_node, self.recent_node) - self.node_map[this_node.id] = this_node - self.recent_node.add_subnode(this_node) - - def _add_module(self, raw_id, module_class): - type_name = module_class.__name__ - module_node = BaseNode(type_name, self.recent_node) - self.node_map[module_node.id] = module_node - self.rawid_map[raw_id] = module_node - self.recent_node.add_subnode(module_node) - self.recent_node = module_node - def _is_root(self, node): - return node.id == self.root.id + is_foward_root = node.id == self.root.id + is_backward_root = self.root.pair and node.id == self.root.pair.id + return is_foward_root or is_backward_root + + def get_output_nodes(self): + if 'output_0' not in self.node_map: + return [] + outputs = self.node_map.get('output_0').inputs + return outputs def __str__(self): infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] diff --git a/debug/visualization/test/test_script.py b/debug/visualization/test/test_script.py index 4531b3c31965463d4e90c86cc3f7703de841805c..cef80860f738d0847bcdc696bdc2c0abdca63745 100644 --- a/debug/visualization/test/test_script.py +++ b/debug/visualization/test/test_script.py @@ -15,17 +15,16 @@ import torch -from test_net import LeNet2 +from test_net import LeNet2, AddOne, AddThree from graph.graph import Graph +from builder.graph_builder import GraphBuilder from tool.graph_viewer import GraphViewer if __name__ == '__main__': - net = LeNet2() - x = torch.randn(3, 32, 32) - - graph = Graph() - graph.build_graph(net, x) + net = AddOne() + x = torch.randn((4, 4), requires_grad=True) + graph = GraphBuilder().build(net, x, need_backward=True) GraphViewer.save_full_level_pdf(graph, './output/') GraphViewer.save_tree(graph, './output/tree') diff --git a/debug/visualization/tool/graph_viewer.py b/debug/visualization/tool/graph_viewer.py index 40ac36e7316a8ebfeef7944b26b1de7ebbde9dbd..01d078999e66a5658e1df16a3179bb14adc05210 100644 --- a/debug/visualization/tool/graph_viewer.py +++ b/debug/visualization/tool/graph_viewer.py @@ -21,11 +21,13 @@ from graph.graph import Graph class GraphViewer(): @staticmethod - def save_pdf(graph, level:int, output_file:str): + def save_pdf(graph, level: int, output_file: str): dot = Digraph() dot.attr(rankdir='TB') node_queue = queue.Queue() node_queue.put(graph.root) + if graph.root.pair: + node_queue.put(graph.root.pair) display_ids = set() for _ in range(level + 1): GraphViewer._search(display_ids, node_queue) @@ -40,8 +42,8 @@ class GraphViewer(): dot.render(output_file) @staticmethod - def save_full_level_pdf(graph:Graph, output_path): - level = graph.depth + def save_full_level_pdf(graph: Graph, output_path): + level = graph.depth - 1 for i in range(level): GraphViewer.save_pdf(graph, i, f'{output_path}/level_{i}') @@ -49,10 +51,6 @@ class GraphViewer(): def save_tree(graph, output_file): dot = Digraph() dot.attr(rankdir='TB') - node = graph.root - dot.node(node.id, node.get_info()) - for subnode in node.subnodes: - dot.edge(node.id, subnode.id, style='dashed') for node_id in graph.node_map: node = graph.node_map.get(node_id) info = node.get_info() diff --git a/debug/visualization/tool/id_manager.py b/debug/visualization/tool/id_manager.py index 5ea283f824d3b1e6960ab9e96efc4bf80bf3d448..58b5151b3216857fb6a53c11d9a7db180df81055 100644 --- a/debug/visualization/tool/id_manager.py +++ b/debug/visualization/tool/id_manager.py @@ -13,13 +13,52 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tool.log import print_error_log, print_warn_log + class IdManager: type_count = {} + forward_end = False + forward_count = {} @classmethod - def get_id(cls, type_name:str): + def get_id(cls, type_name: str, is_forward=True): + if not is_forward and not cls.forward_end: + cls.start_backward() if type_name not in cls.type_count: cls.type_count[type_name] = -1 cls.type_count[type_name] += 1 return f'{type_name}_{cls.type_count.get(type_name)}' + + @classmethod + def get_next_id(cls, type_name: str): + if type_name in cls.type_count: + next_cnt = cls.type_count.get(type_name) + 1 + else: + next_cnt = 0 + return f'{type_name}_{next_cnt}' + + @classmethod + def start_backward(cls): + if cls.forward_end: + return + cls.forward_end = True + for key in cls.type_count: + cls.forward_count[key] = cls.type_count.get(key) + + @classmethod + def find_pair_id(cls, node_id): + splits = node_id.split('_') + if len(splits) < 2: + print_error_log('error in find_forward_id') + return '' + type_count = int(splits[-1]) + node_type = "_".join(splits[:-1]) + if node_type not in cls.forward_count: + print_warn_log(f'{node_id} type not found in forward_count') + return '' + base_cnt = cls.forward_count.get(node_type) + if type_count >= 2 * (base_cnt + 1) or type_count < 0: + print_error_log('cnt is too big or small') + return '' + return f'{node_type}_{2 * base_cnt + 1 - type_count}'