From cc6cd394951d32ac5b6831b7a6b44388b1230be9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E6=99=93=E6=B3=A2?= Date: Wed, 27 Mar 2024 21:06:27 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E7=BA=A7=E5=8F=AF=E8=A7=86=E5=8C=96ut?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/visualization/data/ptdbg_parser.py | 4 +- .../{test => example}/test_net.py | 29 +---- .../{test => example}/test_script.py | 12 +- debug/visualization/graph/base_node.py | 2 +- debug/visualization/graph/edge_manager.py | 2 +- debug/visualization/test/resources/test.pkl | 9 ++ debug/visualization/test/run_test.sh | 36 ++++++ debug/visualization/test/run_ut.py | 58 +++++++++ debug/visualization/test/ut/test_base_node.py | 67 ++++++++++ .../test/ut/test_edge_manager.py | 61 +++++++++ .../test/ut/test_graph_builder.py | 64 ++++++++++ .../visualization/test/ut/test_id_manager.py | 58 +++++++++ .../test/ut/test_ptdbg_parser.py | 53 ++++++++ debug/visualization/tool/file_check_util.py | 117 ------------------ debug/visualization/tool/graph_viewer.py | 2 +- debug/visualization/tool/id_manager.py | 6 + 16 files changed, 422 insertions(+), 158 deletions(-) rename debug/visualization/{test => example}/test_net.py (78%) rename debug/visualization/{test => example}/test_script.py (70%) create mode 100644 debug/visualization/test/resources/test.pkl create mode 100644 debug/visualization/test/run_test.sh create mode 100644 debug/visualization/test/run_ut.py create mode 100644 debug/visualization/test/ut/test_base_node.py create mode 100644 debug/visualization/test/ut/test_edge_manager.py create mode 100644 debug/visualization/test/ut/test_graph_builder.py create mode 100644 debug/visualization/test/ut/test_id_manager.py create mode 100644 debug/visualization/test/ut/test_ptdbg_parser.py diff --git a/debug/visualization/data/ptdbg_parser.py b/debug/visualization/data/ptdbg_parser.py index 8a39157f1..ca7057af4 100644 --- a/debug/visualization/data/ptdbg_parser.py +++ b/debug/visualization/data/ptdbg_parser.py @@ -138,8 +138,8 @@ class PtdbgCell: name = "_".join(words[1:-3]) # 合并Function/Tensor以后,编号以前的所有名字,因为可能会有下划线 return [words[0]] + [name] + words[-3:] - @classmethod - def _name_map(cls, api_name, api_cnt): + @staticmethod + def _name_map(api_name, api_cnt): # 这里会删除apiname的下划线,从而匹配的更多 api_name.replace('_', '') return f'{api_name}_{api_cnt}' diff --git a/debug/visualization/test/test_net.py b/debug/visualization/example/test_net.py similarity index 78% rename from debug/visualization/test/test_net.py rename to debug/visualization/example/test_net.py index 9fb867f2b..f699155de 100644 --- a/debug/visualization/test/test_net.py +++ b/debug/visualization/example/test_net.py @@ -16,30 +16,7 @@ import torch import torch.nn as nn import torch.nn.functional as F - - -class LeNet(nn.Module): - def __init__(self): - super(LeNet, self).__init__() - self.conv1 = nn.Conv2d(3, 16, 5) - self.pool2 = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(16, 32, 5) - self.pool2 = nn.MaxPool2d(2, 2) - self.fc1 = nn.Linear(32 * 5 * 5, 120) - self.fc2 = nn.Linear(120, 84) - self.fc3 = nn.Linear(84, 10) - - def forward(self, x): - x = F.relu(self.conv1(x)) - x = self.pool1(x) - x = F.relu(self.conv2(x)) - x = self.pool2(x) - x = x.view(-1, 32 * 5 * 5) - x = F.relu(self.fc1(x)) - x = F.relu(self.fc2(x)) - x = self.fc3(x) - return x - + class ConvRelu(nn.Module): def __init__(self, x, y, z): @@ -63,9 +40,9 @@ class FcRelu(nn.Module): return x -class LeNet2(nn.Module): +class LeNet(nn.Module): def __init__(self): - super(LeNet2, self).__init__() + super(LeNet, self).__init__() self.conv_relu1 = ConvRelu(3, 16, 5) self.pool1 = nn.MaxPool2d(2, 2) self.conv_relu2 = ConvRelu(16, 32, 5) diff --git a/debug/visualization/test/test_script.py b/debug/visualization/example/test_script.py similarity index 70% rename from debug/visualization/test/test_script.py rename to debug/visualization/example/test_script.py index b5918367b..f9ad5775a 100644 --- a/debug/visualization/test/test_script.py +++ b/debug/visualization/example/test_script.py @@ -15,7 +15,7 @@ import torch -from test_net import LeNet2, AddOne, AddThree +from example.test_net import AddThree from graph.graph import Graph from builder.graph_builder import GraphBuilder from tool.graph_viewer import GraphViewer @@ -25,16 +25,8 @@ from data.ptdbg_parser import PtdbgParser if __name__ == '__main__': net = AddThree() x = torch.randn((4, 4), requires_grad=True) - - # 普通构建 - # graph = GraphBuilder().build(net, x, need_backward=False) - - # 基于ptdbg数据构建 - # PtdbgParser().run(net, x, './output/') - pkl_path = './output/ptdbg_dump_v4.0.T1/rank779468/api_stack_dump.pkl' - graph = GraphBuilder().build(net, x, need_backward=False, data_path=pkl_path) + graph = GraphBuilder().build(net, x, need_backward=False) GraphBuilder.export_to_yaml('./output/export.yaml', graph) GraphViewer.save_full_level_pdf(graph, './output/') GraphViewer.save_tree(graph, './output/tree') - # data = GraphBuilder.build_dict_from_yaml('./output/export.yaml') diff --git a/debug/visualization/graph/base_node.py b/debug/visualization/graph/base_node.py index fdcd97ce2..a84b6c10a 100644 --- a/debug/visualization/graph/base_node.py +++ b/debug/visualization/graph/base_node.py @@ -83,7 +83,7 @@ class BaseNode: return edge_id @staticmethod - def add_direciton_pair(node1, node2): + def add_direction_pair(node1, node2): if not node1 or not node2: return if node1.type != node2.type or node1.is_forward == node2.is_forward: diff --git a/debug/visualization/graph/edge_manager.py b/debug/visualization/graph/edge_manager.py index 7cc58005d..0fd59dc1a 100644 --- a/debug/visualization/graph/edge_manager.py +++ b/debug/visualization/graph/edge_manager.py @@ -40,7 +40,7 @@ class EdgeManager: if edge_id >= len(cls.data_list): print_error_log("Index out of range in EdgeManger") return {} - return str(cls.data_list[edge_id]) + return cls.data_list[edge_id] @classmethod def set_edge(cls, edge_id, data): diff --git a/debug/visualization/test/resources/test.pkl b/debug/visualization/test/resources/test.pkl new file mode 100644 index 000000000..6413115f4 --- /dev/null +++ b/debug/visualization/test/resources/test.pkl @@ -0,0 +1,9 @@ +["Tensor___add___0_forward_input.0", 0, [], "torch.float32", [4,4], [0.88, -1.40, -0.25, 2.55]] +["Tensor___add___0_forward_input.1", 0, [], "torch.float32", [4,4], [0.88, -1.40, -0.25, 2.55]] +["Tensor___add___0_forward_output", 0, [], "torch.float32", [4,4], [0.76, -2.80, -0.50, 5.10]] +["Tensor___add___1_forward_input.0", 0, [], "torch.float32", [4,4], [0.76, -2.80, -0.50, 5.10]] +["Tensor___add___1_forward_input.1", 0, [], "torch.float32", [4,4], [0.76, -2.80, -0.50, 5.10]] +["Tensor___add___1_forward_output", 0, [], "torch.float32", [4,4], [-3.53, -5.60, -1.00, 10.20]] +["Tensor___add___2_forward_input.0", 0, [], "torch.float32", [4,4], [-3.53, -5.60, -1.00, 10.20]] +["Tensor___add___2_forward_input.1", 0, [], "torch.float32", [4,4], [-3.53, -5.60, -1.00, 10.20]] +["Tensor___add___2_forward_output", 0, [], "torch.float32", [4,4], [7.07, -11.21, -2.01, 20.40]] \ No newline at end of file diff --git a/debug/visualization/test/run_test.sh b/debug/visualization/test/run_test.sh new file mode 100644 index 000000000..b4275d983 --- /dev/null +++ b/debug/visualization/test/run_test.sh @@ -0,0 +1,36 @@ +#!/bin/bash +CUR_DIR=$(dirname $(readlink -f $0)) +TOP_DIR=${CUR_DIR}/.. +TEST_DIR=${TOP_DIR}/"test" +SRC_DIR=${TOP_DIR}/../ + +clean() { + cd ${TEST_DIR} + + if [ -e ${TEST_DIR}/"report" ]; then + rm -r ${TEST_DIR}/"report" + echo "remove last ut_report successfully." + fi + + if [ -e ${TEST_DIR}/"output" ]; then + rm -r ${TEST_DIR}/"output" + echo "remove last output successfully." + fi + +} + +run_ut() { + export PYTHONPATH=${SRC_DIR}:${PYTHONPATH} + python3 run_ut.py +} + +main() { + clean + if [ "$1"x == "clean"x ]; then + return 0 + fi + + cd ${TEST_DIR} && run_ut +} + +main $@ \ No newline at end of file diff --git a/debug/visualization/test/run_ut.py b/debug/visualization/test/run_ut.py new file mode 100644 index 000000000..3230d23ea --- /dev/null +++ b/debug/visualization/test/run_ut.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +import subprocess +import sys + +def run_ut(): + cur_dir = os.path.realpath(os.path.dirname(__file__)) + top_dir = os.path.realpath(os.path.dirname(cur_dir)) + ut_path = os.path.join(cur_dir, "ut/") + src_dir = top_dir + report_dir = os.path.join(cur_dir, "report") + + os.chdir(os.path.abspath('..')) + + if os.path.exists(report_dir): + shutil.rmtree(report_dir) + + os.makedirs(report_dir) + + cmd = ["python3", "-m", "pytest", ut_path, "--junitxml=" + report_dir + "/final.xml", + "--cov=" + src_dir, "--cov-branch", "--cov-report=xml:" + report_dir + "/coverage.xml"] + + result_ut = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + + while result_ut.poll() is None: + line = result_ut.stdout.readline().strip() + if line: + print(line) + + ut_flag = False + if result_ut.returncode == 0: + ut_flag = True + print("run ut successfully.") + else: + print("run ut failed.") + + return ut_flag + +if __name__=="__main__": + if run_ut(): + sys.exit(0) + else: + sys.exit(1) \ No newline at end of file diff --git a/debug/visualization/test/ut/test_base_node.py b/debug/visualization/test/ut/test_base_node.py new file mode 100644 index 000000000..7d5ef9e81 --- /dev/null +++ b/debug/visualization/test/ut/test_base_node.py @@ -0,0 +1,67 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from graph.base_node import BaseNode +from graph.node_op import NodeOp +from tool.id_manager import IdManager + + +class TestBaseNodeCase(unittest.TestCase): + def setUp(self): + IdManager.clear() + self.test_node1 = BaseNode(NodeOp.module, "module_test") + self.test_node2 = BaseNode(NodeOp.function_api, "api_test") + self.test_node3 = BaseNode(NodeOp.function_api, "api_test", up_node=self.test_node1) + self.test_node4 = BaseNode(NodeOp.function_api, "api_test", up_node=None, is_forward=False) + + def test_init_id(self): + self.assertEqual(self.test_node2.id, "api_test_0") + self.assertEqual(self.test_node3.id, "api_test_1") + + def test_update_subnode(self): + self.assertEqual(len(self.test_node1.subnodes), 1) + node = self.test_node1.subnodes[0] + self.assertEqual(node.id, self.test_node3.id) + + def test_get_info(self): + info1 = self.test_node1.get_info() + self.assertIn("module", info1) + info2 = self.test_node4.get_info() + self.assertIn("(b)", info2) + + def test_add_data_flow(self): + edge_id = BaseNode.add_data_flow(None, None) + self.assertEqual(edge_id, -1) + edge_id = BaseNode.add_data_flow(self.test_node2, self.test_node2) + self.assertEqual(edge_id, -1) + edge_id = BaseNode.add_data_flow(self.test_node2, self.test_node3) + self.assertEqual(edge_id, 0) + edge_id = BaseNode.add_data_flow(self.test_node2, self.test_node3) + self.assertEqual(edge_id, 1) + edge_id = BaseNode.add_data_flow(self.test_node2, self.test_node3, edge_id) + self.assertEqual(edge_id, 1) + + def test_add_direction_pair(self): + BaseNode.add_direction_pair(None, None) + BaseNode.add_direction_pair(self.test_node2, self.test_node4) + self.assertNotEqual(self.test_node2.pair, None) + + def tearDown(self): + del self.test_node1 + del self.test_node2 + del self.test_node3 + del self.test_node4 diff --git a/debug/visualization/test/ut/test_edge_manager.py b/debug/visualization/test/ut/test_edge_manager.py new file mode 100644 index 000000000..feb89e4e0 --- /dev/null +++ b/debug/visualization/test/ut/test_edge_manager.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from graph.edge_manager import EdgeManager +from graph.base_node import BaseNode +from graph.node_op import NodeOp + + +class TestEdgeManagerCase(unittest.TestCase): + def setUp(self): + EdgeManager.clear() + self.id1 = EdgeManager.init_data_flow() + self.id2 = EdgeManager.init_data_flow() + EdgeManager.set_edge(self.id2, {"data1":1, "data2":2}) + + def test_init_data_flow(self): + self.assertEqual(self.id1, 0) + self.assertEqual(self.id2, 1) + + def test_get_edge(self): + data1 = EdgeManager.get_edge(self.id2) + self.assertEqual(len(data1), 2) + data2 = EdgeManager.get_edge(999) + self.assertEqual(len(data2), 0) + data3 = EdgeManager.get_edge(self.id1) + self.assertEqual(len(data3), 0) + + def test_get_edge_pairs(self): + node1_0 = BaseNode(NodeOp.function_api, "add") + node1_1 = BaseNode(NodeOp.function_api, "add", up_node=node1_0) + node2_0 = BaseNode(NodeOp.function_api, "add") + node2_1 = BaseNode(NodeOp.function_api, "add", up_node=node2_0) + node2_2 = BaseNode(NodeOp.function_api, "add", up_node=node2_1) + pairs = EdgeManager.get_edge_pairs(node1_1, node2_2) + self.assertEqual(len(pairs), 3) + # 顶层节点连线 + self.assertEqual(pairs[0][0].id, node1_0.id) + self.assertEqual(pairs[0][1].id, node2_0.id) + # 向下一层 + self.assertEqual(pairs[1][0].id, node1_1.id) + self.assertEqual(pairs[1][1].id, node2_1.id) + # 最底层,其中node1_1因为无法继续向下,因此暂停 + self.assertEqual(pairs[2][0].id, node1_1.id) + self.assertEqual(pairs[2][1].id, node2_2.id) + + def tearDown(self): + pass diff --git a/debug/visualization/test/ut/test_graph_builder.py b/debug/visualization/test/ut/test_graph_builder.py new file mode 100644 index 000000000..44dc4ca32 --- /dev/null +++ b/debug/visualization/test/ut/test_graph_builder.py @@ -0,0 +1,64 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import torch + +from example.test_net import AddThree, LeNet +from graph.graph import Graph +from builder.graph_builder import GraphBuilder +from tool.graph_viewer import GraphViewer +from data.ptdbg_parser import PtdbgParser + + +class TestGraphBuilderCase(unittest.TestCase): + def setUp(self): + pass + + def test_addthree_forward(self): + net = AddThree() + x = torch.randn((4, 4), requires_grad=True) + graph = GraphBuilder().build(net, x, need_backward=False) + if not os.path.exists('./output'): + os.mkdir('./output') + GraphBuilder.export_to_yaml('./output/export.yaml', graph) + GraphViewer.save_full_level_pdf(graph, './output/') + GraphViewer.save_tree(graph, './output/tree') + self.assertEqual(graph.depth, 4) + self.assertEqual(len(graph.node_map), 11) + + def test_addthree_backward(self): + net1 = LeNet() + net2 = AddThree() + x1 = torch.randn((3, 32, 32), requires_grad=False) + output1 = net1(x1) + output2 = net2(x1) + net = AddThree() + x = torch.rand((3, 32, 32), requires_grad=True) + graph = GraphBuilder().build(net, x, need_backward=True) + self.assertEqual(graph.depth, 4) + self.assertEqual(len(graph.node_map), 19) + + def test_ptdbg_parser(self): + net = AddThree() + x = torch.randn((4, 4), requires_grad=False) + pkl_path = './test/resources/test.pkl' + graph = GraphBuilder().build(net, x, need_backward=False, data_path=pkl_path) + self.assertEqual(graph.depth, 4) + self.assertEqual(len(graph.node_map), 11) + + def tearDown(self): + pass diff --git a/debug/visualization/test/ut/test_id_manager.py b/debug/visualization/test/ut/test_id_manager.py new file mode 100644 index 000000000..c036d63fb --- /dev/null +++ b/debug/visualization/test/ut/test_id_manager.py @@ -0,0 +1,58 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tool.id_manager import IdManager + + +class TestIdManagerCase(unittest.TestCase): + def setUp(self): + IdManager.clear() + + def test_get_id(self): + id1 = IdManager.get_id('add') + self.assertEqual(id1, 'add_0') + id2 = IdManager.get_id('mul') + self.assertEqual(id2, 'mul_0') + id3 = IdManager.get_id('add') + self.assertEqual(id3, 'add_1') + + def test_get_next_id(self): + id1 = IdManager.get_next_id('add') + self.assertEqual(id1, 'add_0') + id2 = IdManager.get_id('add') + self.assertEqual(id1, id2) + id3 = IdManager.get_id('add') + self.assertNotEqual(id2, id3) + id4 = IdManager.get_next_id('add') + id5 = IdManager.get_next_id('add') + self.assertEqual(id4, id5) + + def test_find_pair_id(self): + id1 = IdManager.get_id('add') + id2 = IdManager.get_id('add') + id3 = IdManager.get_id('add', is_forward=False) + pair_id1 = IdManager.find_pair_id(id2) + pair_id2 = IdManager.find_pair_id(id3) + self.assertEqual(pair_id1, id3) + self.assertEqual(pair_id2, id2) + pair_id3 = IdManager.find_pair_id('add_5') + self.assertEqual(pair_id3, '') + pair_id4 = IdManager.find_pair_id('12345') + self.assertEqual(pair_id4, '') + + def tearDown(self): + pass diff --git a/debug/visualization/test/ut/test_ptdbg_parser.py b/debug/visualization/test/ut/test_ptdbg_parser.py new file mode 100644 index 000000000..1b009555f --- /dev/null +++ b/debug/visualization/test/ut/test_ptdbg_parser.py @@ -0,0 +1,53 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from data.ptdbg_parser import PtdbgCell + + +class TestPtdbgCellCase(unittest.TestCase): + def setUp(self): + pass + + def test_equal(self): + data1 = [1, 2, 3, 4, 5] + data2 = [1, 2, 3, 4] + data3 = [1, 2, 3, 4, 6] + result1 = PtdbgCell._equal(data1, data2) + self.assertFalse(result1) + result2 = PtdbgCell._equal(data1, data3) + self.assertFalse(result2) + result3 = PtdbgCell._equal(data1, data1) + self.assertTrue(result3) + + def test_name_map(self): + name1 = PtdbgCell._name_map('add', 0) + self.assertEqual(name1, 'add_0') + name2 = PtdbgCell._name_map('add_test', 0) + self.assertEqual(name2, 'add_test_0') + name3 = PtdbgCell._name_map('conv_relu', 1) + self.assertEqual(name3, 'conv_relu_1') + + def test_process_name(self): + result1 = PtdbgCell._process_name('Functional_relu_0_forward_output') + self.assertEqual(result1, ['Functional', 'relu', '0', 'forward', 'output']) + result2 = PtdbgCell._process_name('Functional_relu_test_123_0_forward_output') + self.assertEqual(result2, ['Functional', 'relu_test_123', '0', 'forward', 'output']) + result3 = PtdbgCell._process_name('Functional___relu___0_forward_output') + self.assertEqual(result3, ['Functional', 'relu', '0', 'forward', 'output']) + + def tearDown(self): + pass diff --git a/debug/visualization/tool/file_check_util.py b/debug/visualization/tool/file_check_util.py index 34aa8dc4d..925f0b12e 100644 --- a/debug/visualization/tool/file_check_util.py +++ b/debug/visualization/tool/file_check_util.py @@ -24,9 +24,6 @@ class FileCheckConst: """ Class for file check const """ - READ_ABLE = "read" - WRITE_ABLE = "write" - READ_WRITE_ABLE = "read and write" DIRECTORY_LENGTH = 4096 FILE_NAME_LENGTH = 255 FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" @@ -75,58 +72,6 @@ class FileCheckException(Exception): return self.error_info -class FileChecker: - """ - The class for check file. - - Attributes: - file_path: The file or dictionary path to be verified. - path_type: file or dictionary - ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability - file_type(str): The correct file type for file - """ - def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): - self.file_path = file_path - self.path_type = self._check_path_type(path_type) - self.ability = ability - self.file_type = file_type - self.is_script = is_script - - @staticmethod - def _check_path_type(path_type): - if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: - print_error_log(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') - raise FileCheckException(FileCheckException.INVALID_PARAM_ERROR) - return path_type - - def common_check(self): - """ - 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 - 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 - """ - check_path_exists(self.file_path) - check_link(self.file_path) - self.file_path = os.path.realpath(self.file_path) - check_path_length(self.file_path) - check_path_type(self.file_path, self.path_type) - self.check_path_ability() - if self.is_script: - check_path_owner_consistent(self.file_path) - check_path_pattern_vaild(self.file_path) - check_common_file_size(self.file_path) - check_file_suffix(self.file_path, self.file_type) - return self.file_path - - def check_path_ability(self): - if self.ability == FileCheckConst.WRITE_ABLE: - check_path_writability(self.file_path) - if self.ability == FileCheckConst.READ_ABLE: - check_path_readability(self.file_path) - if self.ability == FileCheckConst.READ_WRITE_ABLE: - check_path_readability(self.file_path) - check_path_writability(self.file_path) - - class FileOpen: """ The class for open file by a safe way. @@ -212,20 +157,6 @@ def check_path_writability(path): raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) -def check_path_executable(path): - if not os.access(path, os.X_OK): - print_error_log('The file path %s is not executable.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) - - -def check_other_user_writable(path): - st = os.stat(path) - if st.st_mode & 0o002: - _user_interactive_confirm( - 'The file path %s may be insecure because other users have write permissions. ' - 'Do you want to continue?' % path) - - def _user_interactive_confirm(message): while True: check_message = input(message + " Enter 'c' to continue or enter 'e' to exit: ") @@ -264,51 +195,3 @@ def check_common_file_size(file_path): if file_path.endswith(suffix): check_file_size(file_path, max_size) break - - -def check_file_suffix(file_path, file_suffix): - if file_suffix: - if not file_path.endswith(file_suffix): - print_error_log(f"The {file_path} should be a {file_suffix} file!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) - - -def check_path_type(file_path, file_type): - if file_type == FileCheckConst.FILE: - if not os.path.isfile(file_path): - print_error_log(f"The {file_path} should be a file!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) - if file_type == FileCheckConst.DIR: - if not os.path.isdir(file_path): - print_error_log(f"The {file_path} should be a dictionary!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) - - -def create_directory(dir_path): - """ - Function Description: - creating a directory with specified permissions - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - dir_path = os.path.realpath(dir_path) - if not os.path.exists(dir_path): - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) - except OSError as ex: - print_error_log( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) from ex - - -def change_mode(path, mode): - if not os.path.exists(path) or os.path.islink(path): - return - try: - os.chmod(path, mode) - except PermissionError as ex: - print_error_log('Failed to change {} authority. {}'.format(path, str(ex))) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) from ex - diff --git a/debug/visualization/tool/graph_viewer.py b/debug/visualization/tool/graph_viewer.py index ef3e6d8bd..e83eec92a 100644 --- a/debug/visualization/tool/graph_viewer.py +++ b/debug/visualization/tool/graph_viewer.py @@ -43,7 +43,7 @@ class GraphViewer(): if node.id in connected.get(input_node.id, set()): continue edge_info = EdgeManager.get_edge(edge_id) - dot.edge(input_node.id, node.id, edge_info) + dot.edge(input_node.id, node.id, str(edge_info)) if input_node.id not in connected: connected[input_node.id] = set() connected.get(input_node.id).add(node.id) diff --git a/debug/visualization/tool/id_manager.py b/debug/visualization/tool/id_manager.py index 3d4fe2603..528020f83 100644 --- a/debug/visualization/tool/id_manager.py +++ b/debug/visualization/tool/id_manager.py @@ -62,3 +62,9 @@ class IdManager: print_error_log('cnt is too big or small') return '' return f'{node_type}_{2 * base_cnt + 1 - type_count}' + + @classmethod + def clear(cls): + cls.type_count.clear() + cls.forward_count.clear() + cls.forward_end = False -- Gitee