From f98f2a1a7119e18fab8b02f6a4219ff8e9a32541 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E6=99=93=E6=B3=A2?= Date: Mon, 8 Jan 2024 20:01:31 +0800 Subject: [PATCH] =?UTF-8?q?[feature]=E5=B0=8F=E6=A8=A1=E5=9E=8B=E4=B8=8B?= =?UTF-8?q?=E5=9F=BA=E4=BA=8Edynamo=E7=9A=84=E6=9E=84=E5=9B=BE=E6=95=B0?= =?UTF-8?q?=E6=8D=AE=E5=AF=B9=E8=B1=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/visualization/graph/base_node.py | 64 ++++++++++++ debug/visualization/graph/graph.py | 110 ++++++++++++++++++++ debug/visualization/graph/node_op.py | 24 +++++ debug/visualization/test/test_net.py | 116 ++++++++++++++++++++++ debug/visualization/test/test_script.py | 29 ++++++ debug/visualization/tool/dynamo_parser.py | 66 ++++++++++++ debug/visualization/tool/graph_viewer.py | 74 ++++++++++++++ debug/visualization/tool/id_manager.py | 26 +++++ 8 files changed, 509 insertions(+) create mode 100644 debug/visualization/graph/base_node.py create mode 100644 debug/visualization/graph/graph.py create mode 100644 debug/visualization/graph/node_op.py create mode 100644 debug/visualization/test/test_net.py create mode 100644 debug/visualization/test/test_script.py create mode 100644 debug/visualization/tool/dynamo_parser.py create mode 100644 debug/visualization/tool/graph_viewer.py create mode 100644 debug/visualization/tool/id_manager.py diff --git a/debug/visualization/graph/base_node.py b/debug/visualization/graph/base_node.py new file mode 100644 index 000000000..04a46816c --- /dev/null +++ b/debug/visualization/graph/base_node.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.fx as fx + +from tool.id_manager import IdManager +from tool.dynamo_parser import DynamoParser +from graph.node_op import NodeOp + + +class BaseNode: + def __init__(self, input_arg, up_node=None): + if isinstance(input_arg, fx.Node): + self.op = DynamoParser.get_op(input_arg) + self.type = DynamoParser.get_type(input_arg) + elif isinstance(input_arg, str): + self.op = NodeOp.module + self.type = input_arg.lower() + else: + self.op = NodeOp.default + self.type = "default" + self.id = IdManager.get_id(self.type) + self.data = {} + self.outputs = [] + self.inputs = [] + self.upnode = up_node + self.subnodes = [] + + def __str__(self): + info = f'id:\t{self.id}' + return info + + def get_info(self): + info = f'{self.id}\t{self.op}\n' + for key in self.data: + info += f'{key}:\t{self.data[key]}' + return info + + def add_output(self, node): + if node.id == self.id: + return + self.outputs.append(node) + + def add_input(self, node): + if node.id == self.id: + return + self.inputs.append(node) + + def add_subnode(self, node): + if node.id == self.id: + return + self.subnodes.append(node) \ No newline at end of file diff --git a/debug/visualization/graph/graph.py b/debug/visualization/graph/graph.py new file mode 100644 index 000000000..7a7cedfe3 --- /dev/null +++ b/debug/visualization/graph/graph.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from graph.base_node import BaseNode +from tool.dynamo_parser import DynamoParser + + +class Graph: + def __init__(self): + self.root = None + self.recent_node = None + self.depth = 0 + self.node_map = {} + self.rawid_map = {} + + def build_graph(self, net:torch.nn.Module, x:torch.Tensor): + self.root = BaseNode(net.__class__.__name__) + self.recent_node = self.root + + raw_graph = DynamoParser.get_fx_graph_by_export(net, x) + for raw_node in raw_graph.nodes: + self._update_stack(raw_node) + self._add_node(raw_node) + for raw_node in raw_graph.nodes: + node_id = DynamoParser.get_id(raw_node) + for arg in raw_node.args: + if isinstance(arg, tuple): + arg = arg[0] + arg_id = DynamoParser.get_id(arg) + if arg_id not in self.node_map: + continue + self._connect(arg_id, node_id) + + def update_info(self, info_map): + for key in info_map: + if key not in self.node_map: + print(f'{key} not found in node map') + continue + node = self.node_map[key] + for sub_key in info_map[key]: + node.data[sub_key] = info_map[key][sub_key] + + def _connect(self, input_id, output_id): + inode = self.node_map[input_id] + onode = self.node_map[output_id] + istack = [] + while not self._is_root(inode): + istack.append(inode) + inode = inode.upnode + ostack = [] + while not self._is_root(onode): + ostack.append(onode) + onode = onode.upnode + while len(istack) or len(ostack): + if len(istack): + inode = istack.pop() + if len(ostack): + onode = ostack.pop() + if (inode.id == onode.id): + continue + inode.add_output(onode) + onode.add_input(inode) + + def _update_stack(self, raw_node): + stack_info = DynamoParser.get_stack_info(raw_node) + self.depth = max(self.depth, len(stack_info) + 1) + self.recent_node = self.root + for key in stack_info: + if key in self.rawid_map: + node = self.rawid_map[key] + self.recent_node = node + else: + _, module_class = stack_info[key] + self._add_module(key, module_class) + + def _add_node(self, raw_node): + this_node = BaseNode(raw_node, self.recent_node) + self.node_map[this_node.id] = this_node + self.recent_node.add_subnode(this_node) + + def _add_module(self, raw_id, module_class): + type_name = module_class.__name__ + module_node = BaseNode(type_name, self.recent_node) + self.node_map[module_node.id] = module_node + self.rawid_map[raw_id] = module_node + self.recent_node.add_subnode(module_node) + self.recent_node = module_node + + def _is_root(self, node): + return node.id == self.root.id + + def __str__(self): + info = "" + for node_id in self.node_map: + info += f'{str(self.node_map[node_id])}\n' + return info \ No newline at end of file diff --git a/debug/visualization/graph/node_op.py b/debug/visualization/graph/node_op.py new file mode 100644 index 000000000..1d5a6b383 --- /dev/null +++ b/debug/visualization/graph/node_op.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class NodeOp(Enum): + module = 1 + function = 2 + param = 3 + constant = 4 + default = 5 \ No newline at end of file diff --git a/debug/visualization/test/test_net.py b/debug/visualization/test/test_net.py new file mode 100644 index 000000000..4705d1974 --- /dev/null +++ b/debug/visualization/test/test_net.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LeNet(nn.Module): + def __init__(self): + super(LeNet, self).__init__() + self.conv1 = nn.Conv2d(3, 16, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(16, 32, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.fc1 = nn.Linear(32*5*5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = self.pool1(x) + x = F.relu(self.conv2(x)) + x = self.pool2(x) + x = x.view(-1, 32*5*5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class ConvRelu(nn.Module): + def __init__(self, x, y, z): + super(ConvRelu, self).__init__() + self.conv = nn.Conv2d(x, y, z) + + def forward(self, x): + x = self.conv(x) + x = F.relu(x) + return x + + +class FcRelu(nn.Module): + def __init__(self, x, y): + super(FcRelu, self).__init__() + self.fc = nn.Linear(x, y) + + def forward(self, x): + x = self.fc(x) + x = F.relu(x) + return x + +class LeNet2(nn.Module): + def __init__(self): + super(LeNet2, self).__init__() + self.conv_relu1 = ConvRelu(3, 16, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.conv_relu2 = ConvRelu(16, 32, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.fc_relu = FcRelu(32*5*5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.conv_relu1(x) + x = self.pool1(x) + x = self.conv_relu2(x) + x = self.pool2(x) + x = x.view(-1, 32*5*5) + x = self.fc_relu(x) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class AddOne(nn.Module): + def __init__(self): + super(AddOne, self).__init__() + + def forward(self, x): + x = x + x + return x + +class AddTwo(nn.Module): + def __init__(self): + super(AddTwo, self).__init__() + self.add_one1 = AddOne() + self.add_one2 = AddOne() + + def forward(self, x): + x = self.add_one1(x) + x = self.add_one2(x) + return x + +class AddThree(nn.Module): + def __init__(self): + super(AddThree, self).__init__() + self.add_two = AddTwo() + self.add_one = AddOne() + + def forward(self, x): + x = self.add_two(x) + x = self.add_one(x) + return x \ No newline at end of file diff --git a/debug/visualization/test/test_script.py b/debug/visualization/test/test_script.py new file mode 100644 index 000000000..c37c0f8a4 --- /dev/null +++ b/debug/visualization/test/test_script.py @@ -0,0 +1,29 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from test_net import * +from graph.graph import Graph +from tool.graph_viewer import GraphViewer + + +if __name__ == '__main__': + net = LeNet2() + x = torch.randn(3, 32, 32) + + graph = Graph() + graph.build_graph(net, x) + + GraphViewer.save_full_level_pdf(graph, './output/') + GraphViewer.save_tree(graph, './output/tree') \ No newline at end of file diff --git a/debug/visualization/tool/dynamo_parser.py b/debug/visualization/tool/dynamo_parser.py new file mode 100644 index 000000000..c3675879a --- /dev/null +++ b/debug/visualization/tool/dynamo_parser.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.fx as fx +from torch.export import export + +from graph.node_op import NodeOp + + +class DynamoParser: + @staticmethod + def get_type(fx_node:fx.Node): + if fx_node.op == "placeholder": + return str(fx_node).lower() + elif fx_node.op == "call_function": + splits = str(fx_node).split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + result = "_".join(splits[:-1]) + return result.lower() + else: + return str(fx_node).lower() + elif fx_node.op == "output": + return "output" + + @staticmethod + def get_id(fx_node:fx.Node): + fx_name = str(fx_node) + if "arg" in fx_name: + fx_name = fx_name.lower() + return f'{fx_name}_0' + splits = fx_name.split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + return fx_name.lower() + else: + fx_name = fx_name.lower() + return f'{fx_name}_0' + + @staticmethod + def get_op(fx_node:fx.Node): + str_map = {"placeholder":NodeOp.param, "call_function":NodeOp.function, "output": NodeOp.default} + return str_map[fx_node.op] + + @staticmethod + def get_stack_info(fx_node:fx.Node): + if not fx_node.meta or "nn_module_stack" not in fx_node.meta: + return [] + return fx_node.meta["nn_module_stack"] + + @staticmethod + def get_fx_graph_by_export(net:torch.nn.Module, x:torch.Tensor): + export_program = export(net, args=(x,)) + fx_graph = export_program.graph + return fx_graph \ No newline at end of file diff --git a/debug/visualization/tool/graph_viewer.py b/debug/visualization/tool/graph_viewer.py new file mode 100644 index 000000000..77600e56f --- /dev/null +++ b/debug/visualization/tool/graph_viewer.py @@ -0,0 +1,74 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +from graphviz import Digraph + +from graph.graph import Graph + + +class GraphViewer(): + @staticmethod + def save_pdf(graph, level:int, output_file:str): + dot = Digraph() + dot.attr(rankdir='TB') + node_queue = queue.Queue() + node_queue.put(graph.root) + display_ids = set() + for _ in range(level+1): + display_ids.clear() + queue_size = node_queue.qsize() + for _ in range(queue_size): + node = node_queue.get() + if len(node.subnodes) == 0: + node_queue.put(node) + display_ids.add(node.id) + else: + for subnode in node.subnodes: + node_queue.put(subnode) + display_ids.add(subnode.id) + while not node_queue.empty(): + node = node_queue.get() + info = node.get_info() + dot.node(node.id, info, shape='rectangle') + for input_node in node.inputs: + if input_node.id not in display_ids: + continue + dot.edge(input_node.id, node.id) + dot.render(output_file) + + @staticmethod + def save_full_level_pdf(graph:Graph, output_path): + level = graph.depth + for i in range(level): + GraphViewer.save_pdf(graph, i, f'{output_path}/level_{i}') + + @staticmethod + def save_tree(graph, output_file): + dot = Digraph() + dot.attr(rankdir='TB') + node = graph.root + dot.node(node.id, node.get_info()) + for subnode in node.subnodes: + dot.edge(node.id, subnode.id, style='dashed') + for node_id in graph.node_map: + node = graph.node_map[node_id] + info = node.get_info() + dot.node(node.id, info, shape='rectangle') + for subnode in node.subnodes: + dot.edge(node.id, subnode.id, style='dashed') + for input_node in node.inputs: + dot.edge(input_node.id, node.id) + dot.render(output_file) diff --git a/debug/visualization/tool/id_manager.py b/debug/visualization/tool/id_manager.py new file mode 100644 index 000000000..29310cf3c --- /dev/null +++ b/debug/visualization/tool/id_manager.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class IdManager: + type_count = {} + + @classmethod + def get_id(cls, type_name:str): + if type_name not in cls.type_count: + cls.type_count[type_name] = -1 + cls.type_count[type_name] += 1 + return f'{type_name}_{cls.type_count[type_name]}' + \ No newline at end of file -- Gitee