From e579284aca8f2cf5fc5c2ba5ee65196fbedd0ffd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E6=99=93=E6=B3=A2?= Date: Wed, 10 Jan 2024 11:26:21 +0800 Subject: [PATCH 1/2] =?UTF-8?q?[feature]=E5=B0=8F=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E4=B8=8B=E5=9F=BA=E4=BA=8Edynamo=E7=9A=84=E6=9E=84=E5=9B=BE?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E5=AF=B9=E8=B1=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/visualization/graph/base_node.py | 64 ++++++++++++ debug/visualization/graph/graph.py | 112 ++++++++++++++++++++ debug/visualization/graph/node_op.py | 24 +++++ debug/visualization/test/test_net.py | 119 ++++++++++++++++++++++ debug/visualization/test/test_script.py | 31 ++++++ debug/visualization/tool/dynamo_parser.py | 66 ++++++++++++ debug/visualization/tool/graph_viewer.py | 82 +++++++++++++++ debug/visualization/tool/id_manager.py | 25 +++++ debug/visualization/tool/log.py | 55 ++++++++++ 9 files changed, 578 insertions(+) create mode 100644 debug/visualization/graph/base_node.py create mode 100644 debug/visualization/graph/graph.py create mode 100644 debug/visualization/graph/node_op.py create mode 100644 debug/visualization/test/test_net.py create mode 100644 debug/visualization/test/test_script.py create mode 100644 debug/visualization/tool/dynamo_parser.py create mode 100644 debug/visualization/tool/graph_viewer.py create mode 100644 debug/visualization/tool/id_manager.py create mode 100644 debug/visualization/tool/log.py diff --git a/debug/visualization/graph/base_node.py b/debug/visualization/graph/base_node.py new file mode 100644 index 0000000000..3640bee786 --- /dev/null +++ b/debug/visualization/graph/base_node.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.fx as fx + +from tool.id_manager import IdManager +from tool.dynamo_parser import DynamoParser +from graph.node_op import NodeOp + + +class BaseNode: + def __init__(self, input_arg, up_node=None): + if isinstance(input_arg, fx.Node): + self.op = DynamoParser.get_op(input_arg) + self.type = DynamoParser.get_type(input_arg) + elif isinstance(input_arg, str): + self.op = NodeOp.module + self.type = input_arg.lower() + else: + self.op = NodeOp.default + self.type = "default" + self.id = IdManager.get_id(self.type) + self.data = {} + self.outputs = [] + self.inputs = [] + self.upnode = up_node + self.subnodes = [] + + def __str__(self): + info = f'id:\t{self.id}' + return info + + def get_info(self): + info = f'{self.id}\t{self.op}\n' + for key in self.data: + info += f'{key}:\t{self.data.get(key)}' + return info + + def add_output(self, node): + if node.id == self.id: + return + self.outputs.append(node) + + def add_input(self, node): + if node.id == self.id: + return + self.inputs.append(node) + + def add_subnode(self, node): + if node.id == self.id: + return + self.subnodes.append(node) diff --git a/debug/visualization/graph/graph.py b/debug/visualization/graph/graph.py new file mode 100644 index 0000000000..61bdc00bd8 --- /dev/null +++ b/debug/visualization/graph/graph.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from graph.base_node import BaseNode +from tool.dynamo_parser import DynamoParser +from tool.log import print_error_log + + +class Graph: + def __init__(self): + self.root = None + self.recent_node = None + self.depth = 0 + self.node_map = {} + self.rawid_map = {} + + def build_graph(self, net:torch.nn.Module, x:torch.Tensor): + self.root = BaseNode(net.__class__.__name__) + self.recent_node = self.root + + raw_graph = DynamoParser.get_fx_graph_by_export(net, x) + for raw_node in raw_graph.nodes: + self._update_stack(raw_node) + self._add_node(raw_node) + for raw_node in raw_graph.nodes: + node_id = DynamoParser.get_id(raw_node) + for arg in raw_node.args: + if isinstance(arg, tuple): + arg = arg[0] + arg_id = DynamoParser.get_id(arg) + if arg_id not in self.node_map: + continue + self._connect(arg_id, node_id) + + def update_info(self, info_map): + for key in info_map: + if key not in self.node_map: + print_error_log(f'{key} not found in node map') + continue + node = self.node_map.get(key) + for sub_key in info_map.get(key): + node.data[sub_key] = info_map.get(key).get(sub_key) + + def _connect(self, input_id, output_id): + if input_id not in self.node_map or output_id not in self.node_map: + return + inode = self.node_map.get(input_id) + onode = self.node_map.get(output_id) + istack = [] + while not self._is_root(inode): + istack.append(inode) + inode = inode.upnode + ostack = [] + while not self._is_root(onode): + ostack.append(onode) + onode = onode.upnode + while istack or ostack: + if istack: + inode = istack.pop() + if ostack: + onode = ostack.pop() + if (inode.id == onode.id): + continue + inode.add_output(onode) + onode.add_input(inode) + + def _update_stack(self, raw_node): + stack_info = DynamoParser.get_stack_info(raw_node) + self.depth = max(self.depth, len(stack_info) + 1) + self.recent_node = self.root + for key in stack_info: + if key in self.rawid_map: + node = self.rawid_map.get(key) + self.recent_node = node + else: + _, module_class = stack_info.get(key) + self._add_module(key, module_class) + + def _add_node(self, raw_node): + this_node = BaseNode(raw_node, self.recent_node) + self.node_map[this_node.id] = this_node + self.recent_node.add_subnode(this_node) + + def _add_module(self, raw_id, module_class): + type_name = module_class.__name__ + module_node = BaseNode(type_name, self.recent_node) + self.node_map[module_node.id] = module_node + self.rawid_map[raw_id] = module_node + self.recent_node.add_subnode(module_node) + self.recent_node = module_node + + def _is_root(self, node): + return node.id == self.root.id + + def __str__(self): + infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] + info = "\n".join(infos) + return info diff --git a/debug/visualization/graph/node_op.py b/debug/visualization/graph/node_op.py new file mode 100644 index 0000000000..938664a60c --- /dev/null +++ b/debug/visualization/graph/node_op.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class NodeOp(Enum): + module = 1 + function = 2 + param = 3 + constant = 4 + default = 5 diff --git a/debug/visualization/test/test_net.py b/debug/visualization/test/test_net.py new file mode 100644 index 0000000000..9fb867f2b1 --- /dev/null +++ b/debug/visualization/test/test_net.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LeNet(nn.Module): + def __init__(self): + super(LeNet, self).__init__() + self.conv1 = nn.Conv2d(3, 16, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(16, 32, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.fc1 = nn.Linear(32 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = self.pool1(x) + x = F.relu(self.conv2(x)) + x = self.pool2(x) + x = x.view(-1, 32 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class ConvRelu(nn.Module): + def __init__(self, x, y, z): + super(ConvRelu, self).__init__() + self.conv = nn.Conv2d(x, y, z) + + def forward(self, x): + x = self.conv(x) + x = F.relu(x) + return x + + +class FcRelu(nn.Module): + def __init__(self, x, y): + super(FcRelu, self).__init__() + self.fc = nn.Linear(x, y) + + def forward(self, x): + x = self.fc(x) + x = F.relu(x) + return x + + +class LeNet2(nn.Module): + def __init__(self): + super(LeNet2, self).__init__() + self.conv_relu1 = ConvRelu(3, 16, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.conv_relu2 = ConvRelu(16, 32, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.fc_relu = FcRelu(32 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.conv_relu1(x) + x = self.pool1(x) + x = self.conv_relu2(x) + x = self.pool2(x) + x = x.view(-1, 32 * 5 * 5) + x = self.fc_relu(x) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class AddOne(nn.Module): + def __init__(self): + super(AddOne, self).__init__() + + def forward(self, x): + x = x + x + return x + + +class AddTwo(nn.Module): + def __init__(self): + super(AddTwo, self).__init__() + self.add_one1 = AddOne() + self.add_one2 = AddOne() + + def forward(self, x): + x = self.add_one1(x) + x = self.add_one2(x) + return x + + +class AddThree(nn.Module): + def __init__(self): + super(AddThree, self).__init__() + self.add_two = AddTwo() + self.add_one = AddOne() + + def forward(self, x): + x = self.add_two(x) + x = self.add_one(x) + return x diff --git a/debug/visualization/test/test_script.py b/debug/visualization/test/test_script.py new file mode 100644 index 0000000000..4531b3c319 --- /dev/null +++ b/debug/visualization/test/test_script.py @@ -0,0 +1,31 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from test_net import LeNet2 +from graph.graph import Graph +from tool.graph_viewer import GraphViewer + + +if __name__ == '__main__': + net = LeNet2() + x = torch.randn(3, 32, 32) + + graph = Graph() + graph.build_graph(net, x) + + GraphViewer.save_full_level_pdf(graph, './output/') + GraphViewer.save_tree(graph, './output/tree') diff --git a/debug/visualization/tool/dynamo_parser.py b/debug/visualization/tool/dynamo_parser.py new file mode 100644 index 0000000000..2835e70a79 --- /dev/null +++ b/debug/visualization/tool/dynamo_parser.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.fx as fx +from torch.export import export + +from graph.node_op import NodeOp + + +class DynamoParser: + @staticmethod + def get_type(fx_node:fx.Node): + if fx_node.op == "placeholder": + return str(fx_node).lower() + elif fx_node.op == "call_function": + splits = str(fx_node).split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + result = "_".join(splits[:-1]) + return result.lower() + else: + return str(fx_node).lower() + elif fx_node.op == "output": + return "output" + else: + return "default" + + @staticmethod + def get_id(fx_node:fx.Node): + fx_name = str(fx_node).lower() + if "arg" in fx_name: + return f'{fx_name}_0' + splits = fx_name.split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + return fx_name + else: + return f'{fx_name}_0' + + @staticmethod + def get_op(fx_node:fx.Node): + str_map = {"placeholder":NodeOp.param, "call_function":NodeOp.function, "output": NodeOp.default} + return str_map.get(fx_node.op, NodeOp.default) + + @staticmethod + def get_stack_info(fx_node:fx.Node): + if not fx_node.meta: + return [] + return fx_node.meta.get("nn_module_stack", []) + + @staticmethod + def get_fx_graph_by_export(net:torch.nn.Module, x:torch.Tensor): + export_program = export(net, args=(x,)) + fx_graph = export_program.graph + return fx_graph diff --git a/debug/visualization/tool/graph_viewer.py b/debug/visualization/tool/graph_viewer.py new file mode 100644 index 0000000000..40ac36e731 --- /dev/null +++ b/debug/visualization/tool/graph_viewer.py @@ -0,0 +1,82 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +from graphviz import Digraph + +from graph.graph import Graph + + +class GraphViewer(): + @staticmethod + def save_pdf(graph, level:int, output_file:str): + dot = Digraph() + dot.attr(rankdir='TB') + node_queue = queue.Queue() + node_queue.put(graph.root) + display_ids = set() + for _ in range(level + 1): + GraphViewer._search(display_ids, node_queue) + while not node_queue.empty(): + node = node_queue.get() + info = node.get_info() + dot.node(node.id, info, shape='rectangle') + for input_node in node.inputs: + if input_node.id not in display_ids: + continue + dot.edge(input_node.id, node.id) + dot.render(output_file) + + @staticmethod + def save_full_level_pdf(graph:Graph, output_path): + level = graph.depth + for i in range(level): + GraphViewer.save_pdf(graph, i, f'{output_path}/level_{i}') + + @staticmethod + def save_tree(graph, output_file): + dot = Digraph() + dot.attr(rankdir='TB') + node = graph.root + dot.node(node.id, node.get_info()) + for subnode in node.subnodes: + dot.edge(node.id, subnode.id, style='dashed') + for node_id in graph.node_map: + node = graph.node_map.get(node_id) + info = node.get_info() + dot.node(node.id, info, shape='rectangle') + for subnode in node.subnodes: + dot.edge(node.id, subnode.id, style='dashed') + for input_node in node.inputs: + dot.edge(input_node.id, node.id) + dot.render(output_file) + + @staticmethod + def _search(display_ids, node_queue): + display_ids.clear() + queue_size = node_queue.qsize() + for _ in range(queue_size): + node = node_queue.get() + if len(node.subnodes) == 0: + node_queue.put(node) + display_ids.add(node.id) + else: + GraphViewer._add_subnodes(display_ids, node_queue, node) + + @staticmethod + def _add_subnodes(display_ids, node_queue, node): + for subnode in node.subnodes: + node_queue.put(subnode) + display_ids.add(subnode.id) diff --git a/debug/visualization/tool/id_manager.py b/debug/visualization/tool/id_manager.py new file mode 100644 index 0000000000..5ea283f824 --- /dev/null +++ b/debug/visualization/tool/id_manager.py @@ -0,0 +1,25 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class IdManager: + type_count = {} + + @classmethod + def get_id(cls, type_name:str): + if type_name not in cls.type_count: + cls.type_count[type_name] = -1 + cls.type_count[type_name] += 1 + return f'{type_name}_{cls.type_count.get(type_name)}' diff --git a/debug/visualization/tool/log.py b/debug/visualization/tool/log.py new file mode 100644 index 0000000000..fe21c997ec --- /dev/null +++ b/debug/visualization/tool/log.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys + + +def _print_log(level, msg, end='\n'): + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) + pid = os.getgid() + print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) + sys.stdout.flush() + + +def print_info_log(info_msg, end='\n'): + """ + Function Description: + print info log. + Parameter: + info_msg: the info message. + """ + _print_log("INFO", info_msg, end=end) + + +def print_error_log(error_msg): + """ + Function Description: + print error log. + Parameter: + error_msg: the error message. + """ + _print_log("ERROR", error_msg) + + +def print_warn_log(warn_msg): + """ + Function Description: + print warn log. + Parameter: + warn_msg: the warning message. + """ + _print_log("WARNING", warn_msg) -- Gitee From 1ed6c2d182c02d4b3d5c5901cca2d8b66704de18 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E6=99=93=E6=B3=A2?= Date: Sat, 13 Jan 2024 14:29:09 +0800 Subject: [PATCH 2/2] =?UTF-8?q?[feature]=E5=B0=8F=E6=A8=A1=E5=9E=8B?= =?UTF-8?q?=E6=A7=8B=E5=9C=96=E4=BB=A3=E7=A2=BC=E9=87=8D=E6=A7=8B=E4=BB=A5?= =?UTF-8?q?=E5=8F=8A=E5=8F=8D=E5=90=91=E6=A7=8B=E5=9C=96=E5=8A=A0=E5=85=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/visualization/builder/dynamo_parser.py | 160 +++++++++++++++++++ debug/visualization/builder/dynamo_tracer.py | 50 ++++++ debug/visualization/builder/graph_builder.py | 33 ++++ debug/visualization/builder/graph_parser.py | 88 ++++++++++ debug/visualization/builder/graph_tracer.py | 26 +++ debug/visualization/graph/base_node.py | 55 ++++--- debug/visualization/graph/graph.py | 76 ++------- debug/visualization/test/test_script.py | 16 +- debug/visualization/tool/graph_viewer.py | 8 +- debug/visualization/tool/id_manager.py | 43 ++++- 10 files changed, 454 insertions(+), 101 deletions(-) create mode 100644 debug/visualization/builder/dynamo_parser.py create mode 100644 debug/visualization/builder/dynamo_tracer.py create mode 100644 debug/visualization/builder/graph_builder.py create mode 100644 debug/visualization/builder/graph_parser.py create mode 100644 debug/visualization/builder/graph_tracer.py diff --git a/debug/visualization/builder/dynamo_parser.py b/debug/visualization/builder/dynamo_parser.py new file mode 100644 index 0000000000..cfccc1dd23 --- /dev/null +++ b/debug/visualization/builder/dynamo_parser.py @@ -0,0 +1,160 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.fx as fx +from torch.export import export + +from builder.graph_parser import GraphParser +from graph.graph import Graph +from graph.base_node import BaseNode +from graph.node_op import NodeOp +from tool.log import print_warn_log +from tool.id_manager import IdManager + + +class DynamoParser(GraphParser): + def __init__(self): + super(DynamoParser, self).__init__() + + def parse(self, trace_out, model_class): + f_graph = trace_out[0] + self.graph = Graph() + self.graph.root = self._add_module('', model_class.__name__) + self._parse_forward(trace_out[0]) + if not trace_out[1]: + return self.graph + # 開始解析我們的反向數據 + self._parse_backward(trace_out[1]) + return self.graph + + def _parse_forward(self, f_graph): + for raw_node in f_graph.nodes: + stack_info = self.get_stack_info(raw_node) + self._update_stack(stack_info) + node_op = self.get_op(raw_node) + node_type = self.get_type(raw_node) + raw_id = str(raw_node) + self._add_node(node_op, node_type, raw_id) + for raw_node in f_graph.nodes: + # todo + # 这个id怎么获取,讨论讨论 + raw_id = str(raw_node) + if raw_id not in self.graph.rawid_map: + continue + node_id = self.graph.rawid_map.get(raw_id).id + for arg in raw_node.args: + # todo + # 为啥呢,有没有影响 + if isinstance(arg, tuple): + arg = arg[0] + arg_id = str(arg) + # todo + # 爲啥要跳過 + if arg_id not in self.graph.rawid_map: + continue + self._connect(self.graph.rawid_map.get(arg_id).id, node_id) + + # todo + # 重複代碼整理 + def _parse_backward(self, b_graph): + last_nodes = self.graph.get_output_nodes() + for raw_node in b_graph.nodes: + if str(raw_node) in self.graph.rawid_map: + continue + stack_info = self.get_stack_info(raw_node, False) + self._update_stack(stack_info, is_forward=False) + node_op = self.get_op(raw_node) + node_type = self.get_type(raw_node) + raw_id = str(raw_node) + self._add_node(node_op, node_type, raw_id, is_forward=False) + for raw_node in b_graph.nodes: + # todo + # 这个id怎么获取,讨论讨论 + raw_id = str(raw_node) + if raw_id not in self.graph.rawid_map: + continue + this_node = self.graph.rawid_map.get(raw_id) + if this_node.is_forward: + continue + node_id = this_node.id + for arg in raw_node.args: + # todo + # 为啥呢,有没有影响 + if isinstance(arg, tuple): + arg = arg[0] + arg_id = str(arg) + if arg_id not in self.graph.rawid_map: + continue + self._connect(self.graph.rawid_map.get(arg_id).id, node_id) + + def get_stack_info(self, node:fx.Node, is_forward=True): + if is_forward: + return DynamoParser._get_forward_stack_info(node) + else: + return self._get_backward_stack_info(node) + + @staticmethod + def get_type(fx_node:fx.Node): + node_str = str(fx_node).lower() + if fx_node.op == "placeholder": + return node_str + elif fx_node.op == "call_function": + splits = node_str.split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + result = "_".join(splits[:-1]) + return result.lower() + else: + return node_str + elif fx_node.op == "output": + return "output" + else: + return "default" + + @staticmethod + def get_op(fx_node:fx.Node): + str_map = {"placeholder":NodeOp.param, "call_function":NodeOp.function, "output": NodeOp.default} + return str_map.get(fx_node.op, NodeOp.default) + + @staticmethod + def _get_forward_stack_info(node:fx.Node): + if not node.meta: + return [] + dynamo_stack_info = node.meta.get("nn_module_stack", {}) + stack_info = [ + (key, dynamo_stack_info.get(key)[1].__name__) for key in dynamo_stack_info + ] + return stack_info + + def _get_backward_stack_info(self, node:fx.Node): + # 獲得他會獲得的id + op = DynamoParser.get_type(node) + this_id = IdManager.get_next_id(op) + # 拿到對應的正向算子 + pair_id = IdManager.find_pair_id(this_id) + if not pair_id or pair_id not in self.graph.node_map: + print_warn_log(f'{this_id} pair id not found') + return [] + f_node = self.graph.node_map.get(pair_id) + # 拿到正向算子的調用沾 + stack = [] + if f_node: + f_node = f_node.upnode + while f_node and self.graph._is_root(f_node) : + stack.append((f_node.id+'b', f_node.type)) + f_node = f_node.upnode + stack.reverse() + return stack + diff --git a/debug/visualization/builder/dynamo_tracer.py b/debug/visualization/builder/dynamo_tracer.py new file mode 100644 index 0000000000..8d5227efc8 --- /dev/null +++ b/debug/visualization/builder/dynamo_tracer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.fx as fx +from torch.export import export +from torch._functorch.aot_autograd import aot_export_module + +from builder.graph_tracer import GraphTracer + + +class LossWrapper(torch.nn.Module): + def __init__(self, net, loss, target): + super(LossWrapper, self).__init__() + self.net = net + self.loss = loss + self.target = target + + def forward(self, x): + y = self.net(x) + loss =y.sum() + #loss = self.loss(y, self.target) + y = y.detach() + return (loss, y) + + +class DynamoTracer(GraphTracer): + def __init__(self): + super(DynamoTracer, self).__init__() + + def trace(self, net:torch.nn.Module, x:torch.Tensor, loss=None, target=None): + export_program = export(net, args=(x,)) + f_graph = export_program.graph + if not loss: + return (f_graph, b_graph) + loss_wrapper = LossWrapper(net, loss, target) + b_graph = aot_export_module(loss_wrapper, (x,), trace_joint=True, output_loss_index=0)[0].graph + return (f_graph, b_graph) diff --git a/debug/visualization/builder/graph_builder.py b/debug/visualization/builder/graph_builder.py new file mode 100644 index 0000000000..f3a771163a --- /dev/null +++ b/debug/visualization/builder/graph_builder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from builder.graph_tracer import GraphTracer +from builder.dynamo_tracer import DynamoTracer +from builder.graph_parser import GraphParser +from builder.dynamo_parser import DynamoParser + + +class GraphBuilder(): + def __init__(self): + self.tracer = DynamoTracer() + self.parser = DynamoParser() + + def build(self, net:torch.nn.Module, x:torch.Tensor, loss=None, target=None): + trace_out = self.tracer.trace(net, x, loss, target) + model_class = net.__class__ + graph = self.parser.parse(trace_out, model_class) + return graph diff --git a/debug/visualization/builder/graph_parser.py b/debug/visualization/builder/graph_parser.py new file mode 100644 index 0000000000..e17786d87c --- /dev/null +++ b/debug/visualization/builder/graph_parser.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from graph.graph import Graph +from graph.base_node import BaseNode +from graph.node_op import NodeOp +from tool.id_manager import IdManager +from tool.log import print_warn_log + + +class GraphParser(ABC): + def __init__(self): + self.graph = None + + @abstractmethod + def parse(self, trace_out, model_class): + raise NotImplementedError('Function parse need to be implemented.') + + def _update_stack(self, stack_info, is_forward=True): + self.graph.depth = max(self.graph.depth, len(stack_info) + 1) + self.graph.recent_node = None + if not is_forward and not self.graph.root.pair: + raw_id = self.graph.root.id + 'b' + self._add_module(raw_id, self.graph.root.type, False) + else: + self.graph.recent_node = self.graph.root if is_forward else self.graph.root.pair + for key, module_class in stack_info: + if key in self.graph.rawid_map: + node = self.graph.rawid_map.get(key) + self.graph.recent_node = node + else: + self._add_module(key, module_class, is_forward) + + def _add_module(self, raw_id, module_class, is_forward=True): + module_node = self._add_node(NodeOp.module, module_class, raw_id, is_forward) + self.graph.recent_node = module_node + return module_node + + def _add_node(self, node_op, node_type, raw_id, is_forward=True): + this_node = BaseNode(node_op, node_type, self.graph.recent_node, is_forward) + node_id = this_node.id + self.graph.node_map[node_id] = this_node + self.graph.rawid_map[raw_id] = this_node + # 綁定反向關係 + if not is_forward: + pair_id = IdManager.find_pair_id(node_id) + if pair_id not in self.graph.node_map: + print_warn_log('_add_node not found') + return this_node + f_node = self.graph.node_map.get(pair_id) + BaseNode.add_direciton_pair(f_node, this_node) + return this_node + + def _connect(self, input_id, output_id): + if input_id not in self.graph.node_map or output_id not in self.graph.node_map: + return + inode = self.graph.node_map.get(input_id) + onode = self.graph.node_map.get(output_id) + istack = [] + while not self.graph._is_root(inode): + istack.append(inode) + inode = inode.upnode + ostack = [] + while not self.graph._is_root(onode): + ostack.append(onode) + onode = onode.upnode + while istack or ostack: + if istack: + inode = istack.pop() + if ostack: + onode = ostack.pop() + if (inode.id == onode.id): + continue + BaseNode.add_data_flow(inode, onode) diff --git a/debug/visualization/builder/graph_tracer.py b/debug/visualization/builder/graph_tracer.py new file mode 100644 index 0000000000..2997abdc8d --- /dev/null +++ b/debug/visualization/builder/graph_tracer.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from abc import ABC, abstractmethod + + +class GraphTracer(ABC): + def __init__(self): + pass + + @abstractmethod + def trace(self, net:torch.nn.Module, x:torch.Tensor, loss=None, target=None): + raise NotImplementedError('Function trace need to be implemented.') diff --git a/debug/visualization/graph/base_node.py b/debug/visualization/graph/base_node.py index 3640bee786..098fcecd09 100644 --- a/debug/visualization/graph/base_node.py +++ b/debug/visualization/graph/base_node.py @@ -17,48 +17,57 @@ import torch.fx as fx from tool.id_manager import IdManager from tool.dynamo_parser import DynamoParser +from tool.log import print_error_log from graph.node_op import NodeOp class BaseNode: - def __init__(self, input_arg, up_node=None): - if isinstance(input_arg, fx.Node): - self.op = DynamoParser.get_op(input_arg) - self.type = DynamoParser.get_type(input_arg) - elif isinstance(input_arg, str): - self.op = NodeOp.module - self.type = input_arg.lower() - else: - self.op = NodeOp.default - self.type = "default" - self.id = IdManager.get_id(self.type) + def __init__(self, node_op, node_type, up_node=None, is_forward=True): + self.op = node_op + self.type = node_type + self.id = IdManager.get_id(self.type, is_forward) self.data = {} self.outputs = [] self.inputs = [] self.upnode = up_node self.subnodes = [] + if up_node: + up_node.add_subnode(self) + self.is_forward = is_forward + self.pair = None def __str__(self): info = f'id:\t{self.id}' return info def get_info(self): - info = f'{self.id}\t{self.op}\n' + info = f'{self.id}\t{self.op}' + if not self.is_forward: + info += '(b)' for key in self.data: - info += f'{key}:\t{self.data.get(key)}' + info += f'\n{key}:\t{self.data.get(key)}' return info - def add_output(self, node): - if node.id == self.id: - return - self.outputs.append(node) - - def add_input(self, node): - if node.id == self.id: - return - self.inputs.append(node) - def add_subnode(self, node): if node.id == self.id: return self.subnodes.append(node) + + @staticmethod + def add_data_flow(inode, onode): + if not inode or not onode: + return + if inode.id == onode.id: + return + inode.outputs.append(onode) + onode.inputs.append(inode) + + @staticmethod + def add_direciton_pair(node1, node2): + if not node1 or not node2: + return + if node1.type != node2.type or node1.is_forward == node2.is_forward: + print_error_log("Error in add_direction_pair") + return + node1.pair = node2 + node2.pair = node1 diff --git a/debug/visualization/graph/graph.py b/debug/visualization/graph/graph.py index 61bdc00bd8..bf1a96b9b4 100644 --- a/debug/visualization/graph/graph.py +++ b/debug/visualization/graph/graph.py @@ -27,25 +27,7 @@ class Graph: self.depth = 0 self.node_map = {} self.rawid_map = {} - - def build_graph(self, net:torch.nn.Module, x:torch.Tensor): - self.root = BaseNode(net.__class__.__name__) - self.recent_node = self.root - raw_graph = DynamoParser.get_fx_graph_by_export(net, x) - for raw_node in raw_graph.nodes: - self._update_stack(raw_node) - self._add_node(raw_node) - for raw_node in raw_graph.nodes: - node_id = DynamoParser.get_id(raw_node) - for arg in raw_node.args: - if isinstance(arg, tuple): - arg = arg[0] - arg_id = DynamoParser.get_id(arg) - if arg_id not in self.node_map: - continue - self._connect(arg_id, node_id) - def update_info(self, info_map): for key in info_map: if key not in self.node_map: @@ -54,57 +36,17 @@ class Graph: node = self.node_map.get(key) for sub_key in info_map.get(key): node.data[sub_key] = info_map.get(key).get(sub_key) - - def _connect(self, input_id, output_id): - if input_id not in self.node_map or output_id not in self.node_map: - return - inode = self.node_map.get(input_id) - onode = self.node_map.get(output_id) - istack = [] - while not self._is_root(inode): - istack.append(inode) - inode = inode.upnode - ostack = [] - while not self._is_root(onode): - ostack.append(onode) - onode = onode.upnode - while istack or ostack: - if istack: - inode = istack.pop() - if ostack: - onode = ostack.pop() - if (inode.id == onode.id): - continue - inode.add_output(onode) - onode.add_input(inode) - - def _update_stack(self, raw_node): - stack_info = DynamoParser.get_stack_info(raw_node) - self.depth = max(self.depth, len(stack_info) + 1) - self.recent_node = self.root - for key in stack_info: - if key in self.rawid_map: - node = self.rawid_map.get(key) - self.recent_node = node - else: - _, module_class = stack_info.get(key) - self._add_module(key, module_class) - def _add_node(self, raw_node): - this_node = BaseNode(raw_node, self.recent_node) - self.node_map[this_node.id] = this_node - self.recent_node.add_subnode(this_node) - - def _add_module(self, raw_id, module_class): - type_name = module_class.__name__ - module_node = BaseNode(type_name, self.recent_node) - self.node_map[module_node.id] = module_node - self.rawid_map[raw_id] = module_node - self.recent_node.add_subnode(module_node) - self.recent_node = module_node - def _is_root(self, node): - return node.id == self.root.id + is_foward_root = node.id == self.root.id + is_backward_root = self.root.pair and node.id == self.root.pair.id + return is_foward_root or is_backward_root + + def get_output_nodes(self): + if 'output_0' not in self.node_map: + return [] + outputs = self.node_map.get('output_0').inputs + return outputs def __str__(self): infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] diff --git a/debug/visualization/test/test_script.py b/debug/visualization/test/test_script.py index 4531b3c319..30ca84074c 100644 --- a/debug/visualization/test/test_script.py +++ b/debug/visualization/test/test_script.py @@ -15,17 +15,23 @@ import torch -from test_net import LeNet2 +from test_net import LeNet2, AddOne, AddThree from graph.graph import Graph +from builder.graph_builder import GraphBuilder from tool.graph_viewer import GraphViewer if __name__ == '__main__': - net = LeNet2() - x = torch.randn(3, 32, 32) + # net = LeNet2() + # x = torch.randn((3, 32, 32), requires_grad=True) + # target = torch.randn(1, 10) + # loss = torch.nn.CrossEntropyLoss() - graph = Graph() - graph.build_graph(net, x) + net = AddOne() + x = torch.randn((4, 4), requires_grad=True) + target = torch.randn(4, 4) + loss = torch.nn.CrossEntropyLoss() + graph = GraphBuilder().build(net, x, loss, target) GraphViewer.save_full_level_pdf(graph, './output/') GraphViewer.save_tree(graph, './output/tree') diff --git a/debug/visualization/tool/graph_viewer.py b/debug/visualization/tool/graph_viewer.py index 40ac36e731..80a7545e09 100644 --- a/debug/visualization/tool/graph_viewer.py +++ b/debug/visualization/tool/graph_viewer.py @@ -26,6 +26,8 @@ class GraphViewer(): dot.attr(rankdir='TB') node_queue = queue.Queue() node_queue.put(graph.root) + if graph.root.pair: + node_queue.put(graph.root.pair) display_ids = set() for _ in range(level + 1): GraphViewer._search(display_ids, node_queue) @@ -41,7 +43,7 @@ class GraphViewer(): @staticmethod def save_full_level_pdf(graph:Graph, output_path): - level = graph.depth + level = graph.depth - 1 for i in range(level): GraphViewer.save_pdf(graph, i, f'{output_path}/level_{i}') @@ -49,10 +51,6 @@ class GraphViewer(): def save_tree(graph, output_file): dot = Digraph() dot.attr(rankdir='TB') - node = graph.root - dot.node(node.id, node.get_info()) - for subnode in node.subnodes: - dot.edge(node.id, subnode.id, style='dashed') for node_id in graph.node_map: node = graph.node_map.get(node_id) info = node.get_info() diff --git a/debug/visualization/tool/id_manager.py b/debug/visualization/tool/id_manager.py index 5ea283f824..efe6ed662d 100644 --- a/debug/visualization/tool/id_manager.py +++ b/debug/visualization/tool/id_manager.py @@ -13,13 +13,54 @@ # See the License for the specific language governing permissions and # limitations under the License. +from tool.log import print_error_log + class IdManager: type_count = {} + forward_end = False + forward_count = {} @classmethod - def get_id(cls, type_name:str): + def get_id(cls, type_name:str, is_forward=True): + if not is_forward and not cls.forward_end: + cls.start_backward() if type_name not in cls.type_count: cls.type_count[type_name] = -1 cls.type_count[type_name] += 1 return f'{type_name}_{cls.type_count.get(type_name)}' + + @classmethod + def get_next_id(cls, type_name:str): + if type_name in cls.type_count: + next_cnt = cls.type_count.get(type_name) + 1 + else: + next_cnt = 0 + return f'{type_name}_{next_cnt}' + + @classmethod + def start_backward(cls): + if cls.forward_end: + return + cls.forward_end = True + for key in cls.type_count: + cls.forward_count[key] = cls.type_count.get(key) + + @classmethod + def find_pair_id(cls, node_id): + splits = node_id.split('_') + if len(splits) < 2: + print_error_log('error in find_forward_id') + return '' + type_count = int(splits[-1]) + node_type = "_".join(splits[:-1]) + if node_type not in cls.forward_count: + print_error_log('type not found in forward_count') + return '' + base_cnt = cls.forward_count.get(node_type) + if type_count >= 2* (base_cnt + 1) or type_count < 0: + print_error_log('cnt is too big or small') + return '' + return f'{node_type}_{2 * base_cnt + 1 - type_count}' + + \ No newline at end of file -- Gitee