diff --git a/debug/visualization/builder/dynamo_parser.py b/debug/visualization/builder/dynamo_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..cfccc1dd23d0790d35e729e6b690597031b92647 --- /dev/null +++ b/debug/visualization/builder/dynamo_parser.py @@ -0,0 +1,160 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.fx as fx +from torch.export import export + +from builder.graph_parser import GraphParser +from graph.graph import Graph +from graph.base_node import BaseNode +from graph.node_op import NodeOp +from tool.log import print_warn_log +from tool.id_manager import IdManager + + +class DynamoParser(GraphParser): + def __init__(self): + super(DynamoParser, self).__init__() + + def parse(self, trace_out, model_class): + f_graph = trace_out[0] + self.graph = Graph() + self.graph.root = self._add_module('', model_class.__name__) + self._parse_forward(trace_out[0]) + if not trace_out[1]: + return self.graph + # 開始解析我們的反向數據 + self._parse_backward(trace_out[1]) + return self.graph + + def _parse_forward(self, f_graph): + for raw_node in f_graph.nodes: + stack_info = self.get_stack_info(raw_node) + self._update_stack(stack_info) + node_op = self.get_op(raw_node) + node_type = self.get_type(raw_node) + raw_id = str(raw_node) + self._add_node(node_op, node_type, raw_id) + for raw_node in f_graph.nodes: + # todo + # 这个id怎么获取,讨论讨论 + raw_id = str(raw_node) + if raw_id not in self.graph.rawid_map: + continue + node_id = self.graph.rawid_map.get(raw_id).id + for arg in raw_node.args: + # todo + # 为啥呢,有没有影响 + if isinstance(arg, tuple): + arg = arg[0] + arg_id = str(arg) + # todo + # 爲啥要跳過 + if arg_id not in self.graph.rawid_map: + continue + self._connect(self.graph.rawid_map.get(arg_id).id, node_id) + + # todo + # 重複代碼整理 + def _parse_backward(self, b_graph): + last_nodes = self.graph.get_output_nodes() + for raw_node in b_graph.nodes: + if str(raw_node) in self.graph.rawid_map: + continue + stack_info = self.get_stack_info(raw_node, False) + self._update_stack(stack_info, is_forward=False) + node_op = self.get_op(raw_node) + node_type = self.get_type(raw_node) + raw_id = str(raw_node) + self._add_node(node_op, node_type, raw_id, is_forward=False) + for raw_node in b_graph.nodes: + # todo + # 这个id怎么获取,讨论讨论 + raw_id = str(raw_node) + if raw_id not in self.graph.rawid_map: + continue + this_node = self.graph.rawid_map.get(raw_id) + if this_node.is_forward: + continue + node_id = this_node.id + for arg in raw_node.args: + # todo + # 为啥呢,有没有影响 + if isinstance(arg, tuple): + arg = arg[0] + arg_id = str(arg) + if arg_id not in self.graph.rawid_map: + continue + self._connect(self.graph.rawid_map.get(arg_id).id, node_id) + + def get_stack_info(self, node:fx.Node, is_forward=True): + if is_forward: + return DynamoParser._get_forward_stack_info(node) + else: + return self._get_backward_stack_info(node) + + @staticmethod + def get_type(fx_node:fx.Node): + node_str = str(fx_node).lower() + if fx_node.op == "placeholder": + return node_str + elif fx_node.op == "call_function": + splits = node_str.split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + result = "_".join(splits[:-1]) + return result.lower() + else: + return node_str + elif fx_node.op == "output": + return "output" + else: + return "default" + + @staticmethod + def get_op(fx_node:fx.Node): + str_map = {"placeholder":NodeOp.param, "call_function":NodeOp.function, "output": NodeOp.default} + return str_map.get(fx_node.op, NodeOp.default) + + @staticmethod + def _get_forward_stack_info(node:fx.Node): + if not node.meta: + return [] + dynamo_stack_info = node.meta.get("nn_module_stack", {}) + stack_info = [ + (key, dynamo_stack_info.get(key)[1].__name__) for key in dynamo_stack_info + ] + return stack_info + + def _get_backward_stack_info(self, node:fx.Node): + # 獲得他會獲得的id + op = DynamoParser.get_type(node) + this_id = IdManager.get_next_id(op) + # 拿到對應的正向算子 + pair_id = IdManager.find_pair_id(this_id) + if not pair_id or pair_id not in self.graph.node_map: + print_warn_log(f'{this_id} pair id not found') + return [] + f_node = self.graph.node_map.get(pair_id) + # 拿到正向算子的調用沾 + stack = [] + if f_node: + f_node = f_node.upnode + while f_node and self.graph._is_root(f_node) : + stack.append((f_node.id+'b', f_node.type)) + f_node = f_node.upnode + stack.reverse() + return stack + diff --git a/debug/visualization/builder/dynamo_tracer.py b/debug/visualization/builder/dynamo_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..8d5227efc8d4888f6d28e15bfe366c46872f1c1a --- /dev/null +++ b/debug/visualization/builder/dynamo_tracer.py @@ -0,0 +1,50 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.fx as fx +from torch.export import export +from torch._functorch.aot_autograd import aot_export_module + +from builder.graph_tracer import GraphTracer + + +class LossWrapper(torch.nn.Module): + def __init__(self, net, loss, target): + super(LossWrapper, self).__init__() + self.net = net + self.loss = loss + self.target = target + + def forward(self, x): + y = self.net(x) + loss =y.sum() + #loss = self.loss(y, self.target) + y = y.detach() + return (loss, y) + + +class DynamoTracer(GraphTracer): + def __init__(self): + super(DynamoTracer, self).__init__() + + def trace(self, net:torch.nn.Module, x:torch.Tensor, loss=None, target=None): + export_program = export(net, args=(x,)) + f_graph = export_program.graph + if not loss: + return (f_graph, b_graph) + loss_wrapper = LossWrapper(net, loss, target) + b_graph = aot_export_module(loss_wrapper, (x,), trace_joint=True, output_loss_index=0)[0].graph + return (f_graph, b_graph) diff --git a/debug/visualization/builder/graph_builder.py b/debug/visualization/builder/graph_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..f3a771163a86a4307a863b3ea180a0008d3f1d8a --- /dev/null +++ b/debug/visualization/builder/graph_builder.py @@ -0,0 +1,33 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from builder.graph_tracer import GraphTracer +from builder.dynamo_tracer import DynamoTracer +from builder.graph_parser import GraphParser +from builder.dynamo_parser import DynamoParser + + +class GraphBuilder(): + def __init__(self): + self.tracer = DynamoTracer() + self.parser = DynamoParser() + + def build(self, net:torch.nn.Module, x:torch.Tensor, loss=None, target=None): + trace_out = self.tracer.trace(net, x, loss, target) + model_class = net.__class__ + graph = self.parser.parse(trace_out, model_class) + return graph diff --git a/debug/visualization/builder/graph_parser.py b/debug/visualization/builder/graph_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..e17786d87cf8215d156b2be0cb0cae1dc21ef7ea --- /dev/null +++ b/debug/visualization/builder/graph_parser.py @@ -0,0 +1,88 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from graph.graph import Graph +from graph.base_node import BaseNode +from graph.node_op import NodeOp +from tool.id_manager import IdManager +from tool.log import print_warn_log + + +class GraphParser(ABC): + def __init__(self): + self.graph = None + + @abstractmethod + def parse(self, trace_out, model_class): + raise NotImplementedError('Function parse need to be implemented.') + + def _update_stack(self, stack_info, is_forward=True): + self.graph.depth = max(self.graph.depth, len(stack_info) + 1) + self.graph.recent_node = None + if not is_forward and not self.graph.root.pair: + raw_id = self.graph.root.id + 'b' + self._add_module(raw_id, self.graph.root.type, False) + else: + self.graph.recent_node = self.graph.root if is_forward else self.graph.root.pair + for key, module_class in stack_info: + if key in self.graph.rawid_map: + node = self.graph.rawid_map.get(key) + self.graph.recent_node = node + else: + self._add_module(key, module_class, is_forward) + + def _add_module(self, raw_id, module_class, is_forward=True): + module_node = self._add_node(NodeOp.module, module_class, raw_id, is_forward) + self.graph.recent_node = module_node + return module_node + + def _add_node(self, node_op, node_type, raw_id, is_forward=True): + this_node = BaseNode(node_op, node_type, self.graph.recent_node, is_forward) + node_id = this_node.id + self.graph.node_map[node_id] = this_node + self.graph.rawid_map[raw_id] = this_node + # 綁定反向關係 + if not is_forward: + pair_id = IdManager.find_pair_id(node_id) + if pair_id not in self.graph.node_map: + print_warn_log('_add_node not found') + return this_node + f_node = self.graph.node_map.get(pair_id) + BaseNode.add_direciton_pair(f_node, this_node) + return this_node + + def _connect(self, input_id, output_id): + if input_id not in self.graph.node_map or output_id not in self.graph.node_map: + return + inode = self.graph.node_map.get(input_id) + onode = self.graph.node_map.get(output_id) + istack = [] + while not self.graph._is_root(inode): + istack.append(inode) + inode = inode.upnode + ostack = [] + while not self.graph._is_root(onode): + ostack.append(onode) + onode = onode.upnode + while istack or ostack: + if istack: + inode = istack.pop() + if ostack: + onode = ostack.pop() + if (inode.id == onode.id): + continue + BaseNode.add_data_flow(inode, onode) diff --git a/debug/visualization/builder/graph_tracer.py b/debug/visualization/builder/graph_tracer.py new file mode 100644 index 0000000000000000000000000000000000000000..2997abdc8d53563f284c1a98793c4ff76d3a63e5 --- /dev/null +++ b/debug/visualization/builder/graph_tracer.py @@ -0,0 +1,26 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from abc import ABC, abstractmethod + + +class GraphTracer(ABC): + def __init__(self): + pass + + @abstractmethod + def trace(self, net:torch.nn.Module, x:torch.Tensor, loss=None, target=None): + raise NotImplementedError('Function trace need to be implemented.') diff --git a/debug/visualization/graph/base_node.py b/debug/visualization/graph/base_node.py new file mode 100644 index 0000000000000000000000000000000000000000..098fcecd09f69ec95b6ce882fbc279451c5dbf04 --- /dev/null +++ b/debug/visualization/graph/base_node.py @@ -0,0 +1,73 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch.fx as fx + +from tool.id_manager import IdManager +from tool.dynamo_parser import DynamoParser +from tool.log import print_error_log +from graph.node_op import NodeOp + + +class BaseNode: + def __init__(self, node_op, node_type, up_node=None, is_forward=True): + self.op = node_op + self.type = node_type + self.id = IdManager.get_id(self.type, is_forward) + self.data = {} + self.outputs = [] + self.inputs = [] + self.upnode = up_node + self.subnodes = [] + if up_node: + up_node.add_subnode(self) + self.is_forward = is_forward + self.pair = None + + def __str__(self): + info = f'id:\t{self.id}' + return info + + def get_info(self): + info = f'{self.id}\t{self.op}' + if not self.is_forward: + info += '(b)' + for key in self.data: + info += f'\n{key}:\t{self.data.get(key)}' + return info + + def add_subnode(self, node): + if node.id == self.id: + return + self.subnodes.append(node) + + @staticmethod + def add_data_flow(inode, onode): + if not inode or not onode: + return + if inode.id == onode.id: + return + inode.outputs.append(onode) + onode.inputs.append(inode) + + @staticmethod + def add_direciton_pair(node1, node2): + if not node1 or not node2: + return + if node1.type != node2.type or node1.is_forward == node2.is_forward: + print_error_log("Error in add_direction_pair") + return + node1.pair = node2 + node2.pair = node1 diff --git a/debug/visualization/graph/graph.py b/debug/visualization/graph/graph.py new file mode 100644 index 0000000000000000000000000000000000000000..bf1a96b9b435c6295a7c95809705423eaadad3bc --- /dev/null +++ b/debug/visualization/graph/graph.py @@ -0,0 +1,54 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from graph.base_node import BaseNode +from tool.dynamo_parser import DynamoParser +from tool.log import print_error_log + + +class Graph: + def __init__(self): + self.root = None + self.recent_node = None + self.depth = 0 + self.node_map = {} + self.rawid_map = {} + + def update_info(self, info_map): + for key in info_map: + if key not in self.node_map: + print_error_log(f'{key} not found in node map') + continue + node = self.node_map.get(key) + for sub_key in info_map.get(key): + node.data[sub_key] = info_map.get(key).get(sub_key) + + def _is_root(self, node): + is_foward_root = node.id == self.root.id + is_backward_root = self.root.pair and node.id == self.root.pair.id + return is_foward_root or is_backward_root + + def get_output_nodes(self): + if 'output_0' not in self.node_map: + return [] + outputs = self.node_map.get('output_0').inputs + return outputs + + def __str__(self): + infos = [f'{str(self.node_map.get(node_id))}' for node_id in self.node_map] + info = "\n".join(infos) + return info diff --git a/debug/visualization/graph/node_op.py b/debug/visualization/graph/node_op.py new file mode 100644 index 0000000000000000000000000000000000000000..938664a60c17442b75d0fc5cd11da80784407337 --- /dev/null +++ b/debug/visualization/graph/node_op.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum + + +class NodeOp(Enum): + module = 1 + function = 2 + param = 3 + constant = 4 + default = 5 diff --git a/debug/visualization/test/test_net.py b/debug/visualization/test/test_net.py new file mode 100644 index 0000000000000000000000000000000000000000..9fb867f2b1462601ded9518c0a7695fba540b0d2 --- /dev/null +++ b/debug/visualization/test/test_net.py @@ -0,0 +1,119 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class LeNet(nn.Module): + def __init__(self): + super(LeNet, self).__init__() + self.conv1 = nn.Conv2d(3, 16, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.conv2 = nn.Conv2d(16, 32, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.fc1 = nn.Linear(32 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = F.relu(self.conv1(x)) + x = self.pool1(x) + x = F.relu(self.conv2(x)) + x = self.pool2(x) + x = x.view(-1, 32 * 5 * 5) + x = F.relu(self.fc1(x)) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class ConvRelu(nn.Module): + def __init__(self, x, y, z): + super(ConvRelu, self).__init__() + self.conv = nn.Conv2d(x, y, z) + + def forward(self, x): + x = self.conv(x) + x = F.relu(x) + return x + + +class FcRelu(nn.Module): + def __init__(self, x, y): + super(FcRelu, self).__init__() + self.fc = nn.Linear(x, y) + + def forward(self, x): + x = self.fc(x) + x = F.relu(x) + return x + + +class LeNet2(nn.Module): + def __init__(self): + super(LeNet2, self).__init__() + self.conv_relu1 = ConvRelu(3, 16, 5) + self.pool1 = nn.MaxPool2d(2, 2) + self.conv_relu2 = ConvRelu(16, 32, 5) + self.pool2 = nn.MaxPool2d(2, 2) + self.fc_relu = FcRelu(32 * 5 * 5, 120) + self.fc2 = nn.Linear(120, 84) + self.fc3 = nn.Linear(84, 10) + + def forward(self, x): + x = self.conv_relu1(x) + x = self.pool1(x) + x = self.conv_relu2(x) + x = self.pool2(x) + x = x.view(-1, 32 * 5 * 5) + x = self.fc_relu(x) + x = F.relu(self.fc2(x)) + x = self.fc3(x) + return x + + +class AddOne(nn.Module): + def __init__(self): + super(AddOne, self).__init__() + + def forward(self, x): + x = x + x + return x + + +class AddTwo(nn.Module): + def __init__(self): + super(AddTwo, self).__init__() + self.add_one1 = AddOne() + self.add_one2 = AddOne() + + def forward(self, x): + x = self.add_one1(x) + x = self.add_one2(x) + return x + + +class AddThree(nn.Module): + def __init__(self): + super(AddThree, self).__init__() + self.add_two = AddTwo() + self.add_one = AddOne() + + def forward(self, x): + x = self.add_two(x) + x = self.add_one(x) + return x diff --git a/debug/visualization/test/test_script.py b/debug/visualization/test/test_script.py new file mode 100644 index 0000000000000000000000000000000000000000..30ca84074c6c0eb4ee1d7178916eb0dc4b73d7f6 --- /dev/null +++ b/debug/visualization/test/test_script.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from test_net import LeNet2, AddOne, AddThree +from graph.graph import Graph +from builder.graph_builder import GraphBuilder +from tool.graph_viewer import GraphViewer + + +if __name__ == '__main__': + # net = LeNet2() + # x = torch.randn((3, 32, 32), requires_grad=True) + # target = torch.randn(1, 10) + # loss = torch.nn.CrossEntropyLoss() + + net = AddOne() + x = torch.randn((4, 4), requires_grad=True) + target = torch.randn(4, 4) + loss = torch.nn.CrossEntropyLoss() + + graph = GraphBuilder().build(net, x, loss, target) + GraphViewer.save_full_level_pdf(graph, './output/') + GraphViewer.save_tree(graph, './output/tree') diff --git a/debug/visualization/tool/dynamo_parser.py b/debug/visualization/tool/dynamo_parser.py new file mode 100644 index 0000000000000000000000000000000000000000..2835e70a791bf316cce0f019d637be24b2659123 --- /dev/null +++ b/debug/visualization/tool/dynamo_parser.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.fx as fx +from torch.export import export + +from graph.node_op import NodeOp + + +class DynamoParser: + @staticmethod + def get_type(fx_node:fx.Node): + if fx_node.op == "placeholder": + return str(fx_node).lower() + elif fx_node.op == "call_function": + splits = str(fx_node).split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + result = "_".join(splits[:-1]) + return result.lower() + else: + return str(fx_node).lower() + elif fx_node.op == "output": + return "output" + else: + return "default" + + @staticmethod + def get_id(fx_node:fx.Node): + fx_name = str(fx_node).lower() + if "arg" in fx_name: + return f'{fx_name}_0' + splits = fx_name.split('_') + if len(splits) >= 2 and splits[-1].isdigit(): + return fx_name + else: + return f'{fx_name}_0' + + @staticmethod + def get_op(fx_node:fx.Node): + str_map = {"placeholder":NodeOp.param, "call_function":NodeOp.function, "output": NodeOp.default} + return str_map.get(fx_node.op, NodeOp.default) + + @staticmethod + def get_stack_info(fx_node:fx.Node): + if not fx_node.meta: + return [] + return fx_node.meta.get("nn_module_stack", []) + + @staticmethod + def get_fx_graph_by_export(net:torch.nn.Module, x:torch.Tensor): + export_program = export(net, args=(x,)) + fx_graph = export_program.graph + return fx_graph diff --git a/debug/visualization/tool/graph_viewer.py b/debug/visualization/tool/graph_viewer.py new file mode 100644 index 0000000000000000000000000000000000000000..80a7545e0936d8798a79c3240106bc28d25fe17a --- /dev/null +++ b/debug/visualization/tool/graph_viewer.py @@ -0,0 +1,80 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import queue +from graphviz import Digraph + +from graph.graph import Graph + + +class GraphViewer(): + @staticmethod + def save_pdf(graph, level:int, output_file:str): + dot = Digraph() + dot.attr(rankdir='TB') + node_queue = queue.Queue() + node_queue.put(graph.root) + if graph.root.pair: + node_queue.put(graph.root.pair) + display_ids = set() + for _ in range(level + 1): + GraphViewer._search(display_ids, node_queue) + while not node_queue.empty(): + node = node_queue.get() + info = node.get_info() + dot.node(node.id, info, shape='rectangle') + for input_node in node.inputs: + if input_node.id not in display_ids: + continue + dot.edge(input_node.id, node.id) + dot.render(output_file) + + @staticmethod + def save_full_level_pdf(graph:Graph, output_path): + level = graph.depth - 1 + for i in range(level): + GraphViewer.save_pdf(graph, i, f'{output_path}/level_{i}') + + @staticmethod + def save_tree(graph, output_file): + dot = Digraph() + dot.attr(rankdir='TB') + for node_id in graph.node_map: + node = graph.node_map.get(node_id) + info = node.get_info() + dot.node(node.id, info, shape='rectangle') + for subnode in node.subnodes: + dot.edge(node.id, subnode.id, style='dashed') + for input_node in node.inputs: + dot.edge(input_node.id, node.id) + dot.render(output_file) + + @staticmethod + def _search(display_ids, node_queue): + display_ids.clear() + queue_size = node_queue.qsize() + for _ in range(queue_size): + node = node_queue.get() + if len(node.subnodes) == 0: + node_queue.put(node) + display_ids.add(node.id) + else: + GraphViewer._add_subnodes(display_ids, node_queue, node) + + @staticmethod + def _add_subnodes(display_ids, node_queue, node): + for subnode in node.subnodes: + node_queue.put(subnode) + display_ids.add(subnode.id) diff --git a/debug/visualization/tool/id_manager.py b/debug/visualization/tool/id_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..efe6ed662dce8e841eb4f7888fc7323a08e783be --- /dev/null +++ b/debug/visualization/tool/id_manager.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from tool.log import print_error_log + + +class IdManager: + type_count = {} + forward_end = False + forward_count = {} + + @classmethod + def get_id(cls, type_name:str, is_forward=True): + if not is_forward and not cls.forward_end: + cls.start_backward() + if type_name not in cls.type_count: + cls.type_count[type_name] = -1 + cls.type_count[type_name] += 1 + return f'{type_name}_{cls.type_count.get(type_name)}' + + @classmethod + def get_next_id(cls, type_name:str): + if type_name in cls.type_count: + next_cnt = cls.type_count.get(type_name) + 1 + else: + next_cnt = 0 + return f'{type_name}_{next_cnt}' + + @classmethod + def start_backward(cls): + if cls.forward_end: + return + cls.forward_end = True + for key in cls.type_count: + cls.forward_count[key] = cls.type_count.get(key) + + @classmethod + def find_pair_id(cls, node_id): + splits = node_id.split('_') + if len(splits) < 2: + print_error_log('error in find_forward_id') + return '' + type_count = int(splits[-1]) + node_type = "_".join(splits[:-1]) + if node_type not in cls.forward_count: + print_error_log('type not found in forward_count') + return '' + base_cnt = cls.forward_count.get(node_type) + if type_count >= 2* (base_cnt + 1) or type_count < 0: + print_error_log('cnt is too big or small') + return '' + return f'{node_type}_{2 * base_cnt + 1 - type_count}' + + \ No newline at end of file diff --git a/debug/visualization/tool/log.py b/debug/visualization/tool/log.py new file mode 100644 index 0000000000000000000000000000000000000000..fe21c997ec807a35c7ad997eec105e113af9ce88 --- /dev/null +++ b/debug/visualization/tool/log.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import sys + + +def _print_log(level, msg, end='\n'): + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) + pid = os.getgid() + print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) + sys.stdout.flush() + + +def print_info_log(info_msg, end='\n'): + """ + Function Description: + print info log. + Parameter: + info_msg: the info message. + """ + _print_log("INFO", info_msg, end=end) + + +def print_error_log(error_msg): + """ + Function Description: + print error log. + Parameter: + error_msg: the error message. + """ + _print_log("ERROR", error_msg) + + +def print_warn_log(warn_msg): + """ + Function Description: + print warn log. + Parameter: + warn_msg: the warning message. + """ + _print_log("WARNING", warn_msg)