From 7b2b30131245df6b12685b0e3dba1991dd34ee84 Mon Sep 17 00:00:00 2001 From: l30036321 Date: Sat, 16 Dec 2023 13:41:05 +0800 Subject: [PATCH] fix security issues --- .../ptdbg_ascend/common/file_check_util.py | 15 + .../parse_tool/lib/parse_exception.py | 107 ++-- .../ptdbg_ascend/parse_tool/lib/parse_tool.py | 280 +++++------ .../ptdbg_ascend/parse_tool/lib/utils.py | 475 +++++++++--------- .../parse_tool/lib/visualization.py | 174 +++---- 5 files changed, 540 insertions(+), 511 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py index 2323124f53b..ac46062f0fb 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py @@ -210,6 +210,21 @@ def check_path_writability(path): raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) +def check_path_executable(path): + if not os.access(path, os.X_OK): + print_error_log('The file path %s is not executable.' % path) + raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) + + +def check_other_user_writable(path): + st = os.stat(path) + if st.st_mode & 0o002: + _user_interactive_confirm( + 'The file path %s may be insecure because other users have write permissions. ' + 'Do you want to continue?' % path + ) + + def _user_interactive_confirm(message): while True: check_message = input(message + " Enter 'c' to continue or enter 'e' to exit: ") diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/parse_exception.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/parse_exception.py index 380d84cb2c5..45afa5f1617 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/parse_exception.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/parse_exception.py @@ -1,52 +1,55 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import logging - - -class ParseException(Exception): - - PARSE_INVALID_PATH_ERROR = 0 - PARSE_NO_FILE_ERROR = 1 - PARSE_NO_MODULE_ERROR = 2 - PARSE_INVALID_DATA_ERROR = 3 - PARSE_INVALID_FILE_FORMAT_ERROR = 4 - PARSE_UNICODE_ERROR = 5 - PARSE_JSONDECODE_ERROR = 6 - PARSE_MSACCUCMP_ERROR = 7 - PARSE_LOAD_NPY_ERROR = 8 - - def __init__(self, code, error_info=""): - super(ParseException, self).__init__() - self.error_info = error_info - self.code = code - - -def catch_exception(func): - def inner(*args, **kwargs): - log = logging.getLogger() - line = args[-1] if len(args) == 2 else "" - result = None - try: - result = func(*args, **kwargs) - except OSError: - log.error("%s: command not found" % line) - except ParseException: - log.error("Command execution failed") - except SystemExit: - log.warning("Please enter the correct command") - return result - return inner +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import logging +from ...common.file_check_util import FileCheckException + + +class ParseException(Exception): + + PARSE_INVALID_PATH_ERROR = 0 + PARSE_NO_FILE_ERROR = 1 + PARSE_NO_MODULE_ERROR = 2 + PARSE_INVALID_DATA_ERROR = 3 + PARSE_INVALID_FILE_FORMAT_ERROR = 4 + PARSE_UNICODE_ERROR = 5 + PARSE_JSONDECODE_ERROR = 6 + PARSE_MSACCUCMP_ERROR = 7 + PARSE_LOAD_NPY_ERROR = 8 + + def __init__(self, code, error_info=""): + super(ParseException, self).__init__() + self.error_info = error_info + self.code = code + + +def catch_exception(func): + def inner(*args, **kwargs): + log = logging.getLogger() + line = args[-1] if len(args) == 2 else "" + result = None + try: + result = func(*args, **kwargs) + except OSError: + log.error("%s: command not found" % line) + except ParseException: + log.error("Command execution failed") + except FileCheckException: + log.error("Command execution failed") + except SystemExit: + log.warning("Please enter the correct command") + return result + return inner diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/parse_tool.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/parse_tool.py index b0b56100700..e6e2927c1d3 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/parse_tool.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/parse_tool.py @@ -1,140 +1,140 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import argparse -import os - -from .config import Const -from .utils import Util -from .compare import Compare -from .visualization import Visualization -from .parse_exception import catch_exception, ParseException - - -class ParseTool: - def __init__(self): - self.util = Util() - self.compare = Compare() - self.visual = Visualization() - - @catch_exception - def prepare(self): - self.util.create_dir(Const.DATA_ROOT_DIR) - - @catch_exception - def do_vector_compare(self, argv=None): - parser = argparse.ArgumentParser() - parser.add_argument( - "-m", "--my_dump_path", dest="my_dump_path", default=None, - help=" my dump path, the data compared with golden data", - required=True - ) - parser.add_argument( - "-g", "--golden_dump_path", dest="golden_dump_path", default=None, - help=" the golden dump data path", - required=True - ) - parser.add_argument( - "-out", "--output_path", dest="output_path", default=None, - help=" the output path", - required=False - ) - parser.add_argument( - "-asc", "--ascend_path", dest="ascend_path", default=None, - help=" the Ascend home path", - required=False - ) - args = parser.parse_args(argv) - if not args.output_path: - result_dir = os.path.join(Const.COMPARE_DIR) - else: - result_dir = args.output_path - my_dump_path = args.my_dump_path - golden_dump_path = args.golden_dump_path - self.util.check_path_valid(my_dump_path) - self.util.check_path_valid(golden_dump_path) - self.util.check_files_in_path(my_dump_path) - self.util.check_files_in_path(golden_dump_path) - if not os.path.isdir(my_dump_path) or not os.path.isdir(golden_dump_path): - self.util.log.error("Please enter a directory not a file") - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if args.ascend_path: - Const.MS_ACCU_CMP_PATH = self.util.path_strip(args.ascend_path) - self.util.check_path_valid(Const.MS_ACCU_CMP_PATH) - self.compare.npu_vs_npu_compare(my_dump_path, golden_dump_path, result_dir) - - @catch_exception - def do_convert_dump(self, argv=None): - parser = argparse.ArgumentParser() - parser.add_argument( - '-n', '--name', dest='path', default=None, required=True, help='dump file or dump file directory') - parser.add_argument( - '-f', '--format', dest='format', default=None, required=False, help='target format') - parser.add_argument( - '-out', '--output_path', dest='output_path', required=False, default=None, help='output path') - parser.add_argument( - "-asc", "--ascend_path", dest="ascend_path", default=None, help=" the Ascend home path", - required=False) - args = parser.parse_args(argv) - self.util.check_path_valid(args.path) - self.util.check_files_in_path(args.path) - if args.ascend_path: - Const.MS_ACCU_CMP_PATH = self.util.path_strip(args.ascend_path) - self.util.check_path_valid(Const.MS_ACCU_CMP_PATH) - self.compare.convert_dump_to_npy(args.path, args.format, args.output_path) - - @catch_exception - def do_print_data(self, argv=None): - """print tensor data""" - parser = argparse.ArgumentParser() - parser.add_argument('-n', '--name', dest='path', default=None, required=True, help='File name') - args = parser.parse_args(argv) - self.visual.print_npy_data(args.path) - - @catch_exception - def do_parse_pkl(self, argv=None): - parser = argparse.ArgumentParser() - parser.add_argument( - '-f', '--file', dest='file_name', default=None, required=True, help='PKL file path') - parser.add_argument( - '-n', '--name', dest='api_name', default=None, required=True, help='API name') - args = parser.parse_args(argv) - self.visual.parse_pkl(args.file_name, args.api_name) - - @catch_exception - def do_compare_data(self, argv): - """compare two tensor""" - parser = argparse.ArgumentParser() - parser.add_argument( - "-m", "--my_dump_path", dest="my_dump_path", default=None, - help=" my dump path, the data compared with golden data", - required=True - ) - parser.add_argument( - "-g", "--golden_dump_path", dest="golden_dump_path", default=None, - help=" the golden dump data path", - required=True - ) - parser.add_argument('-p', '--print', dest='count', default=20, type=int, help='print err data num') - parser.add_argument('-s', '--save', dest='save', action='store_true', help='save data in txt format') - parser.add_argument('-al', '--atol', dest='atol', default=0.001, type=float, help='set rtol') - parser.add_argument('-rl', '--rtol', dest='rtol', default=0.001, type=float, help='set atol') - args = parser.parse_args(argv) - self.util.check_path_valid(args.my_dump_path) - self.util.check_path_valid(args.golden_dump_path) - self.util.check_path_format(args.my_dump_path, Const.NPY_SUFFIX) - self.util.check_path_format(args.golden_dump_path, Const.NPY_SUFFIX) - self.compare.compare_data(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import argparse +import os + +from .config import Const +from .utils import Util +from .compare import Compare +from .visualization import Visualization +from .parse_exception import catch_exception, ParseException + + +class ParseTool: + def __init__(self): + self.util = Util() + self.compare = Compare() + self.visual = Visualization() + + @catch_exception + def prepare(self): + self.util.create_dir(Const.DATA_ROOT_DIR) + + @catch_exception + def do_vector_compare(self, argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", "--my_dump_path", dest="my_dump_path", default=None, + help=" my dump path, the data compared with golden data", + required=True + ) + parser.add_argument( + "-g", "--golden_dump_path", dest="golden_dump_path", default=None, + help=" the golden dump data path", + required=True + ) + parser.add_argument( + "-out", "--output_path", dest="output_path", default=None, + help=" the output path", + required=False + ) + parser.add_argument( + "-cmp_path", "--msaccucmp_path", dest="msaccucmp_path", default=None, + help=" the msaccucmp.py file path", + required=False + ) + args = parser.parse_args(argv) + if not args.output_path: + result_dir = os.path.join(Const.COMPARE_DIR) + else: + result_dir = args.output_path + my_dump_path = args.my_dump_path + golden_dump_path = args.golden_dump_path + self.util.check_path_valid(my_dump_path) + self.util.check_path_valid(golden_dump_path) + self.util.check_files_in_path(my_dump_path) + self.util.check_files_in_path(golden_dump_path) + if not os.path.isdir(my_dump_path) or not os.path.isdir(golden_dump_path): + self.util.log.error("Please enter a directory not a file") + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + msaccucmp_path = self.util.path_strip(args.msaccucmp_path) if args.msaccucmp_path else Const.MS_ACCU_CMP_PATH + self.util.check_path_valid(msaccucmp_path) + self.util.check_executable_file(msaccucmp_path) + self.compare.npu_vs_npu_compare(my_dump_path, golden_dump_path, result_dir) + + @catch_exception + def do_convert_dump(self, argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + '-n', '--name', dest='path', default=None, required=True, help='dump file or dump file directory') + parser.add_argument( + '-f', '--format', dest='format', default=None, required=False, help='target format') + parser.add_argument( + '-out', '--output_path', dest='output_path', required=False, default=None, help='output path') + parser.add_argument( + "-cmp_path", "--msaccucmp_path", dest="msaccucmp_path", default=None, + help=" the msaccucmp.py file path",required=False) + args = parser.parse_args(argv) + self.util.check_path_valid(args.path) + self.util.check_files_in_path(args.path) + msaccucmp_path = self.util.path_strip(args.msaccucmp_path) if args.msaccucmp_path else Const.MS_ACCU_CMP_PATH + self.util.check_path_valid(msaccucmp_path) + self.util.check_executable_file(msaccucmp_path) + self.compare.convert_dump_to_npy(args.path, args.format, args.output_path) + + @catch_exception + def do_print_data(self, argv=None): + """print tensor data""" + parser = argparse.ArgumentParser() + parser.add_argument('-n', '--name', dest='path', default=None, required=True, help='File name') + args = parser.parse_args(argv) + self.visual.print_npy_data(args.path) + + @catch_exception + def do_parse_pkl(self, argv=None): + parser = argparse.ArgumentParser() + parser.add_argument( + '-f', '--file', dest='file_name', default=None, required=True, help='PKL file path') + parser.add_argument( + '-n', '--name', dest='api_name', default=None, required=True, help='API name') + args = parser.parse_args(argv) + self.visual.parse_pkl(args.file_name, args.api_name) + + @catch_exception + def do_compare_data(self, argv): + """compare two tensor""" + parser = argparse.ArgumentParser() + parser.add_argument( + "-m", "--my_dump_path", dest="my_dump_path", default=None, + help=" my dump path, the data compared with golden data", + required=True + ) + parser.add_argument( + "-g", "--golden_dump_path", dest="golden_dump_path", default=None, + help=" the golden dump data path", + required=True + ) + parser.add_argument('-p', '--print', dest='count', default=20, type=int, help='print err data num') + parser.add_argument('-s', '--save', dest='save', action='store_true', help='save data in txt format') + parser.add_argument('-al', '--atol', dest='atol', default=0.001, type=float, help='set rtol') + parser.add_argument('-rl', '--rtol', dest='rtol', default=0.001, type=float, help='set atol') + args = parser.parse_args(argv) + self.util.check_path_valid(args.my_dump_path) + self.util.check_path_valid(args.golden_dump_path) + self.util.check_path_format(args.my_dump_path, Const.NPY_SUFFIX) + self.util.check_path_format(args.golden_dump_path, Const.NPY_SUFFIX) + self.compare.compare_data(args.my_dump_path, args.golden_dump_path, args.save, args.rtol, args.atol, args.count) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/utils.py index 20c5f6c7499..6793b0e3184 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/utils.py @@ -1,232 +1,243 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import logging -import os -import re -import sys -import subprocess -import numpy as np -from .config import Const -from .file_desc import DumpDecodeFileDesc, FileDesc -from .parse_exception import ParseException - -try: - from rich.traceback import install - from rich.panel import Panel - from rich.table import Table - from rich import print as rich_print - from rich.columns import Columns - install() -except ImportError as err: - install = None - Panel = None - Table = None - Columns = None - rich_print = None - print("[Warning] Failed to import rich, Some features may not be available. Please run 'pip install rich' to fix it.") - - -class Util: - def __init__(self): - self.ms_accu_cmp = None - logging.basicConfig( - level=Const.LOG_LEVEL, - format="%(asctime)s (%(process)d) -[%(levelname)s]%(message)s", - datefmt="%Y-%m-%d %H:%M:%S" - ) - self.log = logging.getLogger() - self.python = sys.executable - - @staticmethod - def print(content): - rich_print(content) - - @staticmethod - def path_strip(path): - return path.strip("'").strip('"') - - @staticmethod - def _gen_npu_dump_convert_file_info(name, match, dir_path): - return DumpDecodeFileDesc(name, dir_path, int(match.groups()[-4]), op_name=match.group(2), - op_type=match.group(1), task_id=int(match.group(3)), anchor_type=match.groups()[-3], - anchor_idx=int(match.groups()[-2])) - - @staticmethod - def _gen_numpy_file_info(name, math, dir_path): - return FileDesc(name, dir_path) - - def execute_command(self, cmd): - if not cmd: - self.log.error("Commond is None") - return -1 - self.log.debug("[RUN CMD]: %s", cmd) - cmd = cmd.split(" ") - complete_process = subprocess.run(cmd, shell=False) - return complete_process.returncode - - def print_panel(self, content, title='', fit=True): - if not Panel: - print(content) - return - if fit: - self.print(Panel.fit(content, title=title)) - else: - self.print(Panel(content, title=title)) - - def check_msaccucmp(self, target_file): - self.log.info("Try to auto detect file with name: %s.", target_file) - result = subprocess.run( - [self.python, target_file, "--help"], stdout=subprocess.PIPE) - if result.returncode == 0: - self.log.info("Check [%s] success.", target_file) - else: - self.log.error("Check msaccucmp failed in dir %s" % target_file) - self.log.error("Please specify a valid msaccucmp.py path or install the cann package") - raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR) - return target_file - - def create_dir(self, path): - path = self.path_strip(path) - if os.path.exists(path): - return - self.check_path_name(path) - try: - os.makedirs(path, mode=0o750) - except OSError as e: - self.log.error("Failed to create %s.", path) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) from e - - def gen_npy_info_txt(self, source_data): - shape, dtype, max_data, min_data, mean = \ - self.npy_info(source_data) - return \ - '[Shape: %s] [Dtype: %s] [Max: %s] [Min: %s] [Mean: %s]' % (shape, dtype, max_data, min_data, mean) - - def save_npy_to_txt(self, data, dst_file='', align=0): - if os.path.exists(dst_file): - self.log.info("Dst file %s exists, will not save new one.", dst_file) - return - shape = data.shape - data = data.flatten() - if align == 0: - align = 1 if len(shape) == 0 else shape[-1] - elif data.size % align != 0: - pad_array = np.zeros((align - data.size % align,)) - data = np.append(data, pad_array) - np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') - - def list_convert_files(self, path, external_pattern=""): - return self._list_file_with_pattern( - path, Const.OFFLINE_DUMP_CONVERT_PATTERN, external_pattern, self._gen_npu_dump_convert_file_info - ) - - def list_numpy_files(self, path, extern_pattern=''): - return self._list_file_with_pattern(path, Const.NUMPY_PATTERN, extern_pattern, - self._gen_numpy_file_info) - - def create_columns(self, content): - if not Columns: - self.log.error("No Module named rich, please install it") - raise ParseException(ParseException.PARSE_NO_MODULE_ERROR) - return Columns(content) - - def create_table(self, title, columns): - if not Table: - self.log.error("No Module named rich, please install it and restart parse tool") - raise ParseException(ParseException.PARSE_NO_MODULE_ERROR) - table = Table(title=title) - for column_name in columns: - table.add_column(column_name, overflow='fold') - return table - - def check_path_valid(self, path): - path = self.path_strip(path) - if not path or not os.path.exists(path): - self.log.error("The path %s does not exist." % path) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if os.path.islink(path): - self.log.error('The file path {} is a soft link.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ - Const.FILE_NAME_LENGTH: - self.log.error('The file path length exceeds limit.') - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): - self.log.error('The file path {} contains special characters.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if os.path.isfile(path): - file_size = os.path.getsize(path) - if path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB: - self.log.error('The file {} size is greater than 1GB.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if path.endswith(Const.NPY_SUFFIX) and file_size > Const.TEN_GB: - self.log.error('The file {} size is greater than 10GB.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - - def check_files_in_path(self, path): - if os.path.isdir(path) and len(os.listdir(path)) == 0: - self.log.error("No files in %s." % path) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - - def npy_info(self, source_data): - if isinstance(source_data, np.ndarray): - data = source_data - else: - self.log.error("Invalid data, data is not ndarray") - raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR) - if data.dtype == 'object': - self.log.error("Invalid data, data is object.") - raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR) - if np.size(data) == 0: - self.log.error("Invalid data, data is empty") - raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR) - return data.shape, data.dtype, data.max(), data.min(), data.mean() - - def _list_file_with_pattern(self, path, pattern, extern_pattern, gen_info_func): - self.check_path_valid(path) - file_list = {} - re_pattern = re.compile(pattern) - for dir_path, dir_names, file_names in os.walk(path, followlinks=True): - for name in file_names: - match = re_pattern.match(name) - if not match: - continue - if extern_pattern != '' and not re.match(extern_pattern, name): - continue - file_list[name] = gen_info_func(name, match, dir_path) - return file_list - - def check_path_format(self, path, suffix): - if os.path.isfile(path): - if not path.endswith(suffix): - self.log.error("%s is not a %s file." % (path, suffix)) - raise ParseException(ParseException.PARSE_INVALID_FILE_FORMAT_ERROR) - elif os.path.isdir(path): - self.log.error("Please specify a single file path") - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - else: - self.log.error("The file path %s is invalid" % path) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - - def check_path_name(self, path): - if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ - Const.FILE_NAME_LENGTH: - self.log.error('The file path length exceeds limit.') - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) - if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): - self.log.error('The file path {} contains special characters.'.format(path)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import logging +import os +import re +import sys +import subprocess +import numpy as np +from .config import Const +from .file_desc import DumpDecodeFileDesc, FileDesc +from .parse_exception import ParseException +from ...common.file_check_util import change_mode, check_other_user_writable,\ + check_path_executable, check_path_owner_consistent +from ...common.file_check_util import FileCheckConst + +try: + from rich.traceback import install + from rich.panel import Panel + from rich.table import Table + from rich import print as rich_print + from rich.columns import Columns + install() +except ImportError as err: + install = None + Panel = None + Table = None + Columns = None + rich_print = None + print("[Warning] Failed to import rich, Some features may not be available. " + "Please run 'pip install rich' to fix it.") + + +class Util: + def __init__(self): + self.ms_accu_cmp = None + logging.basicConfig( + level=Const.LOG_LEVEL, + format="%(asctime)s (%(process)d) -[%(levelname)s]%(message)s", + datefmt="%Y-%m-%d %H:%M:%S" + ) + self.log = logging.getLogger() + self.python = sys.executable + + @staticmethod + def print(content): + rich_print(content) + + @staticmethod + def path_strip(path): + return path.strip("'").strip('"') + + @staticmethod + def _gen_npu_dump_convert_file_info(name, match, dir_path): + return DumpDecodeFileDesc(name, dir_path, int(match.groups()[-4]), op_name=match.group(2), + op_type=match.group(1), task_id=int(match.group(3)), anchor_type=match.groups()[-3], + anchor_idx=int(match.groups()[-2])) + + @staticmethod + def _gen_numpy_file_info(name, math, dir_path): + return FileDesc(name, dir_path) + + @staticmethod + def check_executable_file(path): + check_path_owner_consistent(path) + check_other_user_writable(path) + check_path_executable(path) + + def execute_command(self, cmd): + if not cmd: + self.log.error("Commond is None") + return -1 + self.log.debug("[RUN CMD]: %s", cmd) + cmd = cmd.split(" ") + complete_process = subprocess.run(cmd, shell=False) + return complete_process.returncode + + def print_panel(self, content, title='', fit=True): + if not Panel: + print(content) + return + if fit: + self.print(Panel.fit(content, title=title)) + else: + self.print(Panel(content, title=title)) + + def check_msaccucmp(self, target_file): + self.log.info("Try to auto detect file with name: %s.", target_file) + result = subprocess.run( + [self.python, target_file, "--help"], stdout=subprocess.PIPE) + if result.returncode == 0: + self.log.info("Check [%s] success.", target_file) + else: + self.log.error("Check msaccucmp failed in dir %s" % target_file) + self.log.error("Please specify a valid msaccucmp.py path or install the cann package") + raise ParseException(ParseException.PARSE_MSACCUCMP_ERROR) + return target_file + + def create_dir(self, path): + path = self.path_strip(path) + if os.path.exists(path): + return + self.check_path_name(path) + try: + os.makedirs(path, mode=0o750) + except OSError as e: + self.log.error("Failed to create %s.", path) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) from e + + def gen_npy_info_txt(self, source_data): + shape, dtype, max_data, min_data, mean = \ + self.npy_info(source_data) + return \ + '[Shape: %s] [Dtype: %s] [Max: %s] [Min: %s] [Mean: %s]' % (shape, dtype, max_data, min_data, mean) + + def save_npy_to_txt(self, data, dst_file='', align=0): + if os.path.exists(dst_file): + self.log.info("Dst file %s exists, will not save new one.", dst_file) + return + shape = data.shape + data = data.flatten() + if align == 0: + align = 1 if len(shape) == 0 else shape[-1] + elif data.size % align != 0: + pad_array = np.zeros((align - data.size % align,)) + data = np.append(data, pad_array) + np.savetxt(dst_file, data.reshape((-1, align)), delimiter=' ', fmt='%g') + change_mode(dst_file, FileCheckConst.DATA_FILE_AUTHORITY) + + def list_convert_files(self, path, external_pattern=""): + return self._list_file_with_pattern( + path, Const.OFFLINE_DUMP_CONVERT_PATTERN, external_pattern, self._gen_npu_dump_convert_file_info + ) + + def list_numpy_files(self, path, extern_pattern=''): + return self._list_file_with_pattern(path, Const.NUMPY_PATTERN, extern_pattern, + self._gen_numpy_file_info) + + def create_columns(self, content): + if not Columns: + self.log.error("No Module named rich, please install it") + raise ParseException(ParseException.PARSE_NO_MODULE_ERROR) + return Columns(content) + + def create_table(self, title, columns): + if not Table: + self.log.error("No Module named rich, please install it and restart parse tool") + raise ParseException(ParseException.PARSE_NO_MODULE_ERROR) + table = Table(title=title) + for column_name in columns: + table.add_column(column_name, overflow='fold') + return table + + def check_path_valid(self, path): + path = self.path_strip(path) + if not path or not os.path.exists(path): + self.log.error("The path %s does not exist." % path) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + if os.path.islink(path): + self.log.error('The file path {} is a soft link.'.format(path)) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ + Const.FILE_NAME_LENGTH: + self.log.error('The file path length exceeds limit.') + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): + self.log.error('The file path {} contains special characters.'.format(path)) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + if os.path.isfile(path): + file_size = os.path.getsize(path) + if path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB: + self.log.error('The file {} size is greater than 1GB.'.format(path)) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + if path.endswith(Const.NPY_SUFFIX) and file_size > Const.TEN_GB: + self.log.error('The file {} size is greater than 10GB.'.format(path)) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + + def check_files_in_path(self, path): + if os.path.isdir(path) and len(os.listdir(path)) == 0: + self.log.error("No files in %s." % path) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + + def npy_info(self, source_data): + if isinstance(source_data, np.ndarray): + data = source_data + else: + self.log.error("Invalid data, data is not ndarray") + raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR) + if data.dtype == 'object': + self.log.error("Invalid data, data is object.") + raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR) + if np.size(data) == 0: + self.log.error("Invalid data, data is empty") + raise ParseException(ParseException.PARSE_INVALID_DATA_ERROR) + return data.shape, data.dtype, data.max(), data.min(), data.mean() + + def _list_file_with_pattern(self, path, pattern, extern_pattern, gen_info_func): + self.check_path_valid(path) + file_list = {} + re_pattern = re.compile(pattern) + for dir_path, dir_names, file_names in os.walk(path, followlinks=True): + for name in file_names: + match = re_pattern.match(name) + if not match: + continue + if extern_pattern != '' and not re.match(extern_pattern, name): + continue + file_list[name] = gen_info_func(name, match, dir_path) + return file_list + + def check_path_format(self, path, suffix): + if os.path.isfile(path): + if not path.endswith(suffix): + self.log.error("%s is not a %s file." % (path, suffix)) + raise ParseException(ParseException.PARSE_INVALID_FILE_FORMAT_ERROR) + elif os.path.isdir(path): + self.log.error("Please specify a single file path") + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + else: + self.log.error("The file path %s is invalid" % path) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + + def check_path_name(self, path): + if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ + Const.FILE_NAME_LENGTH: + self.log.error('The file path length exceeds limit.') + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): + self.log.error('The file path {} contains special characters.'.format(path)) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/visualization.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/visualization.py index 10dde7e0894..83bb2459457 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/visualization.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/visualization.py @@ -1,87 +1,87 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import json -import numpy as np - -from .config import Const -from .utils import Util -from .parse_exception import ParseException - - -class Visualization: - def __init__(self): - self.util = Util() - - def print_npy_summary(self, target_file): - try: - np_data = np.load(target_file, allow_pickle=True) - except UnicodeError as e: - self.util.log.error("%s %s" % ("UnicodeError", str(e))) - self.util.log.warning("Please check the npy file") - raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e - table = self.util.create_table('', ['Index', 'Data']) - flatten_data = np_data.flatten() - for i in range(min(16, int(np.ceil(flatten_data.size / 8)))): - last_idx = min(flatten_data.size, i * 8 + 8) - table.add_row(str(i * 8), ' '.join(flatten_data[i * 8: last_idx].astype('str').tolist())) - summary = ['[yellow]%s[/yellow]' % self.util.gen_npy_info_txt(np_data), 'Path: %s' % target_file, - "TextFile: %s.txt" % target_file] - self.util.print_panel(self.util.create_columns([table, "\n".join(summary)]), target_file) - self.util.save_npy_to_txt(np_data, target_file + "txt") - - def print_npy_data(self, file_name): - file_name = self.util.path_strip(file_name) - self.util.check_path_valid(file_name) - self.util.check_path_format(file_name, Const.NPY_SUFFIX) - return self.print_npy_summary(file_name) - - def parse_pkl(self, path, api_name): - path = self.util.path_strip(path) - self.util.check_path_valid(path) - self.util.check_path_format(path, Const.PKL_SUFFIX) - with open(path, "r") as pkl_handle: - title_printed = False - while True: - pkl_line = pkl_handle.readline() - if pkl_line == '\n': - continue - if len(pkl_line) == 0: - break - try: - msg = json.loads(pkl_line) - except json.JSONDecodeError as e: - self.util.log.error("%s %s in line %s" % ("JSONDecodeError", str(e), pkl_line)) - self.util.log.warning("Please check the pkl file") - raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) from e - info_prefix = msg[0] - if not info_prefix.startswith(api_name): - continue - if info_prefix.find("stack_info") != -1 and len(msg) == 2: - print("\nTrace back({}):".format(msg[0])) - if msg[1] and len(msg[1]) > 4: - for item in reversed(msg[1]): - print(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2])) - print(" {}".format(item[3])) - continue - if len(msg) > 5: - summery_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \ - .format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2]) - if not title_printed: - print("\nStatistic Info:") - title_printed = True - print(summery_info) - pkl_handle.close() +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import json +import numpy as np + +from .config import Const +from .utils import Util +from .parse_exception import ParseException + + +class Visualization: + def __init__(self): + self.util = Util() + + def print_npy_summary(self, target_file): + try: + np_data = np.load(target_file, allow_pickle=True) + except UnicodeError as e: + self.util.log.error("%s %s" % ("UnicodeError", str(e))) + self.util.log.warning("Please check the npy file") + raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e + table = self.util.create_table('', ['Index', 'Data']) + flatten_data = np_data.flatten() + for i in range(min(16, int(np.ceil(flatten_data.size / 8)))): + last_idx = min(flatten_data.size, i * 8 + 8) + table.add_row(str(i * 8), ' '.join(flatten_data[i * 8: last_idx].astype('str').tolist())) + summary = ['[yellow]%s[/yellow]' % self.util.gen_npy_info_txt(np_data), 'Path: %s' % target_file, + "TextFile: %s.txt" % target_file] + self.util.print_panel(self.util.create_columns([table, "\n".join(summary)]), target_file) + self.util.save_npy_to_txt(np_data, target_file + ".txt") + + def print_npy_data(self, file_name): + file_name = self.util.path_strip(file_name) + self.util.check_path_valid(file_name) + self.util.check_path_format(file_name, Const.NPY_SUFFIX) + return self.print_npy_summary(file_name) + + def parse_pkl(self, path, api_name): + path = self.util.path_strip(path) + self.util.check_path_valid(path) + self.util.check_path_format(path, Const.PKL_SUFFIX) + with open(path, "r") as pkl_handle: + title_printed = False + while True: + pkl_line = pkl_handle.readline() + if pkl_line == '\n': + continue + if len(pkl_line) == 0: + break + try: + msg = json.loads(pkl_line) + except json.JSONDecodeError as e: + self.util.log.error("%s %s in line %s" % ("JSONDecodeError", str(e), pkl_line)) + self.util.log.warning("Please check the pkl file") + raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) from e + info_prefix = msg[0] + if not info_prefix.startswith(api_name): + continue + if info_prefix.find("stack_info") != -1 and len(msg) == 2: + print("\nTrace back({}):".format(msg[0])) + if msg[1] and len(msg[1]) > 4: + for item in reversed(msg[1]): + print(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2])) + print(" {}".format(item[3])) + continue + if len(msg) > 5: + summery_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \ + .format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2]) + if not title_printed: + print("\nStatistic Info:") + title_printed = True + print(summery_info) + pkl_handle.close() -- Gitee