From 81451797db254422ac95c8b082ed3cba85795159 Mon Sep 17 00:00:00 2001 From: wangchao Date: Fri, 20 Oct 2023 15:51:41 +0800 Subject: [PATCH 1/3] add file check --- .../api_accuracy_checker/dump/dump.py | 22 +- .../run_ut/run_overflow_check.py | 13 +- .../api_accuracy_checker/run_ut/run_ut.py | 22 +- .../ptdbg_ascend/common/file_check_util.py | 237 ++++++++++++++++++ 4 files changed, 272 insertions(+), 22 deletions(-) create mode 100644 debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 08385882d3..cbfa30f3f8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -17,12 +17,15 @@ from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json -from api_accuracy_checker.common.utils import print_error_log +from api_accuracy_checker.common.utils import print_error_log, CompareException from api_accuracy_checker.hook_module.register_hook import initialize_hook from api_accuracy_checker.common.config import msCheckerConfig def set_dump_switch(switch): + if switch not in ["ON", "OFF"]: + print_error_log("Please set switch with 'ON' or 'OFF'.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) if switch == "ON": initialize_hook(pretest_hook) initialize_output_json() @@ -39,8 +42,8 @@ class DumpUtil(object): @staticmethod def get_dump_switch(): return DumpUtil.dump_switch == "ON" - - @staticmethod + + @staticmethod def incr_iter_num_maybe_exit(): if DumpUtil.call_num == msCheckerConfig.target_iter or not msCheckerConfig.enable_dataloader: set_dump_switch("ON") @@ -48,7 +51,7 @@ class DumpUtil(object): raise Exception("Model pretest: exit after iteration {}".format(msCheckerConfig.target_iter)) else: set_dump_switch("OFF") - DumpUtil.call_num += 1 + DumpUtil.call_num += 1 class DumpConst: @@ -59,7 +62,7 @@ class DumpConst: def pretest_info_dump(name, out_feat, module, phase): if not DumpUtil.get_dump_switch(): - return + return if phase == DumpConst.forward: api_info = ForwardAPIInfo(name, module.input_args, module.input_kwargs) elif phase == DumpConst.backward: @@ -68,14 +71,15 @@ def pretest_info_dump(name, out_feat, module, phase): msg = "Unexpected training phase {}.".format(phase) print_error_log(msg) raise NotImplementedError(msg) - + write_api_info_json(api_info) + def pretest_hook(name, phase): def pretest_info_dump_hook(module, in_feat, out_feat): pretest_info_dump(name, out_feat, module, phase) if hasattr(module, "input_args"): - del module.input_args + del module.input_args if hasattr(module, "input_kwargs"): - del module.input_kwargs - return pretest_info_dump_hook + del module.input_kwargs + return pretest_info_dump_hook diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py index 7c0fa0f6a6..1348d9dd38 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_overflow_check.py @@ -9,6 +9,7 @@ from api_accuracy_checker.run_ut.run_ut import exec_api, generate_npu_params, ru from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ print_error_log +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileChecker, FileCheckConst NO_GRAD_APIS = ["hardtanh"] @@ -133,14 +134,14 @@ def _run_overflow_check(): args = parser.parse_args(sys.argv[1:]) torch.npu.set_compile_mode(jit_compile=args.jit_compile) npu_device = "npu:" + str(args.device_id) - forward_file = os.path.realpath(args.forward_input_file) + forward_file_checker = FileChecker(args.forward_input_file, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE, + file_type=FileCheckConst.JSON_SUFFIX) + forward_file = forward_file_checker.common_check() backward_file = "" if args.backward_input_file: - backward_file = os.path.realpath(args.backward_input_file) - if not backward_file.endswith(".json"): - raise ValueError("The backward_input_file should be a json file!") - if not forward_file.endswith(".json"): - raise ValueError("The forward_input_file should be a json file!") + backward_file_checker = FileChecker(args.backward_input_file, FileCheckConst.FILE, + ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX) + backward_file = backward_file_checker.common_check() try: torch.npu.set_device(npu_device) except Exception: diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 034586cc1d..236726b3db 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -16,12 +16,16 @@ from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate from ut_api_info import UtAPIInfo from api_accuracy_checker.common.config import msCheckerConfig +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileChecker, FileCheckConst + NO_GRAD_APIS = ["hardtanh"] def init_environment(): cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "../hook_module/support_wrap_ops.yaml") + yaml_path_checker = FileChecker(yaml_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE) + yaml_path = yaml_path_checker.common_check() with open(yaml_path, 'r') as f: WrapFunctionalOps = yaml.safe_load(f).get('functional') for f in dir(torch.nn.functional): @@ -136,7 +140,6 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) UtAPIInfo(api_full_name + '.backward.output.npu', data_info.npu_grad_out) - def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): in_fwd_data_list = [] [api_type, api_name, _] = api_full_name.split("*") @@ -208,8 +211,9 @@ def run_backward(api_full_name, args, backward_content, grad_index, npu_args, np def initialize_save_error_data(): - error_data_path = os.path.realpath(msCheckerConfig.error_data_path) - check_file_or_directory_path(error_data_path, True) + error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR, + ability=FileCheckConst.WRITE_ABLE) + error_data_path = error_data_path_checker.common_check() initialize_save_path(error_data_path, 'ut_error_data') @@ -244,11 +248,15 @@ def _run_ut(): except Exception: print_error_log(f"Set NPU device id failed. device id is: {args.device_id}") raise NotImplementedError - forward_file = os.path.realpath(args.forward_input_file) - backward_file = os.path.realpath(args.backward_input_file) - if not forward_file.endswith(".json") or not backward_file.endswith(".json"): - raise ValueError("The forward_input_file and backward_input_file should be a json file!") + forward_file_checker = FileChecker(args.forward_input_file, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE, + file_type=FileCheckConst.JSON_SUFFIX) + forward_file = forward_file_checker.common_check() + backward_file_checker = FileChecker(args.backward_input_file, FileCheckConst.FILE, + ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX) + backward_file = backward_file_checker.common_check() out_path = os.path.realpath(args.out_path) if args.out_path else "./" + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() save_error_data = args.save_error_data if save_error_data: initialize_save_error_data() diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py new file mode 100644 index 0000000000..fcf166949b --- /dev/null +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py @@ -0,0 +1,237 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import os +import re + +from .utils import print_warn_log, print_error_log + + +class FileCheckConst: + """ + Class for file check const + """ + READ_ABLE = "read" + WRITE_ABLE = "write" + READ_WRITE_ABLE = "read and write" + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + JSON_SUFFIX = ".json" + MAX_PKL_SIZE = 1 * 1024 * 1024 * 1024 + MAX_NUMPY_SIZE = 10 * 1024 * 1024 * 1024 + MAX_JSON_SIZE = 1 * 1024 * 1024 * 1024 + DIR = "dir" + FILE = "file" + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + + +class FileCheckException(Exception): + """ + Class for File Check Exception + """ + NONE_ERROR = 0 + INVALID_PATH_ERROR = 1 + INVALID_FILE_TYPE_ERROR = 2 + INVALID_PARAM_ERROR = 3 + INVALID_PERMISSION_ERROR = 3 + + def __init__(self, code, error_info: str = ""): + super(FileCheckException, self).__init__() + self.code = code + self.error_info = error_info + + def __str__(self): + return self.error_info + + +class FileChecker: + """ + The class for check file. + + Attributes: + file_path: The file or dictionary path to be verified. + ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability + file_type(str): The correct file type for file + """ + def __init__(self, file_path, path_type, ability=None, file_type=None): + self.file_path = file_path + self.path_type = self._check_path_type(path_type) + self.ability = ability + self.file_type = file_type + + @staticmethod + def _check_path_type(path_type): + if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: + print_error_log(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') + raise FileCheckException(FileCheckException.INVALID_PARAM_ERROR) + return path_type + + def common_check(self): + """ + 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 + 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 + """ + check_link(self.file_path) + check_path_length(self.file_path) + check_path_exists(self.file_path) + self.check_path_ability() + check_path_owner_consistent(self.file_path) + check_path_pattern_vaild(self.file_path) + check_common_file_size(self.file_path) + check_file_suffix(self.file_path, self.file_type) + return os.path.realpath(self.file_path) + + def check_path_ability(self): + if self.ability == FileCheckConst.WRITE_ABLE: + check_path_writability(self.file_path) + if self.ability == FileCheckConst.READ_ABLE: + check_path_readability(self.file_path) + if self.ability == FileCheckConst.READ_WRITE_ABLE: + check_path_readability(self.file_path) + check_path_writability(self.file_path) + + +def check_link(path): + abs_path = os.path.abspath(path) + if os.path.islink(abs_path): + print_error_log('The file path {} is a soft link.'.format(path)) + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + + +def check_path_length(path): + if len(os.path.realpath(path)) > FileCheckConst.DIRECTORY_LENGTH or \ + len(os.path.basename(path)) > FileCheckConst.FILE_NAME_LENGTH: + print_error_log('The file path length exceeds limit.') + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + + +def check_path_exists(path): + real_path = os.path.realpath(path) + if not os.path.exists(real_path): + print_error_log('The file path %s does not exist.' % path) + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + + +def check_path_readability(path): + real_path = os.path.realpath(path) + if not os.access(real_path, os.R_OK): + print_error_log('The file path %s is not readable.' % path) + raise FileCheckException(FileCheckException.PERMISSION_ERROR) + + +def check_path_writability(path): + real_path = os.path.realpath(path) + if not os.access(real_path, os.W_OK): + print_error_log('The file path %s is not writable.' % path) + raise FileCheckException(FileCheckException.PERMISSION_ERROR) + + +def _user_interactive_confirm(message): + while True: + check_message = input(message + " Enter 'c' to continue or enter 'e' to exit: ") + if check_message == "c": + break + elif check_message == "e": + print_warn_log("User canceled.") + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + else: + print("Input is error, please enter 'c' or 'e'.") + + +def check_path_owner_consistent(path): + real_path = os.path.realpath(path) + file_owner = os.stat(real_path).st_uid + if file_owner != os.getuid(): + _user_interactive_confirm('The file path %s may be insecure because is does not belong to you.' + 'Do you want to continue?' % path) + + +def check_path_pattern_vaild(path): + if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)): + print_error_log('The file path {} contains special characters.'.format(path)) + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + + +def check_file_size(file_path, max_size): + real_path = os.path.realpath(file_path) + file_size = os.path.getsize(real_path) + if file_size >= max_size: + _user_interactive_confirm(f'The size of file path {file_path} exceeds {max_size} bytes.' + f'Do you want to continue?') + + +def check_common_file_size(file_path): + if os.path.isfile(file_path): + if file_path.endswith(FileCheckConst.PKL_SUFFIX): + check_file_size(file_path, FileCheckConst.MAX_PKL_SIZE) + if file_path.endswith(FileCheckConst.NUMPY_SUFFIX): + check_file_size(file_path, FileCheckConst.MAX_NUMPY_SIZE) + if file_path.endswith(FileCheckConst.JSON_SUFFIX): + check_file_size(file_path, FileCheckConst.MAX_JSON_SIZE) + + +def check_file_suffix(file_path, file_suffix): + if file_suffix: + real_path = os.path.realpath(file_path) + if not real_path.endswith(file_suffix): + print_error_log(f"The {file_path} should be a {file_suffix} file!") + raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) + + +def check_file_type(file_path, file_type): + real_path = os.path.realpath(file_path) + if file_type == FileCheckConst.FILE: + if not os.path.isfile(real_path): + print_error_log(f"The {file_path} should be a file!") + raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) + if file_type == FileCheckConst.DIR: + if not os.path.isdir(real_path): + print_error_log(f"The {file_path} should be a dictionary!") + raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) + + +def create_directory(dir_path): + """ + Function Description: + creating a directory with specified permissions + Parameter: + dir_path: directory path + Exception Description: + when invalid data throw exception + """ + dir_path = os.path.realpath(dir_path) + if not os.path.exists(dir_path): + try: + os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) + except OSError as ex: + print_error_log( + 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + + +def change_mode(path, mode): + if not os.path.exists(path) or os.islink(path): + return + try: + os.chmod(path, mode) + except PermissionError as ex: + print_error_log('Failed to change {} authority. {}'.format(path, str(ex))) + raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) + -- Gitee From 30b081107035a03f5d2d46a63bc706c9e2501ab7 Mon Sep 17 00:00:00 2001 From: wangchao Date: Fri, 20 Oct 2023 17:47:19 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E6=8E=A5=E5=8F=A3=E5=85=A5=E5=8F=82?= =?UTF-8?q?=E6=A0=A1=E9=AA=8C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../src/python/ptdbg_ascend/common/file_check_util.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py index fcf166949b..a2c257256a 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py @@ -133,14 +133,14 @@ def check_path_readability(path): real_path = os.path.realpath(path) if not os.access(real_path, os.R_OK): print_error_log('The file path %s is not readable.' % path) - raise FileCheckException(FileCheckException.PERMISSION_ERROR) + raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) def check_path_writability(path): real_path = os.path.realpath(path) if not os.access(real_path, os.W_OK): print_error_log('The file path %s is not writable.' % path) - raise FileCheckException(FileCheckException.PERMISSION_ERROR) + raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) def _user_interactive_confirm(message): @@ -164,7 +164,7 @@ def check_path_owner_consistent(path): def check_path_pattern_vaild(path): - if not re.match(FileCheckConst.FILE_PATTERN, os.path.realpath(path)): + if not re.match(FileCheckConst.FILE_VALID_PATTERN, os.path.realpath(path)): print_error_log('The file path {} contains special characters.'.format(path)) raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) @@ -227,7 +227,7 @@ def create_directory(dir_path): def change_mode(path, mode): - if not os.path.exists(path) or os.islink(path): + if not os.path.exists(path) or os.path.islink(path): return try: os.chmod(path, mode) -- Gitee From 1b1a86849d7cd18cac37e1280cee089134747062 Mon Sep 17 00:00:00 2001 From: wangchao Date: Mon, 23 Oct 2023 17:27:45 +0800 Subject: [PATCH 3/3] file check --- .../api_accuracy_checker/common/config.py | 99 +- .../api_accuracy_checker/common/utils.py | 1234 +++++++++-------- .../api_accuracy_checker/dump/info_dump.py | 6 +- .../hook_module/wrap_functional.py | 136 +- .../hook_module/wrap_tensor.py | 134 +- .../hook_module/wrap_torch.py | 220 +-- .../api_accuracy_checker/run_ut/run_ut.py | 554 ++++---- debug/accuracy_tools/ptdbg_ascend/__init__.py | 17 + .../ptdbg_ascend/common/file_check_util.py | 51 + 9 files changed, 1266 insertions(+), 1185 deletions(-) create mode 100644 debug/accuracy_tools/ptdbg_ascend/__init__.py diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 2bf26d7355..41de51f40c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -1,49 +1,52 @@ -import yaml -import os -from api_accuracy_checker.common.utils import check_file_or_directory_path - -class Config: - def __init__(self, yaml_file): - check_file_or_directory_path(yaml_file, False) - with open(yaml_file, 'r') as file: - config = yaml.safe_load(file) - self.config = {key: self.validate(key, value) for key, value in config.items()} - - def validate(self, key, value): - validators = { - 'dump_path': str, - 'jit_compile': bool, - 'compile_option': str, - 'compare_algorithm': str, - 'real_data': bool, - 'dump_step': int, - 'error_data_path': str, - 'enable_dataloader': bool, - 'target_iter': int, - 'precision': int - } - if not isinstance(value, validators[key]): - raise ValueError(f"{key} must be {validators[key].__name__} type") - if key == 'target_iter' and value < 0: - raise ValueError("target_iter must be greater than 0") - if key == 'precision' and value < 0: - raise ValueError("precision must be greater than 0") - return value - - def __getattr__(self, item): - return self.config[item] - - def __str__(self): - return '\n'.join(f"{key}={value}" for key, value in self.config.items()) - - def update_config(self, **kwargs): - for key, value in kwargs.items(): - if key in self.config: - self.config[key] = self.validate(key, value) - else: - raise ValueError(f"Invalid key '{key}'") - - -cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -yaml_path = os.path.join(cur_path, "config.yaml") +import yaml +import os +from api_accuracy_checker.common.utils import check_file_or_directory_path + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + + +class Config: + def __init__(self, yaml_file): + check_file_or_directory_path(yaml_file, False) + with FileOpen(yaml_file, 'r') as file: + config = yaml.safe_load(file) + self.config = {key: self.validate(key, value) for key, value in config.items()} + + def validate(self, key, value): + validators = { + 'dump_path': str, + 'jit_compile': bool, + 'compile_option': str, + 'compare_algorithm': str, + 'real_data': bool, + 'dump_step': int, + 'error_data_path': str, + 'enable_dataloader': bool, + 'target_iter': int, + 'precision': int + } + if not isinstance(value, validators[key]): + raise ValueError(f"{key} must be {validators[key].__name__} type") + if key == 'target_iter' and value < 0: + raise ValueError("target_iter must be greater than 0") + if key == 'precision' and value < 0: + raise ValueError("precision must be greater than 0") + return value + + def __getattr__(self, item): + return self.config[item] + + def __str__(self): + return '\n'.join(f"{key}={value}" for key, value in self.config.items()) + + def update_config(self, **kwargs): + for key, value in kwargs.items(): + if key in self.config: + self.config[key] = self.validate(key, value) + else: + raise ValueError(f"Invalid key '{key}'") + + +cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +yaml_path = os.path.join(cur_path, "config.yaml") msCheckerConfig = Config(yaml_path) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 065c0fe46d..da31000727 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -1,617 +1,619 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import collections -import json -import os -import random -import re -import stat -import subprocess -import sys -import time -from datetime import datetime, timezone - -import numpy as np -import torch -import csv - -try: - import torch_npu -except ImportError: - IS_GPU = True -else: - IS_GPU = False - -torch_without_guard_version_list = ['2.1'] -for version in torch_without_guard_version_list: - if torch.__version__.startswith(version): - torch_without_guard_version = True - break - else: - torch_without_guard_version = False -if not IS_GPU and not torch_without_guard_version: - from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard - -device = collections.namedtuple('device', ['type', 'index']) - - -class Const: - """ - Class for const - """ - MODEL_TYPE = ['.onnx', '.pb', '.om'] - DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*" - SEMICOLON = ";" - COLON = ":" - EQUAL = "=" - COMMA = "," - DOT = "." - DUMP_RATIO_MAX = 100 - SUMMERY_DATA_NUMS = 256 - ONE_HUNDRED_MB = 100*1024*1024 - FLOAT_EPSILON = np.finfo(float).eps - SUPPORT_DUMP_MODE = ['api', 'acl'] - ON = 'ON' - OFF = 'OFF' - BACKWARD = 'backward' - FORWARD = 'forward' - FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16] - BOOL_TYPE = [bool, np.uint8] - INT_TYPE = [np.int32, np.int64] - - # dump mode - ALL = "all" - LIST = "list" - RANGE = "range" - STACK = "stack" - ACL = "acl" - API_LIST = "api_list" - API_STACK = "api_stack" - DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] - - API_PATTERN = r"^[A-Za-z0-9]+[_]+([A-Za-z0-9]+[_]*[A-Za-z0-9]+)[_]+[0-9]+[_]+[A-Za-z0-9]+" - WRITE_FLAGS = os.O_WRONLY | os.O_CREAT - WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR - - RAISE_PRECISION = { - "torch.float16" : "torch.float32", - "torch.bfloat16" : "torch.float32", - "torch.float32" : "torch.float64" - } - CONVERT = { - "int32_to_int64": ["torch.int32", "torch.int64"], - } - - CONVERT_API = { - "int32_to_int64": ["cross_entropy"] - } - -class CompareConst: - """ - Class for compare module const - """ - # compare result column name - NPU_NAME = "NPU Name" - BENCH_NAME = "Bench Name" - NPU_DTYPE = "NPU Tensor Dtype" - BENCH_DTYPE = "Bench Tensor Dtype" - NPU_SHAPE = "NPU Tensor Shape" - BENCH_SHAPE = "Bench Tensor Shape" - NPU_MAX = "NPU max" - NPU_MIN = "NPU min" - NPU_MEAN = "NPU mean" - BENCH_MAX = "Bench max" - BENCH_MIN = "Bench min" - BENCH_MEAN = "Bench mean" - COSINE = "Cosine" - MAX_ABS_ERR = "MaxAbsErr" - ACCURACY = "Accuracy Reached or Not" - STACK = "NPU_Stack_Info" - ERROR_MESSAGE = "Err_message" - - # compare result data - NAN = 'Nan' - SHAPE_UNMATCH = 'shape unmatched' - DTYPE_UNMATCH = 'dtype unmatched' - - # accuracy standards - COS_THRESHOLD = 0.99 - MAX_ABS_ERR_THRESHOLD = 0.001 - COS_MAX_THRESHOLD = 0.9 - MAX_ABS_ERR_MAX_THRESHOLD = 1 - ACCURACY_CHECK_YES = "Yes" - ACCURACY_CHECK_NO = "No" - ACCURACY_CHECK_UNMATCH = "Unmatched" - - # error message - NO_BENCH = "No bench data matched." - - -class VersionCheck: - """ - Class for TorchVersion - """ - V1_8 = "1.8" - V1_11 = "1.11" - - @staticmethod - def check_torch_version(version): - torch_version = torch.__version__ - if torch_version.startswith(version): - return True - else: - return False - - -class CompareException(Exception): - """ - Class for Accuracy Compare Exception - """ - NONE_ERROR = 0 - INVALID_PATH_ERROR = 1 - OPEN_FILE_ERROR = 2 - CLOSE_FILE_ERROR = 3 - READ_FILE_ERROR = 4 - WRITE_FILE_ERROR = 5 - INVALID_FILE_ERROR = 6 - PERMISSION_ERROR = 7 - INDEX_OUT_OF_BOUNDS_ERROR = 8 - NO_DUMP_FILE_ERROR = 9 - INVALID_DATA_ERROR = 10 - INVALID_PARAM_ERROR = 11 - INVALID_DUMP_RATIO = 12 - INVALID_DUMP_FILE = 13 - UNKNOWN_ERROR = 14 - INVALID_DUMP_MODE = 15 - PARSE_FILE_ERROR = 16 - INVALID_COMPARE_MODE = 17 - - def __init__(self, code, error_info: str = ""): - super(CompareException, self).__init__() - self.code = code - self.error_info = error_info - - def __str__(self): - return self.error_info - -class DumpException(CompareException): - pass - -def read_json(file): - with open(file, 'r') as f: - obj = json.load(f) - return obj - -def write_csv(data, filepath): - with open(filepath, 'a') as f: - writer = csv.writer(f) - writer.writerows(data) - -def _print_log(level, msg): - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) - pid = os.getgid() - print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg) - sys.stdout.flush() - - -def print_info_log(info_msg): - """ - Function Description: - print info log. - Parameter: - info_msg: the info message. - """ - _print_log("INFO", info_msg) - - -def print_error_log(error_msg): - """ - Function Description: - print error log. - Parameter: - error_msg: the error message. - """ - _print_log("ERROR", error_msg) - - -def print_warn_log(warn_msg): - """ - Function Description: - print warn log. - Parameter: - warn_msg: the warning message. - """ - _print_log("WARNING", warn_msg) - - -def check_mode_valid(mode): - if mode not in Const.DUMP_MODE: - msg = "Current mode '%s' is not supported. Please use the field in %s" % \ - (mode, Const.DUMP_MODE) - raise CompareException(CompareException.INVALID_DUMP_MODE, msg) - - -def check_object_type(check_object, allow_type): - """ - Function Description: - Check if the object belongs to a certain data type - Parameter: - check_object: the object to be checked - allow_type: legal data type - Exception Description: - when invalid data throw exception - """ - if not isinstance(check_object, allow_type): - print_error_log(f"{check_object} not of {allow_type} type") - raise CompareException(CompareException.INVALID_DATA_ERROR) - - -def check_file_or_directory_path(path, isdir=False): - """ - Function Description: - check whether the path is valid - Parameter: - path: the path to check - isdir: the path is dir or file - Exception Description: - when invalid data throw exception - """ - if isdir: - if not os.path.exists(path): - print_error_log('The path {} is not exist.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.path.isdir(path): - print_error_log('The path {} is not a directory.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.access(path, os.W_OK): - print_error_log( - 'The path {} does not have permission to write. Please check the path permission'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - else: - if not os.path.isfile(path): - print_error_log('{} is an invalid file or non-exist.'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - if not os.access(path, os.R_OK): - print_error_log( - 'The path {} does not have permission to read. Please check the path permission'.format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) - -def _check_pkl(pkl_file_handle, file_name): - tensor_line = pkl_file_handle.readline() - if len(tensor_line) == 0: - print_error_log("dump file {} have empty line!".format(file_name)) - raise CompareException(CompareException.INVALID_DUMP_FILE) - pkl_file_handle.seek(0, 0) - - -def check_file_mode(npu_pkl, bench_pkl, stack_mode): - npu_pkl_name = os.path.split(npu_pkl)[-1] - bench_pkl_name = os.path.split(bench_pkl)[-1] - - if not npu_pkl_name.startswith("api_stack") and not bench_pkl_name.startswith("api_stack"): - if stack_mode: - print_error_log("The current file does not contain stack information, please turn off the stack_mode") - raise CompareException(CompareException.INVALID_COMPARE_MODE) - elif npu_pkl_name.startswith("api_stack") and bench_pkl_name.startswith("api_stack"): - if not stack_mode: - print_error_log("The current file contains stack information, please turn on the stack_mode") - raise CompareException(CompareException.INVALID_COMPARE_MODE) - else: - print_error_log("The dump mode of the two files is not same, please check the dump files") - raise CompareException(CompareException.INVALID_COMPARE_MODE) - - -def check_file_size(input_file, max_size): - try: - file_size = os.path.getsize(input_file) - except OSError as os_error: - print_error_log('Failed to open "%s". %s' % (input_file, str(os_error))) - raise CompareException(CompareException.INVALID_FILE_ERROR) - if file_size > max_size: - print_error_log('The size (%d) of %s exceeds (%d) bytes, tools not support.' - % (file_size, input_file, max_size)) - raise CompareException(CompareException.INVALID_FILE_ERROR) - - -def get_dump_data_path(dump_dir): - """ - Function Description: - traverse directories and obtain the absolute path of dump data - Parameter: - dump_dir: dump data directory - Return Value: - dump data path,file is exist or file is not exist - """ - dump_data_path = None - file_is_exist = False - - check_file_or_directory_path(dump_dir, True) - for dir_path, sub_paths, files in os.walk(dump_dir): - if len(files) != 0: - dump_data_path = dir_path - file_is_exist = True - break - dump_data_path = dir_path - return dump_data_path, file_is_exist - - -def get_api_name_from_matcher(name): - api_matcher = re.compile(Const.API_PATTERN) - match = api_matcher.match(name) - return match.group(1) if match else "" - - -def modify_dump_path(dump_path, mode): - if mode == Const.ALL: - return dump_path - file_name = os.path.split(dump_path) - mode_file_name = mode + "_" + file_name[-1] - return os.path.join(file_name[0], mode_file_name) - - -def create_directory(dir_path): - """ - Function Description: - creating a directory with specified permissions - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - if not os.path.exists(dir_path): - try: - os.makedirs(dir_path, mode=0o700) - except OSError as ex: - print_error_log( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) - - -def execute_command(cmd): - """ - Function Description: - run the following command - Parameter: - cmd: command - Exception Description: - when invalid command throw exception - """ - print_info_log('Execute command:%s' % cmd) - process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - while process.poll() is None: - line = process.stdout.readline() - line = line.strip() - if line: - print(line) - if process.returncode != 0: - print_error_log('Failed to execute command:%s' % " ".join(cmd)) - raise CompareException(CompareException.INVALID_DATA_ERROR) - - -def save_numpy_data(file_path, data): - """ - save_numpy_data - """ - if not os.path.exists(os.path.dirname(file_path)): - os.makedirs(os.path.dirname(file_path)) - np.save(file_path, data) - - -def parse_arg_value(values): - """ - parse dynamic arg value of atc cmdline - """ - value_list = [] - for item in values.split(Const.SEMICOLON): - value_list.append(parse_value_by_comma(item)) - return value_list - - -def parse_value_by_comma(value): - """ - parse value by comma, like '1,2,4,8' - """ - value_list = [] - value_str_list = value.split(Const.COMMA) - for value_str in value_str_list: - value_str = value_str.strip() - if value_str.isdigit() or value_str == '-1': - value_list.append(int(value_str)) - else: - print_error_log("please check your input shape.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - return value_list - - -def get_data_len_by_shape(shape): - data_len = 1 - for item in shape: - if item == -1: - print_error_log("please check your input shape, one dim in shape is -1.") - return -1 - data_len = data_len * item - return data_len - - -def add_time_as_suffix(name): - return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) - - -def get_time(): - return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") - - -def format_value(value): - return '{:.6f}'.format(value) - - -def torch_device_guard(func): - if IS_GPU or torch_without_guard_version: - return func - # Parse args/kwargs matched torch.device objects - - @torch_npu_device_guard - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - return wrapper - - -def seed_all(seed=1234, mode=False): - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.use_deterministic_algorithms(mode) - if IS_GPU: - torch.cuda.manual_seed_all(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.enable = False - torch.backends.cudnn.benchmark = False - else: - torch_npu.npu.manual_seed_all(seed) - torch_npu.npu.manual_seed(seed) - - -def get_process_rank(model): - print_info_log("Rank id is not provided. Trying to get the rank id of the model.") - try: - device = next(model.parameters()).device - except StopIteration: - print_warn_log('There is no parameter in the model. Fail to get rank id.') - return 0, False - if device.type == 'cpu': - print_warn_log("Warning: the debugger is unable to get the rank id. " - "This may cause the dumpped data to be corrupted in the " - "case of distributed training. (You may ignore this if you are using only one card.) " - "Transfer the model to npu or gpu before register_hook() to avoid this warning.") - return 0, False - else: - return device.index, True - - -def get_json_contents(file_path): - ops = get_file_content_bytes(file_path) - return json.loads(ops) - - -def get_file_content_bytes(file): - check_input_file_valid(file) - with open(file, 'rb') as file_handle: - return file_handle.read() - - -def islink(path): - path = os.path.abspath(path) - return os.path.islink(path) - - -class SoftlinkCheckException(Exception): - pass - - -MAX_JSON_FILE_SIZE = 10 * 1024 ** 2 -LINUX_FILE_NAME_LENGTH_LIMIT = 200 - - -def check_path_length_valid(path): - path = os.path.realpath(path) - return len(os.path.basename(path)) <= LINUX_FILE_NAME_LENGTH_LIMIT - - -def check_path_pattern_valid(path): - pattern = re.compile(r'(\.|/|:|_|-|\s|[~0-9a-zA-Z])+') - if not pattern.fullmatch(path): - raise ValueError('Only the following characters are allowed in the path: A-Z a-z 0-9 - _ . / :') - - -def check_input_file_valid(input_path, max_file_size=MAX_JSON_FILE_SIZE): - if islink(input_path): - raise SoftlinkCheckException("Input path doesn't support soft link.") - - input_path = os.path.realpath(input_path) - if not os.path.exists(input_path): - raise ValueError('Input file %s does not exist!' % input_path) - - if not os.access(input_path, os.R_OK): - raise PermissionError('Input file %s is not readable!' % input_path) - - check_path_pattern_valid(input_path) - - if not check_path_length_valid(input_path): - raise ValueError("The real path or file_name of input is too long.") - - if os.path.getsize(input_path) > max_file_size: - raise ValueError(f'The file is too large, exceeds {max_file_size // 1024 ** 2}MB') - - -def check_need_convert(api_name): - convert_type = None - for key, value in Const.CONVERT_API.items(): - if api_name not in value: - continue - else: - convert_type = key - return convert_type - -def api_info_preprocess(api_name, api_info_dict): - """ - Function Description: - Preprocesses the API information. - Parameter: - api_name: Name of the API. - api_info_dict: argument of the API. - Return api_info_dict: - convert_type: Type of conversion. - api_info_dict: Processed argument of the API. - """ - convert_type = check_need_convert(api_name) - if api_name == 'cross_entropy': - api_info_dict = cross_entropy_process(api_info_dict) - return convert_type, api_info_dict - -def cross_entropy_process(api_info_dict): - """ - Function Description: - Preprocesses the cross_entropy API information. - Parameter: - api_info_dict: argument of the API. - Return api_info_dict: - api_info_dict: Processed argument of the API. - """ - if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: - if api_info_dict['args'][1]['Min'] <= 0: - api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. - return api_info_dict - -def initialize_save_path(save_path, dir_name): - data_path = os.path.join(save_path, dir_name) - if os.path.exists(data_path): - raise ValueError(f"file {data_path} already exists, please remove it first") - else: - os.mkdir(data_path, mode = 0o750) - check_file_or_directory_path(data_path, True) - -def write_pt(file_path, tensor): - if os.path.exists(file_path): - raise ValueError(f"File {file_path} already exists") - torch.save(tensor, file_path) - full_path = os.path.abspath(file_path) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +import collections +import json +import os +import random +import re +import stat +import subprocess +import sys +import time +from datetime import datetime, timezone + +import numpy as np +import torch +import csv + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + +try: + import torch_npu +except ImportError: + IS_GPU = True +else: + IS_GPU = False + +torch_without_guard_version_list = ['2.1'] +for version in torch_without_guard_version_list: + if torch.__version__.startswith(version): + torch_without_guard_version = True + break + else: + torch_without_guard_version = False +if not IS_GPU and not torch_without_guard_version: + from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard + +device = collections.namedtuple('device', ['type', 'index']) + + +class Const: + """ + Class for const + """ + MODEL_TYPE = ['.onnx', '.pb', '.om'] + DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*" + SEMICOLON = ";" + COLON = ":" + EQUAL = "=" + COMMA = "," + DOT = "." + DUMP_RATIO_MAX = 100 + SUMMERY_DATA_NUMS = 256 + ONE_HUNDRED_MB = 100*1024*1024 + FLOAT_EPSILON = np.finfo(float).eps + SUPPORT_DUMP_MODE = ['api', 'acl'] + ON = 'ON' + OFF = 'OFF' + BACKWARD = 'backward' + FORWARD = 'forward' + FLOAT_TYPE = [np.half, np.single, float, np.double, np.float64, np.longdouble, np.float32, np.float16] + BOOL_TYPE = [bool, np.uint8] + INT_TYPE = [np.int32, np.int64] + + # dump mode + ALL = "all" + LIST = "list" + RANGE = "range" + STACK = "stack" + ACL = "acl" + API_LIST = "api_list" + API_STACK = "api_stack" + DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] + + API_PATTERN = r"^[A-Za-z0-9]+[_]+([A-Za-z0-9]+[_]*[A-Za-z0-9]+)[_]+[0-9]+[_]+[A-Za-z0-9]+" + WRITE_FLAGS = os.O_WRONLY | os.O_CREAT + WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR + + RAISE_PRECISION = { + "torch.float16" : "torch.float32", + "torch.bfloat16" : "torch.float32", + "torch.float32" : "torch.float64" + } + CONVERT = { + "int32_to_int64": ["torch.int32", "torch.int64"], + } + + CONVERT_API = { + "int32_to_int64": ["cross_entropy"] + } + +class CompareConst: + """ + Class for compare module const + """ + # compare result column name + NPU_NAME = "NPU Name" + BENCH_NAME = "Bench Name" + NPU_DTYPE = "NPU Tensor Dtype" + BENCH_DTYPE = "Bench Tensor Dtype" + NPU_SHAPE = "NPU Tensor Shape" + BENCH_SHAPE = "Bench Tensor Shape" + NPU_MAX = "NPU max" + NPU_MIN = "NPU min" + NPU_MEAN = "NPU mean" + BENCH_MAX = "Bench max" + BENCH_MIN = "Bench min" + BENCH_MEAN = "Bench mean" + COSINE = "Cosine" + MAX_ABS_ERR = "MaxAbsErr" + ACCURACY = "Accuracy Reached or Not" + STACK = "NPU_Stack_Info" + ERROR_MESSAGE = "Err_message" + + # compare result data + NAN = 'Nan' + SHAPE_UNMATCH = 'shape unmatched' + DTYPE_UNMATCH = 'dtype unmatched' + + # accuracy standards + COS_THRESHOLD = 0.99 + MAX_ABS_ERR_THRESHOLD = 0.001 + COS_MAX_THRESHOLD = 0.9 + MAX_ABS_ERR_MAX_THRESHOLD = 1 + ACCURACY_CHECK_YES = "Yes" + ACCURACY_CHECK_NO = "No" + ACCURACY_CHECK_UNMATCH = "Unmatched" + + # error message + NO_BENCH = "No bench data matched." + + +class VersionCheck: + """ + Class for TorchVersion + """ + V1_8 = "1.8" + V1_11 = "1.11" + + @staticmethod + def check_torch_version(version): + torch_version = torch.__version__ + if torch_version.startswith(version): + return True + else: + return False + + +class CompareException(Exception): + """ + Class for Accuracy Compare Exception + """ + NONE_ERROR = 0 + INVALID_PATH_ERROR = 1 + OPEN_FILE_ERROR = 2 + CLOSE_FILE_ERROR = 3 + READ_FILE_ERROR = 4 + WRITE_FILE_ERROR = 5 + INVALID_FILE_ERROR = 6 + PERMISSION_ERROR = 7 + INDEX_OUT_OF_BOUNDS_ERROR = 8 + NO_DUMP_FILE_ERROR = 9 + INVALID_DATA_ERROR = 10 + INVALID_PARAM_ERROR = 11 + INVALID_DUMP_RATIO = 12 + INVALID_DUMP_FILE = 13 + UNKNOWN_ERROR = 14 + INVALID_DUMP_MODE = 15 + PARSE_FILE_ERROR = 16 + INVALID_COMPARE_MODE = 17 + + def __init__(self, code, error_info: str = ""): + super(CompareException, self).__init__() + self.code = code + self.error_info = error_info + + def __str__(self): + return self.error_info + +class DumpException(CompareException): + pass + +def read_json(file): + with FileOpen(file, 'r') as f: + obj = json.load(f) + return obj + +def write_csv(data, filepath): + with FileOpen(filepath, 'a') as f: + writer = csv.writer(f) + writer.writerows(data) + +def _print_log(level, msg): + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) + pid = os.getgid() + print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg) + sys.stdout.flush() + + +def print_info_log(info_msg): + """ + Function Description: + print info log. + Parameter: + info_msg: the info message. + """ + _print_log("INFO", info_msg) + + +def print_error_log(error_msg): + """ + Function Description: + print error log. + Parameter: + error_msg: the error message. + """ + _print_log("ERROR", error_msg) + + +def print_warn_log(warn_msg): + """ + Function Description: + print warn log. + Parameter: + warn_msg: the warning message. + """ + _print_log("WARNING", warn_msg) + + +def check_mode_valid(mode): + if mode not in Const.DUMP_MODE: + msg = "Current mode '%s' is not supported. Please use the field in %s" % \ + (mode, Const.DUMP_MODE) + raise CompareException(CompareException.INVALID_DUMP_MODE, msg) + + +def check_object_type(check_object, allow_type): + """ + Function Description: + Check if the object belongs to a certain data type + Parameter: + check_object: the object to be checked + allow_type: legal data type + Exception Description: + when invalid data throw exception + """ + if not isinstance(check_object, allow_type): + print_error_log(f"{check_object} not of {allow_type} type") + raise CompareException(CompareException.INVALID_DATA_ERROR) + + +def check_file_or_directory_path(path, isdir=False): + """ + Function Description: + check whether the path is valid + Parameter: + path: the path to check + isdir: the path is dir or file + Exception Description: + when invalid data throw exception + """ + if isdir: + if not os.path.exists(path): + print_error_log('The path {} is not exist.'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + + if not os.path.isdir(path): + print_error_log('The path {} is not a directory.'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + + if not os.access(path, os.W_OK): + print_error_log( + 'The path {} does not have permission to write. Please check the path permission'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + else: + if not os.path.isfile(path): + print_error_log('{} is an invalid file or non-exist.'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + + if not os.access(path, os.R_OK): + print_error_log( + 'The path {} does not have permission to read. Please check the path permission'.format(path)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + +def _check_pkl(pkl_file_handle, file_name): + tensor_line = pkl_file_handle.readline() + if len(tensor_line) == 0: + print_error_log("dump file {} have empty line!".format(file_name)) + raise CompareException(CompareException.INVALID_DUMP_FILE) + pkl_file_handle.seek(0, 0) + + +def check_file_mode(npu_pkl, bench_pkl, stack_mode): + npu_pkl_name = os.path.split(npu_pkl)[-1] + bench_pkl_name = os.path.split(bench_pkl)[-1] + + if not npu_pkl_name.startswith("api_stack") and not bench_pkl_name.startswith("api_stack"): + if stack_mode: + print_error_log("The current file does not contain stack information, please turn off the stack_mode") + raise CompareException(CompareException.INVALID_COMPARE_MODE) + elif npu_pkl_name.startswith("api_stack") and bench_pkl_name.startswith("api_stack"): + if not stack_mode: + print_error_log("The current file contains stack information, please turn on the stack_mode") + raise CompareException(CompareException.INVALID_COMPARE_MODE) + else: + print_error_log("The dump mode of the two files is not same, please check the dump files") + raise CompareException(CompareException.INVALID_COMPARE_MODE) + + +def check_file_size(input_file, max_size): + try: + file_size = os.path.getsize(input_file) + except OSError as os_error: + print_error_log('Failed to open "%s". %s' % (input_file, str(os_error))) + raise CompareException(CompareException.INVALID_FILE_ERROR) + if file_size > max_size: + print_error_log('The size (%d) of %s exceeds (%d) bytes, tools not support.' + % (file_size, input_file, max_size)) + raise CompareException(CompareException.INVALID_FILE_ERROR) + + +def get_dump_data_path(dump_dir): + """ + Function Description: + traverse directories and obtain the absolute path of dump data + Parameter: + dump_dir: dump data directory + Return Value: + dump data path,file is exist or file is not exist + """ + dump_data_path = None + file_is_exist = False + + check_file_or_directory_path(dump_dir, True) + for dir_path, sub_paths, files in os.walk(dump_dir): + if len(files) != 0: + dump_data_path = dir_path + file_is_exist = True + break + dump_data_path = dir_path + return dump_data_path, file_is_exist + + +def get_api_name_from_matcher(name): + api_matcher = re.compile(Const.API_PATTERN) + match = api_matcher.match(name) + return match.group(1) if match else "" + + +def modify_dump_path(dump_path, mode): + if mode == Const.ALL: + return dump_path + file_name = os.path.split(dump_path) + mode_file_name = mode + "_" + file_name[-1] + return os.path.join(file_name[0], mode_file_name) + + +def create_directory(dir_path): + """ + Function Description: + creating a directory with specified permissions + Parameter: + dir_path: directory path + Exception Description: + when invalid data throw exception + """ + if not os.path.exists(dir_path): + try: + os.makedirs(dir_path, mode=0o700) + except OSError as ex: + print_error_log( + 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) + raise CompareException(CompareException.INVALID_PATH_ERROR) + + +def execute_command(cmd): + """ + Function Description: + run the following command + Parameter: + cmd: command + Exception Description: + when invalid command throw exception + """ + print_info_log('Execute command:%s' % cmd) + process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + while process.poll() is None: + line = process.stdout.readline() + line = line.strip() + if line: + print(line) + if process.returncode != 0: + print_error_log('Failed to execute command:%s' % " ".join(cmd)) + raise CompareException(CompareException.INVALID_DATA_ERROR) + + +def save_numpy_data(file_path, data): + """ + save_numpy_data + """ + if not os.path.exists(os.path.dirname(file_path)): + os.makedirs(os.path.dirname(file_path)) + np.save(file_path, data) + + +def parse_arg_value(values): + """ + parse dynamic arg value of atc cmdline + """ + value_list = [] + for item in values.split(Const.SEMICOLON): + value_list.append(parse_value_by_comma(item)) + return value_list + + +def parse_value_by_comma(value): + """ + parse value by comma, like '1,2,4,8' + """ + value_list = [] + value_str_list = value.split(Const.COMMA) + for value_str in value_str_list: + value_str = value_str.strip() + if value_str.isdigit() or value_str == '-1': + value_list.append(int(value_str)) + else: + print_error_log("please check your input shape.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + return value_list + + +def get_data_len_by_shape(shape): + data_len = 1 + for item in shape: + if item == -1: + print_error_log("please check your input shape, one dim in shape is -1.") + return -1 + data_len = data_len * item + return data_len + + +def add_time_as_suffix(name): + return '{}_{}.csv'.format(name, time.strftime("%Y%m%d%H%M%S", time.localtime(time.time()))) + + +def get_time(): + return datetime.now(tz=timezone.utc).strftime("%Y%m%d_%H%M%S") + + +def format_value(value): + return '{:.6f}'.format(value) + + +def torch_device_guard(func): + if IS_GPU or torch_without_guard_version: + return func + # Parse args/kwargs matched torch.device objects + + @torch_npu_device_guard + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + + +def seed_all(seed=1234, mode=False): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(mode) + if IS_GPU: + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enable = False + torch.backends.cudnn.benchmark = False + else: + torch_npu.npu.manual_seed_all(seed) + torch_npu.npu.manual_seed(seed) + + +def get_process_rank(model): + print_info_log("Rank id is not provided. Trying to get the rank id of the model.") + try: + device = next(model.parameters()).device + except StopIteration: + print_warn_log('There is no parameter in the model. Fail to get rank id.') + return 0, False + if device.type == 'cpu': + print_warn_log("Warning: the debugger is unable to get the rank id. " + "This may cause the dumpped data to be corrupted in the " + "case of distributed training. (You may ignore this if you are using only one card.) " + "Transfer the model to npu or gpu before register_hook() to avoid this warning.") + return 0, False + else: + return device.index, True + + +def get_json_contents(file_path): + ops = get_file_content_bytes(file_path) + return json.loads(ops) + + +def get_file_content_bytes(file): + check_input_file_valid(file) + with FileOpen(file, 'rb') as file_handle: + return file_handle.read() + + +def islink(path): + path = os.path.abspath(path) + return os.path.islink(path) + + +class SoftlinkCheckException(Exception): + pass + + +MAX_JSON_FILE_SIZE = 10 * 1024 ** 2 +LINUX_FILE_NAME_LENGTH_LIMIT = 200 + + +def check_path_length_valid(path): + path = os.path.realpath(path) + return len(os.path.basename(path)) <= LINUX_FILE_NAME_LENGTH_LIMIT + + +def check_path_pattern_valid(path): + pattern = re.compile(r'(\.|/|:|_|-|\s|[~0-9a-zA-Z])+') + if not pattern.fullmatch(path): + raise ValueError('Only the following characters are allowed in the path: A-Z a-z 0-9 - _ . / :') + + +def check_input_file_valid(input_path, max_file_size=MAX_JSON_FILE_SIZE): + if islink(input_path): + raise SoftlinkCheckException("Input path doesn't support soft link.") + + input_path = os.path.realpath(input_path) + if not os.path.exists(input_path): + raise ValueError('Input file %s does not exist!' % input_path) + + if not os.access(input_path, os.R_OK): + raise PermissionError('Input file %s is not readable!' % input_path) + + check_path_pattern_valid(input_path) + + if not check_path_length_valid(input_path): + raise ValueError("The real path or file_name of input is too long.") + + if os.path.getsize(input_path) > max_file_size: + raise ValueError(f'The file is too large, exceeds {max_file_size // 1024 ** 2}MB') + + +def check_need_convert(api_name): + convert_type = None + for key, value in Const.CONVERT_API.items(): + if api_name not in value: + continue + else: + convert_type = key + return convert_type + +def api_info_preprocess(api_name, api_info_dict): + """ + Function Description: + Preprocesses the API information. + Parameter: + api_name: Name of the API. + api_info_dict: argument of the API. + Return api_info_dict: + convert_type: Type of conversion. + api_info_dict: Processed argument of the API. + """ + convert_type = check_need_convert(api_name) + if api_name == 'cross_entropy': + api_info_dict = cross_entropy_process(api_info_dict) + return convert_type, api_info_dict + +def cross_entropy_process(api_info_dict): + """ + Function Description: + Preprocesses the cross_entropy API information. + Parameter: + api_info_dict: argument of the API. + Return api_info_dict: + api_info_dict: Processed argument of the API. + """ + if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: + if api_info_dict['args'][1]['Min'] <= 0: + api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. + return api_info_dict + +def initialize_save_path(save_path, dir_name): + data_path = os.path.join(save_path, dir_name) + if os.path.exists(data_path): + raise ValueError(f"file {data_path} already exists, please remove it first") + else: + os.mkdir(data_path, mode = 0o750) + check_file_or_directory_path(data_path, True) + +def write_pt(file_path, tensor): + if os.path.exists(file_path): + raise ValueError(f"File {file_path} already exists") + torch.save(tensor, file_path) + full_path = os.path.abspath(file_path) return full_path \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index f2a96bd0fa..bb3da868fd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -7,6 +7,8 @@ from .api_info import ForwardAPIInfo, BackwardAPIInfo from ..common.utils import check_file_or_directory_path, initialize_save_path from ..common.config import msCheckerConfig +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + lock = threading.Lock() def write_api_info_json(api_info): @@ -27,10 +29,10 @@ def write_api_info_json(api_info): def write_json(file_path, data, indent=None): check_file_or_directory_path(os.path.dirname(file_path),True) if not os.path.exists(file_path): - with open(file_path, 'w') as f: + with FileOpen(file_path, 'w') as f: f.write("{\n}") lock.acquire() - with open(file_path, 'a+') as f: + with FileOpen(file_path, 'a+') as f: fcntl.flock(f, fcntl.LOCK_EX) try: f.seek(0, os.SEEK_END) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py index dbe27a134c..1900686265 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_functional.py @@ -1,67 +1,69 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os - -import torch -import yaml - -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with open(yaml_path, 'r') as f: - WrapFunctionalOps = yaml.safe_load(f).get('functional') - -for f in dir(torch.nn.functional): - locals().update({f: getattr(torch.nn.functional, f)}) - - -def get_functional_ops(): - global WrapFunctionalOps - _all_functional_ops = dir(torch.nn.functional) - return set(WrapFunctionalOps) & set(_all_functional_ops) - - -class HOOKFunctionalOP(object): - pass - - -class FunctionalOPTemplate(HOOKModule): - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Functional*" + str(op_name) + "*" - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return eval(self.op_name_)(*args, **kwargs) - - -def wrap_functional_op(op_name, hook): - def functional_op_template(*args, **kwargs): - return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) - - return functional_op_template - - -def wrap_functional_ops_and_bind(hook): - _functional_ops = get_functional_ops() - for op_name in _functional_ops: - setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch +import yaml + +from api_accuracy_checker.hook_module.hook_module import HOOKModule +from api_accuracy_checker.common.utils import torch_device_guard + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapFunctionalOps = yaml.safe_load(f).get('functional') + +for f in dir(torch.nn.functional): + locals().update({f: getattr(torch.nn.functional, f)}) + + +def get_functional_ops(): + global WrapFunctionalOps + _all_functional_ops = dir(torch.nn.functional) + return set(WrapFunctionalOps) & set(_all_functional_ops) + + +class HOOKFunctionalOP(object): + pass + + +class FunctionalOPTemplate(HOOKModule): + def __init__(self, op_name, hook, need_hook=True): + self.op_name_ = op_name + self.prefix_op_name_ = "Functional*" + str(op_name) + "*" + if need_hook: + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + return eval(self.op_name_)(*args, **kwargs) + + +def wrap_functional_op(op_name, hook): + def functional_op_template(*args, **kwargs): + return FunctionalOPTemplate(op_name, hook)(*args, **kwargs) + + return functional_op_template + + +def wrap_functional_ops_and_bind(hook): + _functional_ops = get_functional_ops() + for op_name in _functional_ops: + setattr(HOOKFunctionalOP, "wrap_" + op_name, wrap_functional_op(op_name, hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py index 93d92923c6..490c2a08d7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_tensor.py @@ -1,66 +1,68 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os - -import torch -import yaml - -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with open(yaml_path, 'r') as f: - WrapTensorOps = yaml.safe_load(f).get('tensor') - - -def get_tensor_ops(): - global WrapTensorOps - _tensor_ops = dir(torch._C._TensorBase) - return set(WrapTensorOps) & set(_tensor_ops) - - -class HOOKTensor(object): - pass - - -class TensorOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Tensor*" + str(op_name) + "*" - if need_hook: - super().__init__(hook) - - @torch_device_guard - def forward(self, *args, **kwargs): - return getattr(torch._C._TensorBase, str(self.op_name_))(*args, **kwargs) - - -def wrap_tensor_op(op_name, hook): - - def tensor_op_template(*args, **kwargs): - return TensorOPTemplate(op_name, hook)(*args, **kwargs) - - return tensor_op_template - - -def wrap_tensor_ops_and_bind(hook): - _tensor_ops = get_tensor_ops() - for op_name in _tensor_ops: - setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook)) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch +import yaml + +from api_accuracy_checker.hook_module.hook_module import HOOKModule +from api_accuracy_checker.common.utils import torch_device_guard + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapTensorOps = yaml.safe_load(f).get('tensor') + + +def get_tensor_ops(): + global WrapTensorOps + _tensor_ops = dir(torch._C._TensorBase) + return set(WrapTensorOps) & set(_tensor_ops) + + +class HOOKTensor(object): + pass + + +class TensorOPTemplate(HOOKModule): + + def __init__(self, op_name, hook, need_hook=True): + self.op_name_ = op_name + self.prefix_op_name_ = "Tensor*" + str(op_name) + "*" + if need_hook: + super().__init__(hook) + + @torch_device_guard + def forward(self, *args, **kwargs): + return getattr(torch._C._TensorBase, str(self.op_name_))(*args, **kwargs) + + +def wrap_tensor_op(op_name, hook): + + def tensor_op_template(*args, **kwargs): + return TensorOPTemplate(op_name, hook)(*args, **kwargs) + + return tensor_op_template + + +def wrap_tensor_ops_and_bind(hook): + _tensor_ops = get_tensor_ops() + for op_name in _tensor_ops: + setattr(HOOKTensor, "wrap_" + str(op_name), wrap_tensor_op(op_name, hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py index 07a037b779..1a0bc1c71f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/api_accuracy_checker/hook_module/wrap_torch.py @@ -1,109 +1,111 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import os - -import torch -import yaml - -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard - -cur_path = os.path.dirname(os.path.realpath(__file__)) -yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") -with open(yaml_path, 'r') as f: - WrapTorchOps = yaml.safe_load(f).get('torch') - - -def get_torch_ops(): - global WrapTorchOps - _torch_ops = dir(torch._C._VariableFunctionsClass) - return set(WrapTorchOps) & set(_torch_ops) - - -class HOOKTorchOP(object): - pass - - -class TorchOPTemplate(HOOKModule): - - def __init__(self, op_name, hook, need_hook=True): - self.op_name_ = op_name - self.prefix_op_name_ = "Torch*" + str(op_name) + "*" - if need_hook: - super().__init__(hook) - - def input_param_need_adapt(self): - special_op_list = ["broadcast_tensors"] - for item in special_op_list: - if item in self.op_name_: - return True - return False - - def einsum_adapt(self, *args): - if len(args) < 2: - raise ValueError('einsum(): must specify the equation string and at least one operand, ' - 'or at least one operand and its subscripts list') - equation = None - operands = None - if isinstance(args[0], torch.Tensor): - def parse_subscript(n: int) -> str: - if n == Ellipsis: - return '...' - if n >= 0 and n < 26: - return chr(ord('A') + n) - if n >= 26 and n < 52: - return chr(ord('a') + n - 26) - raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52]') - equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2]) - - if len(args) % 2 == 1: - equation += '->' + ''.join(parse_subscript(s) for s in args[-1]) - operands = args[:-1:2] - else: - operands = args[::2] - else: - equation = args[0] - operands = args[1:] - - if len(operands) == 1 and isinstance(operands[0], (list, tuple)): - _operands = operands[0] - return self.einsum_adapt(equation, *_operands) - return equation, operands - - @torch_device_guard - def forward(self, *args, **kwargs): - if self.input_param_need_adapt(): - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(args, **kwargs) - else: - if self.op_name_ == 'einsum': - args = self.einsum_adapt(*args) - return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) - - -def wrap_torch_op(op_name, hook): - - def torch_op_template(*args, **kwargs): - return TorchOPTemplate(op_name, hook)(*args, **kwargs) - - return torch_op_template - - -def wrap_torch_ops_and_bind(hook): - _torch_ops = get_torch_ops() - for op_name in _torch_ops: - setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook)) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import os + +import torch +import yaml + +from api_accuracy_checker.hook_module.hook_module import HOOKModule +from api_accuracy_checker.common.utils import torch_device_guard + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen + +cur_path = os.path.dirname(os.path.realpath(__file__)) +yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") +with FileOpen(yaml_path, 'r') as f: + WrapTorchOps = yaml.safe_load(f).get('torch') + + +def get_torch_ops(): + global WrapTorchOps + _torch_ops = dir(torch._C._VariableFunctionsClass) + return set(WrapTorchOps) & set(_torch_ops) + + +class HOOKTorchOP(object): + pass + + +class TorchOPTemplate(HOOKModule): + + def __init__(self, op_name, hook, need_hook=True): + self.op_name_ = op_name + self.prefix_op_name_ = "Torch*" + str(op_name) + "*" + if need_hook: + super().__init__(hook) + + def input_param_need_adapt(self): + special_op_list = ["broadcast_tensors"] + for item in special_op_list: + if item in self.op_name_: + return True + return False + + def einsum_adapt(self, *args): + if len(args) < 2: + raise ValueError('einsum(): must specify the equation string and at least one operand, ' + 'or at least one operand and its subscripts list') + equation = None + operands = None + if isinstance(args[0], torch.Tensor): + def parse_subscript(n: int) -> str: + if n == Ellipsis: + return '...' + if n >= 0 and n < 26: + return chr(ord('A') + n) + if n >= 26 and n < 52: + return chr(ord('a') + n - 26) + raise ValueError('einsum(): subscript in subscript list is not within the valid range [0, 52]') + equation = ','.join(''.join(parse_subscript(s) for s in l) for l in args[1::2]) + + if len(args) % 2 == 1: + equation += '->' + ''.join(parse_subscript(s) for s in args[-1]) + operands = args[:-1:2] + else: + operands = args[::2] + else: + equation = args[0] + operands = args[1:] + + if len(operands) == 1 and isinstance(operands[0], (list, tuple)): + _operands = operands[0] + return self.einsum_adapt(equation, *_operands) + return equation, operands + + @torch_device_guard + def forward(self, *args, **kwargs): + if self.input_param_need_adapt(): + return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(args, **kwargs) + else: + if self.op_name_ == 'einsum': + args = self.einsum_adapt(*args) + return getattr(torch._C._VariableFunctionsClass, str(self.op_name_))(*args, **kwargs) + + +def wrap_torch_op(op_name, hook): + + def torch_op_template(*args, **kwargs): + return TorchOPTemplate(op_name, hook)(*args, **kwargs) + + return torch_op_template + + +def wrap_torch_ops_and_bind(hook): + _torch_ops = get_torch_ops() + for op_name in _torch_ops: + setattr(HOOKTorchOP, "wrap_" + op_name, wrap_torch_op(op_name, hook)) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 236726b3db..b6272ce0f1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -1,277 +1,277 @@ -import argparse -import os -import copy -import sys -import torch_npu -import yaml -import torch -from tqdm import tqdm -from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ - print_error_log, check_file_or_directory_path, initialize_save_path, Const -from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate -from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate -from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate -from ut_api_info import UtAPIInfo -from api_accuracy_checker.common.config import msCheckerConfig - -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileChecker, FileCheckConst - -NO_GRAD_APIS = ["hardtanh"] - - -def init_environment(): - cur_path = os.path.dirname(os.path.realpath(__file__)) - yaml_path = os.path.join(cur_path, "../hook_module/support_wrap_ops.yaml") - yaml_path_checker = FileChecker(yaml_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE) - yaml_path = yaml_path_checker.common_check() - with open(yaml_path, 'r') as f: - WrapFunctionalOps = yaml.safe_load(f).get('functional') - for f in dir(torch.nn.functional): - if f != "__name__": - locals().update({f: getattr(torch.nn.functional, f)}) - - -init_environment() - - -def exec_api(api_type, api_name, args, kwargs): - if api_type == "Functional": - functional_api = FunctionalOPTemplate(api_name, str, False) - out = functional_api.forward(*args, **kwargs) - if api_type == "Tensor": - tensor_api = TensorOPTemplate(api_name, str, False) - out = tensor_api.forward(*args, **kwargs) - if api_type == "Torch": - torch_api = TorchOPTemplate(api_name, str, False) - out = torch_api.forward(*args, **kwargs) - return out - - -def generate_npu_params(input_args, input_kwargs, need_backward): - def recursive_arg_to_npu(arg_in): - if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_npu(arg) for arg in arg_in) - elif isinstance(arg_in, torch.Tensor): - if need_backward and arg_in.requires_grad: - arg_in = arg_in.clone().detach().to("npu").requires_grad_() - temp_arg_in = arg_in * 1 - arg_in = temp_arg_in.type_as(arg_in) - arg_in.retain_grad() - return arg_in - else: - return arg_in.clone().detach().to("npu") - else: - return arg_in - - npu_args = recursive_arg_to_npu(input_args) - npu_kwargs = {key: recursive_arg_to_npu(value) for key, value in input_kwargs.items()} - return npu_args, npu_kwargs - -def generate_cpu_params(input_args, input_kwargs, need_backward): - first_dtype = None - def recursive_arg_to_cpu(arg_in): - nonlocal first_dtype - if isinstance(arg_in, (list, tuple)): - return type(arg_in)(recursive_arg_to_cpu(arg) for arg in arg_in) - elif isinstance(arg_in, torch.Tensor): - if need_backward and arg_in.requires_grad: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach().requires_grad_() - if first_dtype is None: - first_dtype = arg_in.dtype - else: - arg_in = arg_in.clone().detach().requires_grad_() - temp_arg_in = arg_in * 1 - arg_in = temp_arg_in.type_as(arg_in) - arg_in.retain_grad() - return arg_in - else: - if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: - arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach() - if first_dtype is None: - first_dtype = arg_in.dtype - return arg_in - return arg_in.clone().detach() - else: - return arg_in - - cpu_args = recursive_arg_to_cpu(input_args) - cpu_kwargs = {key: recursive_arg_to_cpu(value) for key, value in input_kwargs.items()} - return cpu_args, cpu_kwargs - -def run_ut(forward_file, backward_file, out_path, save_error_data): - print_info_log("start UT test") - forward_content = get_json_contents(forward_file) - backward_content = get_json_contents(backward_file) - api_setting_dict = get_json_contents("torch_ut_setting.json") - compare = Comparator(out_path) - for api_full_name, api_info_dict in tqdm(forward_content.items()): - try: - data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) - is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, - data_info.bench_out, - data_info.npu_out, - data_info.bench_grad_out, - data_info.npu_grad_out) - if save_error_data: - do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) - except Exception as err: - [_, api_name, _] = api_full_name.split("*") - if "expected scalar type Long" in str(err): - print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " - f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") - else: - print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) - compare.write_summary_csv((api_full_name, "SKIP", "SKIP", str(err))) - compare.print_pretest_result() - - -def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): - if not is_fwd_success or not is_bwd_success: - api_full_name = api_full_name.replace("*", ".") - for element in data_info.in_fwd_data_list: - UtAPIInfo(api_full_name + '.forward.input', element) - UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) - UtAPIInfo(api_full_name + '.forward.output.npu', data_info.npu_out) - UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) - UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out) - UtAPIInfo(api_full_name + '.backward.output.npu', data_info.npu_grad_out) - - -def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): - in_fwd_data_list = [] - [api_type, api_name, _] = api_full_name.split("*") - args, kwargs, need_grad = get_api_info(api_info_dict, api_name) - in_fwd_data_list.append(args) - in_fwd_data_list.append(kwargs) - need_backward = api_full_name in backward_content and api_name[-1] != "_" - need_backward = need_backward and need_grad - if not need_grad: - print_warn_log("%s involves in-place operations, skip backward" % api_full_name) - cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) - npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) - grad_out, npu_grad_out = None, None - if kwargs.get("device"): - del kwargs["device"] - out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) - npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) - grad_input_index = api_setting_dict.get(api_name) - grad_index = None - grad = None - if grad_input_index is not None: - grad_index = grad_input_index.get('grad_index') - - if need_backward: - grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, cpu_args, backward_content, grad_index, npu_args, - npu_out, out) - if grad_index is not None: - return UtDataInfo(grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], grad, in_fwd_data_list) - return UtDataInfo(grad_out, npu_grad_out, npu_out, out, grad, in_fwd_data_list) - - -def get_api_info(api_info_dict, api_name): - convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) - need_grad = True - if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): - need_grad = False - if api_name[-1] == "_" or api_name in NO_GRAD_APIS: - need_grad = False - args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) - return args, kwargs, need_grad - - -def run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out): - backward_args = backward_content[api_full_name] - grad = gen_args(backward_args)[0] - cpu_grad, _ = generate_cpu_params(grad, {}, False) - if grad_index is not None: - out[grad_index].backward(cpu_grad) - elif isinstance(out, (list, tuple)): - raise NotImplementedError("Multiple backward is not supported.") - else: - out.backward(cpu_grad) - args_grad = [] - for arg in args: - if isinstance(arg, torch.Tensor): - args_grad.append(arg.grad) - grad_out = args_grad - npu_grad = grad.clone().detach().npu() - if grad_index is not None: - npu_out[grad_index].backward(npu_grad) - else: - npu_out.backward(npu_grad) - npu_args_grad = [] - for arg in npu_args: - if isinstance(arg, torch.Tensor): - npu_args_grad.append(arg.grad) - npu_grad_out = npu_args_grad - return grad_out, npu_grad_out, grad, npu_grad - - -def initialize_save_error_data(): - error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR, - ability=FileCheckConst.WRITE_ABLE) - error_data_path = error_data_path_checker.common_check() - initialize_save_path(error_data_path, 'ut_error_data') - - -def _run_ut_parser(parser): - parser.add_argument("-forward", "--forward_input_file", dest="forward_input_file", default="", type=str, - help=" The api param tool forward result file: generate from api param tool, " - "a json file.", - required=True) - parser.add_argument("-backward", "--backward_input_file", dest="backward_input_file", default="", type=str, - help=" The api param tool backward result file: generate from api param tool, " - "a json file.", - required=True) - parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, - help=" The ut task result out path.", - required=False) - parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", - help=" Save compare failed api output.", required=False) - parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true", - help=" whether to turn on jit compile", required=False) - parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set NPU device id to run ut", - default=0, required=False) - - -def _run_ut(): - parser = argparse.ArgumentParser() - _run_ut_parser(parser) - args = parser.parse_args(sys.argv[1:]) - torch.npu.set_compile_mode(jit_compile=args.jit_compile) - npu_device = "npu:" + str(args.device_id) - try: - torch.npu.set_device(npu_device) - except Exception: - print_error_log(f"Set NPU device id failed. device id is: {args.device_id}") - raise NotImplementedError - forward_file_checker = FileChecker(args.forward_input_file, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE, - file_type=FileCheckConst.JSON_SUFFIX) - forward_file = forward_file_checker.common_check() - backward_file_checker = FileChecker(args.backward_input_file, FileCheckConst.FILE, - ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX) - backward_file = backward_file_checker.common_check() - out_path = os.path.realpath(args.out_path) if args.out_path else "./" - out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) - out_path = out_path_checker.common_check() - save_error_data = args.save_error_data - if save_error_data: - initialize_save_error_data() - run_ut(forward_file, backward_file, out_path, save_error_data) - - -class UtDataInfo: - def __init__(self, bench_grad_out, npu_grad_out, npu_out, bench_out, grad_in, in_fwd_data_list): - self.bench_grad_out = bench_grad_out - self.npu_grad_out = npu_grad_out - self.npu_out = npu_out - self.bench_out = bench_out - self.grad_in = grad_in - self.in_fwd_data_list = in_fwd_data_list - -if __name__ == '__main__': - _run_ut() - print_info_log("UT task completed.") +import argparse +import os +import copy +import sys +import torch_npu +import yaml +import torch +from tqdm import tqdm +from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args +from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ + print_error_log, check_file_or_directory_path, initialize_save_path, Const +from api_accuracy_checker.compare.compare import Comparator +from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate +from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate +from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate +from ut_api_info import UtAPIInfo +from api_accuracy_checker.common.config import msCheckerConfig + +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileChecker, FileCheckConst, FileOpen + +NO_GRAD_APIS = ["hardtanh"] + + +def init_environment(): + cur_path = os.path.dirname(os.path.realpath(__file__)) + yaml_path = os.path.join(cur_path, "../hook_module/support_wrap_ops.yaml") + yaml_path_checker = FileChecker(yaml_path, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE) + yaml_path = yaml_path_checker.common_check() + with FileOpen(yaml_path, 'r') as f: + WrapFunctionalOps = yaml.safe_load(f).get('functional') + for f in dir(torch.nn.functional): + if f != "__name__": + locals().update({f: getattr(torch.nn.functional, f)}) + + +init_environment() + + +def exec_api(api_type, api_name, args, kwargs): + if api_type == "Functional": + functional_api = FunctionalOPTemplate(api_name, str, False) + out = functional_api.forward(*args, **kwargs) + if api_type == "Tensor": + tensor_api = TensorOPTemplate(api_name, str, False) + out = tensor_api.forward(*args, **kwargs) + if api_type == "Torch": + torch_api = TorchOPTemplate(api_name, str, False) + out = torch_api.forward(*args, **kwargs) + return out + + +def generate_npu_params(input_args, input_kwargs, need_backward): + def recursive_arg_to_npu(arg_in): + if isinstance(arg_in, (list, tuple)): + return type(arg_in)(recursive_arg_to_npu(arg) for arg in arg_in) + elif isinstance(arg_in, torch.Tensor): + if need_backward and arg_in.requires_grad: + arg_in = arg_in.clone().detach().to("npu").requires_grad_() + temp_arg_in = arg_in * 1 + arg_in = temp_arg_in.type_as(arg_in) + arg_in.retain_grad() + return arg_in + else: + return arg_in.clone().detach().to("npu") + else: + return arg_in + + npu_args = recursive_arg_to_npu(input_args) + npu_kwargs = {key: recursive_arg_to_npu(value) for key, value in input_kwargs.items()} + return npu_args, npu_kwargs + +def generate_cpu_params(input_args, input_kwargs, need_backward): + first_dtype = None + def recursive_arg_to_cpu(arg_in): + nonlocal first_dtype + if isinstance(arg_in, (list, tuple)): + return type(arg_in)(recursive_arg_to_cpu(arg) for arg in arg_in) + elif isinstance(arg_in, torch.Tensor): + if need_backward and arg_in.requires_grad: + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: + arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach().requires_grad_() + if first_dtype is None: + first_dtype = arg_in.dtype + else: + arg_in = arg_in.clone().detach().requires_grad_() + temp_arg_in = arg_in * 1 + arg_in = temp_arg_in.type_as(arg_in) + arg_in.retain_grad() + return arg_in + else: + if str(arg_in.dtype) in Const.RAISE_PRECISION.keys() and arg_in.dtype != first_dtype: + arg_in = arg_in.clone().type(eval(Const.RAISE_PRECISION[str(arg_in.dtype)])).detach() + if first_dtype is None: + first_dtype = arg_in.dtype + return arg_in + return arg_in.clone().detach() + else: + return arg_in + + cpu_args = recursive_arg_to_cpu(input_args) + cpu_kwargs = {key: recursive_arg_to_cpu(value) for key, value in input_kwargs.items()} + return cpu_args, cpu_kwargs + +def run_ut(forward_file, backward_file, out_path, save_error_data): + print_info_log("start UT test") + forward_content = get_json_contents(forward_file) + backward_content = get_json_contents(backward_file) + api_setting_dict = get_json_contents("torch_ut_setting.json") + compare = Comparator(out_path) + for api_full_name, api_info_dict in tqdm(forward_content.items()): + try: + data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, + data_info.bench_out, + data_info.npu_out, + data_info.bench_grad_out, + data_info.npu_grad_out) + if save_error_data: + do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) + except Exception as err: + [_, api_name, _] = api_full_name.split("*") + if "expected scalar type Long" in str(err): + print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") + else: + print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) + compare.write_summary_csv((api_full_name, "SKIP", "SKIP", str(err))) + compare.print_pretest_result() + + +def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): + if not is_fwd_success or not is_bwd_success: + api_full_name = api_full_name.replace("*", ".") + for element in data_info.in_fwd_data_list: + UtAPIInfo(api_full_name + '.forward.input', element) + UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) + UtAPIInfo(api_full_name + '.forward.output.npu', data_info.npu_out) + UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) + UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out) + UtAPIInfo(api_full_name + '.backward.output.npu', data_info.npu_grad_out) + + +def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): + in_fwd_data_list = [] + [api_type, api_name, _] = api_full_name.split("*") + args, kwargs, need_grad = get_api_info(api_info_dict, api_name) + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) + need_backward = api_full_name in backward_content and api_name[-1] != "_" + need_backward = need_backward and need_grad + if not need_grad: + print_warn_log("%s involves in-place operations, skip backward" % api_full_name) + cpu_args, cpu_kwargs = generate_cpu_params(args, kwargs, need_backward) + npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) + grad_out, npu_grad_out = None, None + if kwargs.get("device"): + del kwargs["device"] + out = exec_api(api_type, api_name, cpu_args, cpu_kwargs) + npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) + grad_input_index = api_setting_dict.get(api_name) + grad_index = None + grad = None + if grad_input_index is not None: + grad_index = grad_input_index.get('grad_index') + + if need_backward: + grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, cpu_args, backward_content, grad_index, npu_args, + npu_out, out) + if grad_index is not None: + return UtDataInfo(grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], grad, in_fwd_data_list) + return UtDataInfo(grad_out, npu_grad_out, npu_out, out, grad, in_fwd_data_list) + + +def get_api_info(api_info_dict, api_name): + convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) + need_grad = True + if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): + need_grad = False + if api_name[-1] == "_" or api_name in NO_GRAD_APIS: + need_grad = False + args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type) + return args, kwargs, need_grad + + +def run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out): + backward_args = backward_content[api_full_name] + grad = gen_args(backward_args)[0] + cpu_grad, _ = generate_cpu_params(grad, {}, False) + if grad_index is not None: + out[grad_index].backward(cpu_grad) + elif isinstance(out, (list, tuple)): + raise NotImplementedError("Multiple backward is not supported.") + else: + out.backward(cpu_grad) + args_grad = [] + for arg in args: + if isinstance(arg, torch.Tensor): + args_grad.append(arg.grad) + grad_out = args_grad + npu_grad = grad.clone().detach().npu() + if grad_index is not None: + npu_out[grad_index].backward(npu_grad) + else: + npu_out.backward(npu_grad) + npu_args_grad = [] + for arg in npu_args: + if isinstance(arg, torch.Tensor): + npu_args_grad.append(arg.grad) + npu_grad_out = npu_args_grad + return grad_out, npu_grad_out, grad, npu_grad + + +def initialize_save_error_data(): + error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR, + ability=FileCheckConst.WRITE_ABLE) + error_data_path = error_data_path_checker.common_check() + initialize_save_path(error_data_path, 'ut_error_data') + + +def _run_ut_parser(parser): + parser.add_argument("-forward", "--forward_input_file", dest="forward_input_file", default="", type=str, + help=" The api param tool forward result file: generate from api param tool, " + "a json file.", + required=True) + parser.add_argument("-backward", "--backward_input_file", dest="backward_input_file", default="", type=str, + help=" The api param tool backward result file: generate from api param tool, " + "a json file.", + required=True) + parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, + help=" The ut task result out path.", + required=False) + parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", + help=" Save compare failed api output.", required=False) + parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true", + help=" whether to turn on jit compile", required=False) + parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set NPU device id to run ut", + default=0, required=False) + + +def _run_ut(): + parser = argparse.ArgumentParser() + _run_ut_parser(parser) + args = parser.parse_args(sys.argv[1:]) + torch.npu.set_compile_mode(jit_compile=args.jit_compile) + npu_device = "npu:" + str(args.device_id) + try: + torch.npu.set_device(npu_device) + except Exception: + print_error_log(f"Set NPU device id failed. device id is: {args.device_id}") + raise NotImplementedError + forward_file_checker = FileChecker(args.forward_input_file, FileCheckConst.FILE, ability=FileCheckConst.READ_ABLE, + file_type=FileCheckConst.JSON_SUFFIX) + forward_file = forward_file_checker.common_check() + backward_file_checker = FileChecker(args.backward_input_file, FileCheckConst.FILE, + ability=FileCheckConst.READ_ABLE, file_type=FileCheckConst.JSON_SUFFIX) + backward_file = backward_file_checker.common_check() + out_path = os.path.realpath(args.out_path) if args.out_path else "./" + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + save_error_data = args.save_error_data + if save_error_data: + initialize_save_error_data() + run_ut(forward_file, backward_file, out_path, save_error_data) + + +class UtDataInfo: + def __init__(self, bench_grad_out, npu_grad_out, npu_out, bench_out, grad_in, in_fwd_data_list): + self.bench_grad_out = bench_grad_out + self.npu_grad_out = npu_grad_out + self.npu_out = npu_out + self.bench_out = bench_out + self.grad_in = grad_in + self.in_fwd_data_list = in_fwd_data_list + +if __name__ == '__main__': + _run_ut() + print_info_log("UT task completed.") diff --git a/debug/accuracy_tools/ptdbg_ascend/__init__.py b/debug/accuracy_tools/ptdbg_ascend/__init__.py new file mode 100644 index 0000000000..6e8d69a7ab --- /dev/null +++ b/debug/accuracy_tools/ptdbg_ascend/__init__.py @@ -0,0 +1,17 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2022-2023. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +""" \ No newline at end of file diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py index a2c257256a..aeac1b0149 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/file_check_util.py @@ -108,6 +108,57 @@ class FileChecker: check_path_writability(self.file_path) +class FileOpen: + """ + The class for open file by a safe way. + + Attributes: + file_path: The file or dictionary path to be opened. + mode(str): The file open mode + """ + SUPPORT_READ_MODE = ["r", "rb"] + SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] + SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] + + def __init__(self, file_path, mode): + self.file_path = file_path + self.mode = mode + self._handle = None + + def __enter__(self): + self.check_file_path() + self._handle = open(self.file_path, self.mode) + return self._handle + + def __exit__(self, exc_type, exc_val, exc_tb): + if self._handle: + self._handle.close() + + def check_file_path(self): + support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE + if self.mode not in support_mode: + print_error_log("File open not support %s mode" % self.mode) + check_link(self.file_path) + check_path_length(self.file_path) + self.check_ability_and_owner() + check_path_pattern_vaild(self.file_path) + if os.path.exists(self.file_path): + check_common_file_size(self.file_path) + + def check_ability_and_owner(self): + if self.mode in self.SUPPORT_READ_MODE: + check_path_exists(self.file_path) + check_path_readability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path): + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path): + check_path_readability(self.file_path) + check_path_writability(self.file_path) + check_path_owner_consistent(self.file_path) + + def check_link(path): abs_path = os.path.abspath(path) if os.path.islink(abs_path): -- Gitee