From 6159c86a7142e66dd0e8a193d3ff873134c3ed6a Mon Sep 17 00:00:00 2001 From: l30036321 Date: Thu, 11 Jul 2024 14:54:29 +0800 Subject: [PATCH] code refactoring for data collect --- .../{pytorch => core}/common/exceptions.py | 0 .../{pytorch => core}/common/file_check.py | 34 +- debug/accuracy_tools/atat/core/common/log.py | 52 ++ .../atat/core/{ => common}/utils.py | 76 +-- .../accuracy_tools/atat/core/common_config.py | 35 +- .../data_dump}/data_collector.py | 115 +--- .../core/data_dump/data_processor/base.py | 262 +++++++++ .../core/data_dump/data_processor/factory.py | 65 +++ .../data_processor/mindspore_processor.py | 120 ++++ .../data_processor/pytorch_processor.py} | 529 ++++++------------ .../data_dump}/json_writer.py | 37 +- .../functional => core/data_dump}/scope.py | 3 +- .../atat/core/file_check_util.py | 319 ----------- debug/accuracy_tools/atat/core/log.py | 56 -- .../atat/mindspore/dump/api_kbk_dump.py | 8 +- .../atat/mindspore/dump/kernel_graph_dump.py | 8 +- .../atat/mindspore/ms_config.py | 2 +- .../kernel_graph_overflow_check.py | 10 +- .../atat/pytorch/advisor/advisor.py | 15 +- .../atat/pytorch/advisor/advisor_result.py | 13 +- .../api_accuracy_checker/common/config.py | 2 +- .../api_accuracy_checker/common/utils.py | 24 +- .../compare/api_precision_compare.py | 24 +- .../api_accuracy_checker/compare/compare.py | 7 +- .../compare/compare_utils.py | 7 +- .../api_accuracy_checker/dump/api_info.py | 0 .../run_ut/data_generate.py | 12 +- .../run_ut/multi_run_ut.py | 32 +- .../run_ut/run_overflow_check.py | 24 +- .../api_accuracy_checker/run_ut/run_ut.py | 28 +- .../api_accuracy_checker/test/run_ut.py | 0 .../test/ut/run_ut/test_data_generate.py | 0 .../atat/pytorch/common/__init__.py | 2 - .../accuracy_tools/atat/pytorch/common/log.py | 78 +-- .../atat/pytorch/common/parse_json.py | 2 +- .../atat/pytorch/common/recursive.py | 31 - .../atat/pytorch/common/utils.py | 3 +- .../atat/pytorch/compare/acc_compare.py | 54 +- .../pytorch/compare/distributed_compare.py | 15 +- .../atat/pytorch/compare/highlight.py | 2 +- .../atat/pytorch/compare/match.py | 4 +- .../atat/pytorch/compare/npy_compare.py | 5 +- .../atat/pytorch/debugger/debugger_config.py | 14 +- .../pytorch/debugger/precision_debugger.py | 10 +- .../atat/pytorch/free_benchmark/__init__.py | 4 +- .../pytorch/free_benchmark/common/params.py | 5 +- .../free_benchmark/compare/grad_saver.py | 10 +- .../compare/single_benchmark.py | 4 +- .../atat/pytorch/free_benchmark/main.py | 8 +- .../perturbed_layers/npu/add_noise.py | 15 +- .../perturbed_layers/npu/bit_noise.py | 15 +- .../perturbed_layers/npu/change_value.py | 6 +- .../perturbed_layers/npu/improve_precision.py | 4 +- .../perturbed_layers/npu/no_change.py | 4 +- .../perturbed_layers/run_cpu.py | 4 +- .../result_handlers/base_handler.py | 8 +- .../result_handlers/check_handler.py | 5 +- .../result_handlers/fix_handler.py | 4 +- .../result_handlers/handler_factory.py | 1 - .../result_handlers/preheat_handler.py | 8 +- .../atat/pytorch/functional/__init__.py | 4 - .../atat/pytorch/functional/dump_module.py | 13 +- .../atat/pytorch/functional/repair.py | 90 --- .../pytorch/functional/step_post_process.py | 43 -- .../atat/pytorch/hook_module/hook_module.py | 1 + .../atat/pytorch/hook_module/utils.py | 2 +- .../atat/pytorch/hook_module/wrap_aten.py | 2 +- .../pytorch/hook_module/wrap_distributed.py | 2 +- .../pytorch/hook_module/wrap_functional.py | 6 +- .../pytorch/hook_module/wrap_npu_custom.py | 2 +- .../atat/pytorch/hook_module/wrap_tensor.py | 2 +- .../atat/pytorch/hook_module/wrap_torch.py | 2 +- .../atat/pytorch/hook_module/wrap_vf.py | 2 +- .../atat/pytorch/module_processer.py | 2 +- .../accuracy_tools/atat/pytorch/pt_config.py | 5 +- debug/accuracy_tools/atat/pytorch/service.py | 64 +-- .../atat/test/core_ut/test_utils.py | 6 +- .../atat/test/mindspore_ut/test_ms_config.py | 2 +- .../atat/test/pytorch_ut/test_pt_config.py | 2 +- debug/accuracy_tools/atat/test/run_ut.py | 12 +- 80 files changed, 1080 insertions(+), 1428 deletions(-) rename debug/accuracy_tools/atat/{pytorch => core}/common/exceptions.py (100%) rename debug/accuracy_tools/atat/{pytorch => core}/common/file_check.py (89%) create mode 100644 debug/accuracy_tools/atat/core/common/log.py rename debug/accuracy_tools/atat/core/{ => common}/utils.py (89%) rename debug/accuracy_tools/atat/{pytorch/functional => core/data_dump}/data_collector.py (56%) create mode 100644 debug/accuracy_tools/atat/core/data_dump/data_processor/base.py create mode 100644 debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py create mode 100644 debug/accuracy_tools/atat/core/data_dump/data_processor/mindspore_processor.py rename debug/accuracy_tools/atat/{pytorch/functional/data_processor.py => core/data_dump/data_processor/pytorch_processor.py} (44%) rename debug/accuracy_tools/atat/{pytorch/functional => core/data_dump}/json_writer.py (80%) rename debug/accuracy_tools/atat/{pytorch/functional => core/data_dump}/scope.py (99%) delete mode 100644 debug/accuracy_tools/atat/core/file_check_util.py delete mode 100644 debug/accuracy_tools/atat/core/log.py delete mode 100644 debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py delete mode 100644 debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py delete mode 100644 debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py delete mode 100644 debug/accuracy_tools/atat/pytorch/common/recursive.py delete mode 100644 debug/accuracy_tools/atat/pytorch/functional/repair.py delete mode 100644 debug/accuracy_tools/atat/pytorch/functional/step_post_process.py diff --git a/debug/accuracy_tools/atat/pytorch/common/exceptions.py b/debug/accuracy_tools/atat/core/common/exceptions.py similarity index 100% rename from debug/accuracy_tools/atat/pytorch/common/exceptions.py rename to debug/accuracy_tools/atat/core/common/exceptions.py diff --git a/debug/accuracy_tools/atat/pytorch/common/file_check.py b/debug/accuracy_tools/atat/core/common/file_check.py similarity index 89% rename from debug/accuracy_tools/atat/pytorch/common/file_check.py rename to debug/accuracy_tools/atat/core/common/file_check.py index 3204652583..29ad59d28e 100644 --- a/debug/accuracy_tools/atat/pytorch/common/file_check.py +++ b/debug/accuracy_tools/atat/core/common/file_check.py @@ -17,7 +17,7 @@ import os import re -from .log import print_error_log, print_warn_log +from .log import logger from .exceptions import FileCheckException from .utils import Const @@ -78,7 +78,7 @@ class FileChecker: @staticmethod def _check_path_type(path_type): if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: - print_error_log(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') + logger.error(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) return path_type @@ -144,7 +144,7 @@ class FileOpen: def check_file_path(self): support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE if self.mode not in support_mode: - print_error_log("File open not support %s mode" % self.mode) + logger.error("File open not support %s mode" % self.mode) raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) check_link(self.file_path) self.file_path = os.path.realpath(self.file_path) @@ -171,7 +171,7 @@ class FileOpen: def check_link(path): abs_path = os.path.abspath(path) if os.path.islink(abs_path): - print_error_log('The file path {} is a soft link.'.format(path)) + logger.error('The file path {} is a soft link.'.format(path)) raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) @@ -179,58 +179,58 @@ def check_path_length(path, name_length=None): file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ len(os.path.basename(path)) > file_max_name_length: - print_error_log('The file path length exceeds limit.') + logger.error('The file path length exceeds limit.') raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_exists(path): if not os.path.exists(path): - print_error_log('The file path %s does not exist.' % path) + logger.error('The file path %s does not exist.' % path) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_readability(path): if not os.access(path, os.R_OK): - print_error_log('The file path %s is not readable.' % path) + logger.error('The file path %s is not readable.' % path) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_writability(path): if not os.access(path, os.W_OK): - print_error_log('The file path %s is not writable.' % path) + logger.error('The file path %s is not writable.' % path) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_executable(path): if not os.access(path, os.X_OK): - print_error_log('The file path %s is not executable.' % path) + logger.error('The file path %s is not executable.' % path) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_other_user_writable(path): st = os.stat(path) if st.st_mode & 0o002: - print_error_log('The file path %s may be insecure because other users have write permissions. ' % path) + logger.error('The file path %s may be insecure because other users have write permissions. ' % path) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid if file_owner != os.getuid(): - print_error_log('The file path %s may be insecure because is does not belong to you.' % path) + logger.error('The file path %s may be insecure because is does not belong to you.' % path) raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_pattern_vaild(path): if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - print_error_log('The file path {} contains special characters.'.format(path)) + logger.error('The file path {} contains special characters.'.format(path)) raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_file_size(file_path, max_size): file_size = os.path.getsize(file_path) if file_size >= max_size: - print_error_log(f'The size of file path {file_path} exceeds {max_size} bytes.') + logger.error(f'The size of file path {file_path} exceeds {max_size} bytes.') raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) @@ -245,18 +245,18 @@ def check_common_file_size(file_path): def check_file_suffix(file_path, file_suffix): if file_suffix: if not file_path.endswith(file_suffix): - print_error_log(f"The {file_path} should be a {file_suffix} file!") + logger.error(f"The {file_path} should be a {file_suffix} file!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) def check_path_type(file_path, file_type): if file_type == FileCheckConst.FILE: if not os.path.isfile(file_path): - print_error_log(f"The {file_path} should be a file!") + logger.error(f"The {file_path} should be a file!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) if file_type == FileCheckConst.DIR: if not os.path.isdir(file_path): - print_error_log(f"The {file_path} should be a dictionary!") + logger.error(f"The {file_path} should be a dictionary!") raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) @@ -298,4 +298,4 @@ def change_mode(path, mode): def path_len_exceeds_limit(file_path): return len(os.path.realpath(file_path)) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH + len(os.path.basename(file_path)) > FileCheckConst.FILE_NAME_LENGTH \ No newline at end of file diff --git a/debug/accuracy_tools/atat/core/common/log.py b/debug/accuracy_tools/atat/core/common/log.py new file mode 100644 index 0000000000..445f948165 --- /dev/null +++ b/debug/accuracy_tools/atat/core/common/log.py @@ -0,0 +1,52 @@ +import os +import time +import sys + + +class BaseLogger: + def __init__(self): + self.warning_level = "WARNING" + self.error_level = "ERROR" + self.info_level = "INFO" + + def get_rank(self): + """ + This should be implemented by subclasses to return the appropriate rank. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + def _print_log(self, level, msg, end='\n'): + """ + Print a log message with the given level and message. + """ + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + pid = os.getpid() + full_msg = f"{current_time} ({pid}) [{level}] {msg}" + print(full_msg, end=end) + sys.stdout.flush() + + def info(self, msg): + self._print_log(self.info_level, msg) + + def error(self, msg): + self._print_log(self.error_level, msg) + + def warning(self, msg): + self._print_log(self.warning_level, msg) + + def on_rank_0(self, func): + def wrapper(*args, **kwargs): + if self.get_rank() == 0: + return func(*args, **kwargs) + return wrapper + + def info_on_rank_0(self, msg): + return self.on_rank_0(self.info)(msg) + + def error_on_rank_0(self, msg): + return self.on_rank_0(self.error)(msg) + + def warning_on_rank_0(self, msg): + return self.on_rank_0(self.warning)(msg) + +logger = BaseLogger() \ No newline at end of file diff --git a/debug/accuracy_tools/atat/core/utils.py b/debug/accuracy_tools/atat/core/common/utils.py similarity index 89% rename from debug/accuracy_tools/atat/core/utils.py rename to debug/accuracy_tools/atat/core/common/utils.py index 47e54ffa6d..819034b5b5 100644 --- a/debug/accuracy_tools/atat/core/utils.py +++ b/debug/accuracy_tools/atat/core/common/utils.py @@ -26,8 +26,8 @@ from datetime import datetime, timezone from pathlib import Path import numpy as np -from .file_check_util import FileOpen, FileChecker, FileCheckConst -from .log import print_info_log, print_warn_log, print_error_log +from .file_check import FileOpen, FileChecker, FileCheckConst +from .log import logger device = collections.namedtuple('device', ['type', 'index']) @@ -104,6 +104,8 @@ class Const: TENSOR = "tensor" OVERFLOW_CHECK = "overflow_check" FREE_BENCHMARK = "free_benchmark" + DATA = "data" + class CompareConst: """ @@ -263,12 +265,12 @@ def make_dump_path_if_not_exists(dump_path): try: Path(dump_path).mkdir(mode=0o750, exist_ok=True, parents=True) except OSError as ex: - print_error_log( + logger.error( 'Failed to create {}.Please check the path permission or disk space .{}'.format(dump_path, str(ex))) raise CompareException(CompareException.INVALID_PATH_ERROR) from ex else: if not os.path.isdir(dump_path): - print_error_log('{} already exists and is not a directory.'.format(dump_path)) + logger.error('{} already exists and is not a directory.'.format(dump_path)) def check_mode_valid(mode, scope=None, api_list=None): @@ -300,13 +302,13 @@ def check_mode_valid(mode, scope=None, api_list=None): def check_switch_valid(switch): if switch not in ["ON", "OFF"]: - print_error_log("Please set switch with 'ON' or 'OFF'.") + logger.error("Please set switch with 'ON' or 'OFF'.") raise CompareException(CompareException.INVALID_PARAM_ERROR) def check_dump_mode_valid(dump_mode): if not isinstance(dump_mode, list): - print_warn_log("Please set dump_mode as a list.") + logger.error("Please set dump_mode as a list.") dump_mode = [dump_mode] if not all(mode in ["all", "forward", "backward", "input", "output"] for mode in dump_mode): raise ValueError("Please set dump_mode as a list containing one or more of the following: 'all', 'forward', 'backward', 'input', 'output'.") @@ -327,14 +329,14 @@ def check_summary_mode_valid(summary_mode): def check_summary_only_valid(summary_only): if not isinstance(summary_only, bool): - print_error_log("Params summary_only only support True or False.") + logger.error("Params summary_only only support True or False.") raise CompareException(CompareException.INVALID_PARAM_ERROR) return summary_only def check_compare_param(input_parma, output_path, stack_mode=False, summary_compare=False, md5_compare=False): if not (isinstance(input_parma, dict) and isinstance(output_path, str)): - print_error_log("Invalid input parameters") + logger.error("Invalid input parameters") raise CompareException(CompareException.INVALID_PARAM_ERROR) check_file_or_directory_path(input_parma.get("npu_json_path"), False) check_file_or_directory_path(input_parma.get("bench_json_path"), False) @@ -351,7 +353,7 @@ def check_compare_param(input_parma, output_path, stack_mode=False, summary_comp def check_configuration_param(stack_mode=False, auto_analyze=True, fuzzy_match=False): if not (isinstance(stack_mode, bool) and isinstance(auto_analyze, bool) and isinstance(fuzzy_match, bool)): - print_error_log("Invalid input parameters which should be only bool type.") + logger.error("Invalid input parameters which should be only bool type.") raise CompareException(CompareException.INVALID_PARAM_ERROR) @@ -379,7 +381,7 @@ def is_starts_with(string, prefix_list): def _check_json(json_file_handle, file_name): tensor_line = json_file_handle.readline() if not tensor_line: - print_error_log("dump file {} have empty line!".format(file_name)) + logger.error("dump file {} have empty line!".format(file_name)) raise CompareException(CompareException.INVALID_DUMP_FILE) json_file_handle.seek(0, 0) @@ -394,10 +396,10 @@ def check_file_size(input_file, max_size): try: file_size = os.path.getsize(input_file) except OSError as os_error: - print_error_log('Failed to open "%s". %s' % (input_file, str(os_error))) + logger.error('Failed to open "%s". %s' % (input_file, str(os_error))) raise CompareException(CompareException.INVALID_FILE_ERROR) from os_error if file_size > max_size: - print_error_log('The size (%d) of %s exceeds (%d) bytes, tools not support.' + logger.error('The size (%d) of %s exceeds (%d) bytes, tools not support.' % (file_size, input_file, max_size)) raise CompareException(CompareException.INVALID_FILE_ERROR) @@ -437,7 +439,7 @@ def remove_path(path): else: shutil.rmtree(path) except PermissionError as err: - print_error_log("Failed to delete {}. Please check the permission.".format(path)) + logger.error("Failed to delete {}. Please check the permission.".format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) from err @@ -484,7 +486,7 @@ def create_directory(dir_path): try: os.makedirs(dir_path, mode=0o700) except OSError as ex: - print_error_log( + logger.error( 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) raise CompareException(CompareException.INVALID_PATH_ERROR) from ex @@ -498,7 +500,7 @@ def execute_command(cmd): Exception Description: when invalid command throw exception """ - print_info_log('Execute command:%s' % cmd) + logger.info('Execute command:%s' % cmd) process = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) while process.poll() is None: line = process.stdout.readline() @@ -506,7 +508,7 @@ def execute_command(cmd): if line: print(line) if process.returncode != 0: - print_error_log('Failed to execute command:%s' % " ".join(cmd)) + logger.error('Failed to execute command:%s' % " ".join(cmd)) raise CompareException(CompareException.INVALID_DATA_ERROR) @@ -530,7 +532,7 @@ def parse_value_by_comma(value): if value_str.isdigit() or value_str == '-1': value_list.append(int(value_str)) else: - print_error_log("please check your input shape.") + logger.error("please check your input shape.") raise CompareException(CompareException.INVALID_PARAM_ERROR) return value_list @@ -539,7 +541,7 @@ def get_data_len_by_shape(shape): data_len = 1 for item in shape: if item == -1: - print_error_log("please check your input shape, one dim in shape is -1.") + logger.error("please check your input shape, one dim in shape is -1.") return -1 data_len = data_len * item return data_len @@ -564,25 +566,25 @@ def format_value(value): def check_seed_all(seed, mode): if isinstance(seed, int): if seed < 0 or seed > Const.MAX_SEED_VALUE: - print_error_log(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + logger.error(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") raise CompareException(CompareException.INVALID_PARAM_ERROR) else: - print_error_log(f"Seed must be integer.") + logger.error(f"Seed must be integer.") raise CompareException(CompareException.INVALID_PARAM_ERROR) if not isinstance(mode, bool): - print_error_log(f"seed_all mode must be bool.") + logger.error(f"seed_all mode must be bool.") raise CompareException(CompareException.INVALID_PARAM_ERROR) def get_process_rank(model): - print_info_log("Rank id is not provided. Trying to get the rank id of the model.") + logger.info("Rank id is not provided. Trying to get the rank id of the model.") try: local_device = next(model.parameters()).device except StopIteration: - print_warn_log('There is no parameter in the model. Fail to get rank id.') + logger.warning('There is no parameter in the model. Fail to get rank id.') return 0, False if local_device.type == 'cpu': - print_warn_log("Warning: the debugger is unable to get the rank id. " + logger.warning("Warning: the debugger is unable to get the rank id. " "This may cause the dumpped data to be corrupted in the " "case of distributed training. (You may ignore this if you are using only one card.) " "Transfer the model to npu or gpu before register_hook() to avoid this warning.") @@ -603,43 +605,43 @@ def generate_compare_script(dump_path, pkl_file_path, dump_switch_mode): code_temp = ftemp.read() fout.write(code_temp % (pkl_file_path, dump_path, is_api_stack)) except OSError: - print_error_log(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.") + logger.error(f"Failed to open file. Please check file {template_path} or path {pkl_dir}.") - print_info_log(f"Generate compare script successfully which is {compare_script_path}.") + logger.info(f"Generate compare script successfully which is {compare_script_path}.") def check_file_valid(file_path): if os.path.islink(file_path): - print_error_log('The file path {} is a soft link.'.format(file_path)) + logger.error('The file path {} is a soft link.'.format(file_path)) raise CompareException(CompareException.INVALID_PATH_ERROR) if len(os.path.realpath(file_path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(file_path)) > \ Const.FILE_NAME_LENGTH: - print_error_log('The file path length exceeds limit.') + logger.error('The file path length exceeds limit.') raise CompareException(CompareException.INVALID_PATH_ERROR) if not re.match(Const.FILE_PATTERN, os.path.realpath(file_path)): - print_error_log('The file path {} contains special characters.'.format(file_path)) + logger.error('The file path {} contains special characters.'.format(file_path)) raise CompareException(CompareException.INVALID_PATH_ERROR) if os.path.isfile(file_path): file_size = os.path.getsize(file_path) if file_path.endswith(Const.PKL_SUFFIX) and file_size > Const.ONE_GB: - print_error_log('The file {} size is greater than 1GB.'.format(file_path)) + logger.error('The file {} size is greater than 1GB.'.format(file_path)) raise CompareException(CompareException.INVALID_PATH_ERROR) if file_path.endswith(Const.NUMPY_SUFFIX) and file_size > Const.TEN_GB: - print_error_log('The file {} size is greater than 10GB.'.format(file_path)) + logger.error('The file {} size is greater than 10GB.'.format(file_path)) raise CompareException(CompareException.INVALID_PATH_ERROR) def check_path_before_create(path): if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ Const.FILE_NAME_LENGTH: - print_error_log('The file path length exceeds limit.') + logger.error('The file path length exceeds limit.') raise CompareException(CompareException.INVALID_PATH_ERROR) if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): - print_error_log('The file path {} contains special characters.'.format(path)) + logger.error('The file path {} contains special characters.'.format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) @@ -667,14 +669,14 @@ def task_dumppath_get(input_param): npu_json_path = input_param.get("npu_json_path", None) bench_json_path = input_param.get("bench_json_path", None) if not npu_json_path or not bench_json_path: - print_error_log(f"Please check the json path is valid.") + logger.error(f"Please check the json path is valid.") raise CompareException(CompareException.INVALID_PATH_ERROR) with FileOpen(npu_json_path, 'r') as npu_f: npu_json_data = json.load(npu_f) with FileOpen(bench_json_path, 'r') as bench_f: bench_json_data = json.load(bench_f) if npu_json_data['task'] != bench_json_data['task']: - print_error_log(f"Please check the dump task is consistent.") + logger.error(f"Please check the dump task is consistent.") raise CompareException(CompareException.INVALID_TASK_ERROR) if npu_json_data['task'] == Const.TENSOR: summary_compare = False @@ -686,7 +688,7 @@ def task_dumppath_get(input_param): else: summary_compare = True else: - print_error_log(f"Compare is not required for overflow_check or free_benchmark.") + logger.error(f"Compare is not required for overflow_check or free_benchmark.") raise CompareException(CompareException.INVALID_TASK_ERROR) input_param['npu_dump_data_dir'] = npu_json_data['dump_data_dir'] input_param['bench_dump_data_dir'] = bench_json_data['dump_data_dir'] @@ -699,6 +701,6 @@ def get_header_index(header_name, summary_compare=False): else: header = CompareConst.COMPARE_RESULT_HEADER[:] if header_name not in header: - print_error_log(f"{header_name} not in data name") + logger.error(f"{header_name} not in data name") raise CompareException(CompareException.INVALID_PARAM_ERROR) return header.index(header_name) diff --git a/debug/accuracy_tools/atat/core/common_config.py b/debug/accuracy_tools/atat/core/common_config.py index ee045d3c52..93ceffe72b 100644 --- a/debug/accuracy_tools/atat/core/common_config.py +++ b/debug/accuracy_tools/atat/core/common_config.py @@ -1,7 +1,8 @@ -from .utils import Const +from .common.utils import Const +from .common.exceptions import MsaccException +from .common.log import logger -# 公共配置类 class CommonConfig: def __init__(self, json_config): self.task = json_config.get('task') @@ -17,19 +18,26 @@ class CommonConfig: def _check_config(self): if self.task and self.task not in Const.TASK_LIST: - raise Exception("task is invalid") + logger.error("task is invalid, it should be one of {}".format(Const.TASK_LIST)) + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if self.rank is not None and not isinstance(self.rank, list): - raise Exception("rank is invalid") + logger.error("rank is invalid, it should be a list") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if self.step is not None and not isinstance(self.step, list): - raise Exception("step is invalid") + logger.error("step is invalid, it should be a list") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if self.level and self.level not in Const.LEVEL_LIST: - raise Exception("level is invalid") + logger.error("level is invalid, it should be one of {}".format(Const.LEVEL_LIST)) + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if self.seed is not None and not isinstance(self.seed, int): - raise Exception("seed is invalid") + logger.error("seed is invalid, it should be an integer") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if not isinstance(self.is_deterministic, bool): - raise Exception("is_deterministic is invalid") + logger.error("is_deterministic is invalid, it should be a boolean") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if not isinstance(self.enable_dataloader, bool): - raise Exception("enable_dataloader is invalid") + logger.error("enable_dataloader is invalid, it should be a boolean") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) # 基础配置类 @@ -46,9 +54,12 @@ class BaseConfig: def check_config(self): if self.scope is not None and not isinstance(self.scope, list): - raise Exception("scope is invalid") + logger.error("scope is invalid, it should be a list") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if self.list is not None and not isinstance(self.list, list): - raise Exception("list is invalid") + logger.error("list is invalid, it should be a list") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if self.data_mode is not None and not isinstance(self.data_mode, list): - raise Exception("data_mode is invalid") + logger.error("data_mode is invalid, it should be a list") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_collector.py b/debug/accuracy_tools/atat/core/data_dump/data_collector.py similarity index 56% rename from debug/accuracy_tools/atat/pytorch/functional/data_collector.py rename to debug/accuracy_tools/atat/core/data_dump/data_collector.py index 8e4011a054..6eb5b76681 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_collector.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_collector.py @@ -1,20 +1,11 @@ -import os -import torch +import os -from .data_processor import build_data_processor, DataProcessor +from .scope import build_scope, ListScope from .json_writer import DataWriter -from .scope import build_scope, ListScope -from ..common.log import print_info_log, print_warn_log +from ..common.log import logger from ..common.utils import Const -from ..module_processer import ModuleProcesser - -try: - import torch_npu -except ImportError: - pass - -forward_init_status = False +from .data_processor.factory import DataProcessorFactory def build_data_collector(config): @@ -22,19 +13,17 @@ def build_data_collector(config): class DataCollector: - overflow_task = "overflow_check" - tensor_task = "tensor" - freebenchmark_task = "free_benchmark" multi_output_apis = ["_sort_", "npu_flash_attention"] - tasks_need_tensor_data = [overflow_task, tensor_task, freebenchmark_task] + tasks_need_tensor_data = [Const.OVERFLOW_CHECK, Const.TENSOR, Const.FREE_BENCHMARK] level_without_construct = ["L1", "L2"] def __init__(self, config): self.config = config self.data_writer = DataWriter() - self.data_processor = build_data_processor(config, self.data_writer) + self.data_processor = DataProcessorFactory.create_processor(self.config, self.data_writer) + self.module_processor = DataProcessorFactory.get_module_processor(self.config.framework) if self.config.framework == Const.PT_FRAMEWORK else None self.module_count = {} - if config.task == DataCollector.freebenchmark_task: + if self.config.task == Const.FREE_BENCHMARK: self.scope = build_scope(ListScope, self.config.scope, self.config.list) else: self.scope = build_scope(None, self.config.scope, self.config.list) @@ -46,7 +35,7 @@ class DataCollector: @property def dump_file_path(self): return self.data_writer.dump_file_path - + @staticmethod def check_scope_and_pid(scope, name, pid): return (not scope or scope.check(name)) and pid == os.getpid() @@ -54,10 +43,10 @@ class DataCollector: @staticmethod def is_inplace(module): return getattr(module, "op_is_inplace", False) - + def if_return_forward_new_output(self): return self.data_processor.if_return_forward_new_output() - + def get_forward_new_output(self): return self.data_processor.get_forward_new_output() @@ -68,7 +57,7 @@ class DataCollector: self.data_writer.write_json() def update_data(self, data_info, msg=''): - if self.config.task == DataProcessor.overflow: + if self.config.task == Const.OVERFLOW_CHECK: if self.data_processor.has_overflow: self.data_writer.update_data(data_info) msg += "Overflow detected." @@ -79,12 +68,12 @@ class DataCollector: return msg def pre_forward_data_collect(self, name, module, pid, module_input_output): - backward_name = name.replace("forward", "backward") + backward_name = name.replace(Const.FORWARD, Const.BACKWARD) if self.check_scope_and_pid(self.scope, backward_name, pid): self.data_processor.analyze_pre_forward(backward_name, module, module_input_output) if not self.is_inplace(module): return - print_info_log(f"API {name} is inplace.") + logger.info(f"API {name} is inplace.") if self.check_scope_and_pid(self.scope, name, pid): data_info = self.data_processor.analyze_pre_forward_inplace(name, module_input_output) self.update_data(data_info) @@ -94,14 +83,14 @@ class DataCollector: if not self.check_scope_and_pid(self.scope, name, pid): return - if self.config.level == "L2": - self.acl_dump(module, module_input_output, name) - return - if not self.is_inplace(module): data_info = self.data_processor.analyze_forward(name, module, module_input_output) else: data_info = self.data_processor.analyze_forward_inplace(name, module_input_output) + + if self.config.level == "L2": + return + self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) self.handle_data(name, data_info) @@ -115,14 +104,15 @@ class DataCollector: def update_construct(self, name): if self.config.level not in DataCollector.level_without_construct: - self.data_writer.update_construct({name: ModuleProcesser.api_parent_node}) - self.data_writer.update_construct(ModuleProcesser.module_node) + self.data_writer.update_construct({name: self.module_processor.api_parent_node}) + self.data_writer.update_construct(self.module_processor.module_node) def handle_data(self, name, data_info): msg = f"msProbe is collecting data on {name}. " + if data_info: msg = self.update_data(data_info, msg) - print_info_log(msg) + logger.info(msg) self.data_writer.flush_data_when_buffer_is_full() def module_count_func(self, name, name_template): @@ -148,65 +138,6 @@ class DataCollector: def update_dump_paths(self, *args): self.data_writer.update_dump_paths(*args) self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level) - + def update_iter(self, current_iter): self.data_processor.update_iter(current_iter) - - def acl_dump(self, module, module_input_output, module_name): - if self.config.is_forward_acl_dump: - self.forward_acl_dump(module, module_input_output, module_name) - else: - self.dump_mode_backward_acl_dump(module, module_input_output, module_name) - - def op_need_trigger(self, module_name): - if 'Tensor___getitem___' in module_name: - return True - return False - - def forward_acl_dump(self, module, module_input_output, module_name): - global forward_init_status - if not forward_init_status: - forward_init_status = True - torch_npu.npu.synchronize() - torch_npu.npu.init_dump() - torch_npu.npu.set_dump(self.config.acl_config) - torch_npu.npu.synchronize() - if self.op_need_trigger(module_name): - module.forward(*module_input_output.args, **module_input_output.kwargs).cpu() - else: - module.forward(*module_input_output.args, **module_input_output.kwargs) - torch_npu.npu.synchronize() - torch_npu.npu.finalize_dump() - torch_npu.npu.synchronize() - forward_init_status = False - print_info_log("Dump %s op file." % module_name) - - def acl_backward_dump_status(self, output, grad, module_name): - if isinstance(output, torch.Tensor): - output.backward(grad, retain_graph=True) - return True - - for api_name in DataCollector.multi_output_apis: - if api_name in module_name: - output[0].backward(grad, retain_graph=True) - return True - return False - - def dump_mode_backward_acl_dump(self, module, module_input_output, module_name): - global forward_init_status - grad_path = self.config.backward_input.get(module_name) - if not forward_init_status: - forward_init_status = True - output = module.forward(*module_input_output.args, **module_input_output.kwargs) - grad = torch.load(grad_path).to("npu").requires_grad_() - torch_npu.npu.init_dump() - torch_npu.npu.set_dump(self.config.acl_config) - torch_npu.npu.synchronize() - if not self.acl_backward_dump_status(output, grad, module_name): - print_warn_log("The output of {} is not of tensor type and cannot be automatically derived. " - "you can manually construct a single API backward case for ACL dump.".format( - module_name)) - torch_npu.npu.synchronize() - torch_npu.npu.finalize_dump() - forward_init_status = False - print_info_log("Dump %s op file." % module_name) diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py new file mode 100644 index 0000000000..9785bb53aa --- /dev/null +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/base.py @@ -0,0 +1,262 @@ +import os +import inspect +from dataclasses import dataclass +from typing import Tuple, Dict, Optional, Any + +import numpy as np +from ...common.log import logger +from ...common.utils import Const + + +@dataclass +class ModuleForwardInputsOutputs: + args: Optional[Tuple] + kwargs: Optional[Dict] + output: Any + + @property + def args_tuple(self): + if not isinstance(self.args, tuple): + return (self.args, ) + else: + return self.args + + @property + def output_tuple(self): + if not isinstance(self.output, tuple): + return (self.output, ) + else: + return self.output + + def concat_args_and_kwargs(self): + args = self.args + tuple(self.kwargs.values()) + return args + + +@dataclass +class ModuleBackwardInputsOutputs: + grad_output: Optional[Tuple] + grad_input: Optional[Tuple] + + @property + def grad_input_tuple(self): + if not isinstance(self.grad_input, tuple): + return (self.grad_input, ) + else: + return self.grad_input + + @property + def grad_output_tuple(self): + if not isinstance(self.grad_output, tuple): + return (self.grad_output, ) + else: + return self.grad_output + + +class TensorStatInfo: + def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): + self.max = max_val + self.min = min_val + self.mean = mean_val + self.norm = norm_val + + +class BaseDataProcessor: + _recursive_key_stack = [] + special_type = ( + np.integer, np.floating, np.bool_, np.complexfloating, np.str_, np.byte, np.unicode_, + bool, int, float, str, slice + ) + + def __init__(self, config, data_writer): + self.data_writer = data_writer + self.config = config + self.api_info_struct = {} + self.stack_info_struct = {} + self.current_api_or_module_name = None + self.api_data_category = None + self.has_overflow = False + self.current_iter = 0 + self._return_forward_new_output = False + self._forward_new_output = None + + @staticmethod + def analyze_api_call_stack(name): + stack_str = [] + for (_, path, line, func, code, _) in inspect.stack()[5:]: + if not code: + continue + stack_line = " ".join([ + "File", ", ".join([ + path, + " ".join(["line", str(line)]), + " ".join(["in", func]), + " ".join(["\n", code[0].strip()]) + ]) + ]) + stack_str.append(stack_line) + stack_info_struct = {name: stack_str} + return stack_info_struct + + @staticmethod + def _convert_numpy_to_builtin(arg): + type_mapping = { + np.integer: int, + np.floating: float, + np.bool_: bool, + np.complexfloating: complex, + np.str_: str, + np.byte: bytes, + np.unicode_: str + } + for numpy_type, builtin_type in type_mapping.items(): + if isinstance(arg, numpy_type): + return builtin_type(arg), type(arg).__name__ + return arg, '' + + @staticmethod + def _analyze_numpy(value, numpy_type): + single_arg = {} + single_arg.update({"type": numpy_type}) + single_arg.update({"value": value}) + return single_arg + + @staticmethod + def _analyze_builtin(arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({"type": "slice"}) + single_arg.update({"value": [arg.start, arg.stop, arg.step]}) + else: + single_arg.update({"type": type(arg).__name__}) + single_arg.update({"value": arg}) + return single_arg + + @classmethod + def get_special_types(cls): + return cls.special_type + + @classmethod + def recursive_apply_transform(cls, args, transform): + if isinstance(args, cls.get_special_types()): + arg_transform = transform(args, cls._recursive_key_stack) + return arg_transform + elif isinstance(args, (list, tuple)): + transform_result = [] + for i, arg in enumerate(args): + cls._recursive_key_stack.append(str(i)) + transform_result.append(cls.recursive_apply_transform(arg, transform)) + cls._recursive_key_stack.pop() + return type(args)(transform_result) + elif isinstance(args, dict): + transform_result = {} + for k, arg in args.items(): + cls._recursive_key_stack.append(str(k)) + transform_result[k] = cls.recursive_apply_transform(arg, transform) + cls._recursive_key_stack.pop() + return transform_result + else: + logger.warning(f"Data type {type(args)} is not supported.") + + def if_return_forward_new_output(self): + return self._return_forward_new_output + + def get_forward_new_output(self): + self._return_forward_new_output = False + return self._forward_new_output + + def update_iter(self, current_iter): + self.current_iter = current_iter + + def visit_and_clear_overflow_status(self, api_or_module_name): + if self.current_api_or_module_name != api_or_module_name: + self.current_api_or_module_name = api_or_module_name + self.has_overflow = False + + def is_dump_for_data_mode(self, forward_backward, input_output): + """ + Compare the parameters with data_mode to determine whether to dump. + + Args: + forward_backward(str): The forward or backward mode to check. + input_output(str): The input or output mode to check. + + Return: + bool: True if the parameters are in data_mode or data_mode is all, False otherwise. + """ + return (Const.ALL in self.config.data_mode or + forward_backward in self.config.data_mode or + input_output in self.config.data_mode) + + def analyze_pre_forward(self, name, module, + module_input_output: ModuleForwardInputsOutputs): + pass + + def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): + api_info_struct = {} + if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): # check whether data_mode contains forward or input + api_info_struct[name] = {} + self.api_data_category = Const.INPUT + args_info_list = self.analyze_element(module_input_output.args_tuple) + api_info_struct[name][Const.INPUT_ARGS] = args_info_list + + self.api_data_category = Const.KWARGS + kwargs_info_list = self.analyze_element(module_input_output.kwargs) + api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list + + if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): # check whether data_mode contains forward or output + api_info_struct[name] = api_info_struct.get(name, {}) + self.api_data_category = Const.OUTPUT + output_info_list = self.analyze_element(module_input_output.output_tuple) + api_info_struct[name][Const.OUTPUT] = output_info_list + + return api_info_struct + + def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs): + api_info_struct = {} + if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): + api_info_struct[name] = {} + self.api_data_category = Const.INPUT + args_info_list = self.analyze_element(module_input_output.args_tuple) + api_info_struct[name][Const.INPUT_ARGS] = args_info_list + + self.api_data_category = Const.KWARGS + kwargs_info_list = self.analyze_element(module_input_output.kwargs) + api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list + + return api_info_struct + + def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs): + concat_args = module_input_output.concat_args_and_kwargs() + api_info_struct = {} + if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): + api_info_struct[name] = {} + self.api_data_category = Const.OUTPUT + output_info_list = self.analyze_element(concat_args) + api_info_struct[name][Const.OUTPUT] = output_info_list + + return api_info_struct + + def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): + api_info_struct = {} + if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT): + api_info_struct[name] = {} + self.api_data_category = Const.OUTPUT + input_info_list = self.analyze_element(module_input_output.grad_input_tuple) + api_info_struct[name][Const.GRAD_INPUT] = input_info_list + + if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT): + api_info_struct[name] = api_info_struct.get(name, {}) + self.api_data_category = Const.INPUT + output_info_list = self.analyze_element(module_input_output.grad_output_tuple) + api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list + + return api_info_struct + + def get_save_file_path(self, suffix): + file_format = "pt" if self.config.framework == Const.PT_FRAMEWORK else "npy" + self.data_path = self.data_writer.dump_tensor_data_dir + dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + + suffix + Const.SEP + file_format) + file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + return dump_data_name, file_path \ No newline at end of file diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py new file mode 100644 index 0000000000..b9d1618938 --- /dev/null +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/factory.py @@ -0,0 +1,65 @@ +from ...common.utils import Const + + +class DataProcessorFactory: + _data_processor = {} + _module_processor = {} + + @classmethod + def register_processor(cls, framework, task, processor_class): + key = (framework, task) + cls._data_processor[key] = processor_class + + @classmethod + def register_module_processor(cls, framework, processor_class): + cls._module_processor[framework] = processor_class + + @classmethod + def get_module_processor(cls, framework): + processor_class = cls._module_processor.get(framework) + if not processor_class: + raise ValueError(f"ModuleProcesser not found for framework: {framework}") + return processor_class + + @classmethod + def create_processor(cls, config, data_writer): + cls.register_processors(config.framework) + + task = Const.KERNEL_DUMP if config.level == "L2" else config.task + key = (config.framework, task) + processor_class = cls._data_processor.get(key) + if not processor_class: + raise ValueError(f"Processor not found for framework: {config.framework}, task: {config.task}") + return processor_class(config, data_writer) + + @classmethod + def register_processors(cls, framework): + if framework == Const.PT_FRAMEWORK: + from .pytorch_processor import ( + StatisticsDataProcessor as PytorchStatisticsDataProcessor, + TensorDataProcessor as PytorchTensorDataProcessor, + OverflowCheckDataProcessor as PytorchOverflowCheckDataProcessor, + FreeBenchmarkDataProcessor as PytorchFreeBenchmarkDataProcessor, + KernelDumpDataProcessor as PytorchKernelDumpDataProcessor + ) + from ....pytorch.module_processer import ModuleProcesser + + cls.register_processor(Const.PT_FRAMEWORK, Const.STATISTICS, PytorchStatisticsDataProcessor) + cls.register_processor(Const.PT_FRAMEWORK, Const.TENSOR, PytorchTensorDataProcessor) + cls.register_processor(Const.PT_FRAMEWORK, Const.OVERFLOW_CHECK, PytorchOverflowCheckDataProcessor) + cls.register_processor(Const.PT_FRAMEWORK, Const.FREE_BENCHMARK, PytorchFreeBenchmarkDataProcessor) + cls.register_processor(Const.PT_FRAMEWORK, Const.KERNEL_DUMP, PytorchKernelDumpDataProcessor) + cls.register_module_processor(Const.PT_FRAMEWORK, ModuleProcesser) + + elif framework == Const.MS_FRAMEWORK: + from .mindspore_processor import ( + StatisticsDataProcessor as MindsporeStatisticsDataProcessor, + TensorDataProcessor as MindsporeTensorDataProcessor, + OverflowCheckDataProcessor as MindsporeOverflowCheckDataProcessor, + FreeBenchmarkDataProcessor as MindsporeFreeBenchmarkDataProcessor + ) + + cls.register_processor(Const.MS_FRAMEWORK, Const.STATISTICS, MindsporeStatisticsDataProcessor) + cls.register_processor(Const.MS_FRAMEWORK, Const.TENSOR, MindsporeTensorDataProcessor) + cls.register_processor(Const.MS_FRAMEWORK, Const.OVERFLOW_CHECK, MindsporeOverflowCheckDataProcessor) + cls.register_processor(Const.MS_FRAMEWORK, Const.FREE_BENCHMARK, MindsporeFreeBenchmarkDataProcessor) \ No newline at end of file diff --git a/debug/accuracy_tools/atat/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/mindspore_processor.py new file mode 100644 index 0000000000..2b5196eea2 --- /dev/null +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/mindspore_processor.py @@ -0,0 +1,120 @@ +import os +import zlib +import mindspore as ms +import numpy as np + +from ...common.utils import Const +from .base import BaseDataProcessor + + +class MindsporeDataProcessor(BaseDataProcessor): + mindspore_special_type = (ms.Tensor) + + def __init__(self, config, data_writer): + super().__init__(config, data_writer) + self.mindspore_object_key = { + "dtype": self.analyze_dtype_in_kwargs + } + + @staticmethod + def get_md5_for_tensor(x): + if x.dtype == ms.bfloat16: + x = x.to(ms.float32) + tensor_bytes = x.asnumpy().tobytes() + crc32_hash = zlib.crc32(tensor_bytes) + return f"{crc32_hash:08x}" + + @staticmethod + def analyze_dtype_in_kwargs(element): + single_arg = {} + single_arg.update({"type": "mindspore.dtype"}) + single_arg.update({"value": str(element)}) + return single_arg + + @staticmethod + def get_stat_info(data): + if data.size == 0: + tensor_max = None + tensor_min = None + tensor_mean = None + tensor_norm = None + elif data.dtype == np.bool_: + tensor_max = True in data + tensor_min = False not in data + tensor_mean = None + tensor_norm = None + elif not data.shape: + tensor_max = data.astype(np.float32).tolist() + tensor_min = data.astype(np.float32).tolist() + tensor_mean = data.astype(np.float32).tolist() + tensor_norm = None + else: + tensor_max = data.max().astype(np.float32).tolist() + tensor_min = data.min().astype(np.float32).tolist() + tensor_mean = data.astype(np.float32).mean().tolist() + tensor_norm = None + + return tensor_max, tensor_min, tensor_mean, tensor_norm + + @classmethod + def get_special_types(cls): + return super().get_special_types() + cls.mindspore_special_type + + def _analyze_tensor(self, tensor, suffix): + saved_tensor = tensor.asnumpy() + tensor_max, tensor_min, tensor_mean, tensor_norm = self.get_stat_info(saved_tensor) + + tensor_json = {} + tensor_json.update({'type': 'mindspore.Tensor'}) + tensor_json.update({'dtype': str(tensor.dtype)}) + tensor_json.update({"shape": tensor.shape}) + tensor_json.update({"Max": tensor_max}) + tensor_json.update({"Min": tensor_min}) + tensor_json.update({"Mean": tensor_mean}) + tensor_json.update({"Norm": tensor_norm}) + if self.config.summary_mode == "md5": + tensor_md5 = self.get_md5_for_tensor(tensor) + tensor_json.update({"md5": tensor_md5}) + return tensor_json + + def analyze_single_element(self, element, suffix_stack): + if suffix_stack and suffix_stack[-1] in self.mindspore_object_key: + return self.mindspore_object_key[suffix_stack[-1]](element) + + converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) + if converted_numpy is not element: + return self._analyze_numpy(converted_numpy, numpy_type) + if isinstance(element, ms.Tensor): + return self._analyze_tensor(element, Const.SEP.join(suffix_stack)) + + if isinstance(element, (bool, int, float, str, slice)): + return self._analyze_builtin(element) + + def analyze_element(self, element): + return self.recursive_apply_transform(element, self.analyze_single_element) + + +class StatisticsDataProcessor(MindsporeDataProcessor): + pass + + +class TensorDataProcessor(MindsporeDataProcessor): + + def _analyze_tensor(self, tensor, suffix): + self.data_path = self.data_writer.dump_tensor_data_dir + dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + + suffix + ".npy") + file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + single_arg = super()._analyze_tensor(tensor, suffix) + single_arg.update({"data_name": dump_data_name}) + tensor = tensor.asnumpy() + np.save(file_path, tensor) + return single_arg + + +class OverflowCheckDataProcessor(MindsporeDataProcessor): + pass + + +class FreeBenchmarkDataProcessor(MindsporeDataProcessor): + pass \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py similarity index 44% rename from debug/accuracy_tools/atat/pytorch/functional/data_processor.py rename to debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py index 116301725b..0cc90c8914 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py +++ b/debug/accuracy_tools/atat/core/data_dump/data_processor/pytorch_processor.py @@ -1,118 +1,34 @@ -import inspect import os import zlib -from dataclasses import dataclass, asdict -from typing import Tuple, List, Dict, Optional, Union - -import numpy as np +import os +from typing import List +from dataclasses import asdict import torch -import torch_npu - -from ..common import recursive_apply_transform -from ..common.exceptions import MsaccException -from ..common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst -from ..common.log import print_warn_log -from ..common.utils import Const -from ..free_benchmark import FreeBenchmarkCheck, UnequalRow - -bits_for_overflow = 8 - - -def build_data_processor(config, data_writer): - if config.task == DataProcessor.full: - return FullTensorDataProcessor(config, data_writer) - elif config.task == DataProcessor.summary: - return DataProcessor(config, data_writer) - elif config.task == DataProcessor.overflow: - return OverflowTensorDataProcessor(config, data_writer) - elif config.task == DataProcessor.free_benchmark: - return FreeBenchmarkDataProcessor(config, data_writer) - else: - raise MsaccException(MsaccException.INVALID_PARAM_ERROR, - "task should be in [{}, {}, {}, {}]".format( - DataProcessor.full, - DataProcessor.summary, - DataProcessor.overflow, - DataProcessor.free_benchmark - )) - - -@dataclass -class ModuleForwardInputsOutputs: - args: Optional[Tuple] - kwargs: Optional[Dict] - output: Union[Tuple, torch.Tensor] - - @property - def args_tuple(self): - if not isinstance(self.args, tuple): - return (self.args,) - else: - return self.args - - @property - def output_tuple(self): - if not isinstance(self.output, tuple): - return (self.output,) - else: - return self.output - - def concat_args_and_kwargs(self): - args = self.args + tuple(self.kwargs.values()) - return args - - -@dataclass -class ModuleBackwardInputsOutputs: - grad_output: Optional[Tuple] - grad_input: Optional[Tuple] - - @property - def grad_input_tuple(self): - if not isinstance(self.grad_input, tuple): - return (self.grad_input,) - else: - return self.grad_input - - @property - def grad_output_tuple(self): - if not isinstance(self.grad_output, tuple): - return (self.grad_output,) - else: - return self.grad_output - - -class TensorStatInfo: - def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): - self.max = max_val - self.min = min_val - self.mean = mean_val - self.norm = norm_val - - -class DataProcessor: - full = "tensor" - summary = "statistics" - overflow = "overflow_check" - free_benchmark = "free_benchmark" - +import numpy as np +from ....pytorch.free_benchmark import FreeBenchmarkCheck, UnequalRow +from ...common.utils import Const +from ...common.file_check import path_len_exceeds_limit, change_mode, FileCheckConst +from ...common.log import logger +from ...common.exceptions import MsaccException +from .base import BaseDataProcessor, ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs, TensorStatInfo + +try: + import torch_npu +except ImportError: + pass + + +class PytorchDataProcessor(BaseDataProcessor): + pytorch_special_type = ( + torch.device, torch.dtype, torch.Size, torch.Tensor + ) + def __init__(self, config, data_writer): - self.data_writer = data_writer - self.api_info_struct = {} - self.stack_info_struct = {} + super().__init__(config, data_writer) self.torch_object_key = { "device": self.analyze_device_in_kwargs, "dtype": self.analyze_dtype_in_kwargs } - self.current_api_or_module_name = None - self.config = config - self.api_data_category = None - self.has_overflow = False - self.current_iter = 0 - - # 需要对forward的output进行更改 - self._return_forward_new_output = False - self._forward_new_output = None @staticmethod def get_md5_for_tensor(x): @@ -144,55 +60,7 @@ class DataProcessor: return single_arg @staticmethod - def _convert_numpy_to_builtin(arg): - type_mapping = { - np.integer: int, - np.floating: float, - np.bool_: bool, - np.complexfloating: complex, - np.str_: str, - np.byte: bytes, - np.unicode_: str - } - for numpy_type, builtin_type in type_mapping.items(): - if isinstance(arg, numpy_type): - return builtin_type(arg), type(arg).__name__ - return arg, '' - - @staticmethod - def handle_tensor_extremum_nan_inf(data_clone, operator): - data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) - if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): - return float('nan') - finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) - if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: - finite_values = data_clone[finite_mask] - return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(finite_values).item() - else: - data_no_nan = data_clone[~data_nan] - return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ - torch._C._VariableFunctionsClass.min(data_no_nan).item() - - @staticmethod - def analyze_api_call_stack(name): - stack_str = [] - for (_, path, line, func, code, _) in inspect.stack()[5:]: - if not code: - continue - stack_line = " ".join([ - "File", ", ".join([ - path, - " ".join(["line", str(line)]), - " ".join(["in", func]), - " ".join(["\n", code[0].strip()]) - ]) - ]) - stack_str.append(stack_line) - stack_info_struct = {name: stack_str} - return stack_info_struct - - def get_stat_info(self, data): + def get_stat_info(data): tensor_stat = TensorStatInfo() if data.is_meta: return tensor_stat @@ -216,43 +84,16 @@ class DataProcessor: tensor_stat.min = torch._C._VariableFunctionsClass.min(data_clone).item() tensor_stat.mean = torch._C._VariableFunctionsClass.mean(data_clone).item() tensor_stat.norm = torch._C._VariableFunctionsClass.norm(data_clone).item() - return tensor_stat - - def if_return_forward_new_output(self): - return self._return_forward_new_output - - def get_forward_new_output(self): - self._return_forward_new_output = False - return self._forward_new_output - - def update_iter(self, current_iter): - self.current_iter = current_iter - - def visit_and_clear_overflow_status(self, api_or_module_name): - if self.current_api_or_module_name != api_or_module_name: - self.current_api_or_module_name = api_or_module_name - self.has_overflow = False - - def is_dump_for_data_mode(self, forward_backward, input_output): - """ - Compare the parameters with data_mode to determine whether to dump. - - Args: - forward_backward(str): The forward or backward mode to check. - input_output(str): The input or output mode to check. - - Return: - bool: True if the parameters are in data_mode or data_mode is all, False otherwise. - """ - return (Const.ALL in self.config.data_mode or - forward_backward in self.config.data_mode or - input_output in self.config.data_mode) - + + @classmethod + def get_special_types(cls): + return super().get_special_types() + cls.pytorch_special_type + def analyze_single_element(self, element, suffix_stack): if suffix_stack and suffix_stack[-1] in self.torch_object_key: return self.torch_object_key[suffix_stack[-1]](element) - + if isinstance(element, torch.Size): return self._analyze_torch_size(element) @@ -265,119 +106,16 @@ class DataProcessor: if isinstance(element, (bool, int, float, str, slice)): return self._analyze_builtin(element) - return {} def analyze_element(self, element): - return recursive_apply_transform(element, self.analyze_single_element) - - def analyze_pre_forward(self, name, module, - module_input_output: ModuleForwardInputsOutputs): - pass - - def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): - api_info_struct = {} - if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): - api_info_struct[name] = {} - self.api_data_category = Const.INPUT - args_info_list = self.analyze_element(module_input_output.args_tuple) - api_info_struct[name][Const.INPUT_ARGS] = args_info_list - - self.api_data_category = Const.KWARGS - kwargs_info_list = self.analyze_element(module_input_output.kwargs) - api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list - - if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): - api_info_struct[name] = api_info_struct.get(name, {}) - self.api_data_category = Const.OUTPUT - output_info_list = self.analyze_element(module_input_output.output_tuple) - api_info_struct[name][Const.OUTPUT] = output_info_list - - return api_info_struct - - def analyze_pre_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs): - api_info_struct = {} - if self.is_dump_for_data_mode(Const.FORWARD, Const.INPUT): - api_info_struct[name] = {} - self.api_data_category = Const.INPUT - args_info_list = self.analyze_element(module_input_output.args_tuple) - api_info_struct[name][Const.INPUT_ARGS] = args_info_list - - self.api_data_category = Const.KWARGS - kwargs_info_list = self.analyze_element(module_input_output.kwargs) - api_info_struct[name][Const.INPUT_KWARGS] = kwargs_info_list - - return api_info_struct - - def analyze_forward_inplace(self, name, module_input_output: ModuleForwardInputsOutputs): - concat_args = module_input_output.concat_args_and_kwargs() - api_info_struct = {} - if self.is_dump_for_data_mode(Const.FORWARD, Const.OUTPUT): - api_info_struct[name] = {} - self.api_data_category = Const.OUTPUT - output_info_list = self.analyze_element(concat_args) - api_info_struct[name][Const.OUTPUT] = output_info_list - - return api_info_struct - - def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): - api_info_struct = {} - if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT): - api_info_struct[name] = {} - self.api_data_category = Const.OUTPUT - input_info_list = self.analyze_element(module_input_output.grad_input_tuple) - api_info_struct[name][Const.GRAD_INPUT] = input_info_list - - if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT): - api_info_struct[name] = api_info_struct.get(name, {}) - self.api_data_category = Const.INPUT - output_info_list = self.analyze_element(module_input_output.grad_output_tuple) - api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list - - return api_info_struct - - def _analyze_numpy(self, value, numpy_type): - single_arg = {} - single_arg.update({"type": numpy_type}) - single_arg.update({"value": value}) - return single_arg - - def _analyze_builtin(self, arg): - single_arg = {} - if isinstance(arg, slice): - single_arg.update({"type": "slice"}) - # slice参数中可能存在tensor类型,json序列化,需要转换为python数值类型 - values = [ - value if not isinstance(value, torch.Tensor) else value.item() - for value in [arg.start, arg.stop, arg.step] - ] - single_arg.update({"value": values}) - else: - single_arg.update({"type": type(arg).__name__}) - single_arg.update({"value": arg}) - return single_arg - - def _analyze_torch_size(self, arg): + return self.recursive_apply_transform(element, self.analyze_single_element) + + def _analyze_torch_size(arg): single_arg = {} single_arg.update({"type": "torch.Size"}) single_arg.update({"value": list(arg)}) return single_arg - def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): - data_clone = tensor.detach() - if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan(): - if tensor_json[Const.MAX] is None: - return - if np.isinf(tensor_json[Const.MAX]) or np.isnan(tensor_json[Const.MAX]): - tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max") - self.has_overflow = True - if np.isinf(tensor_json[Const.MIN]) or np.isnan(tensor_json[Const.MIN]): - tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min") - self.has_overflow = True - else: - self.has_overflow = check_overflow_npu() - if self.has_overflow: - clear_overflow_npu() - def _analyze_tensor(self, tensor, suffix): tensor_stat = self.get_stat_info(tensor) @@ -387,7 +125,6 @@ class DataProcessor: tensor_json.update({"shape": tensor.shape}) tensor_json.update({"Max": tensor_stat.max}) tensor_json.update({"Min": tensor_stat.min}) - self._analyze_maybe_overflow_tensor(tensor_json, tensor) tensor_json.update({"Mean": tensor_stat.mean}) tensor_json.update({"Norm": tensor_stat.norm}) tensor_json.update({"requires_grad": tensor.requires_grad}) @@ -398,27 +135,24 @@ class DataProcessor: return tensor_json -class FullTensorDataProcessor(DataProcessor): +class StatisticsDataProcessor(PytorchDataProcessor): + pass - def __init__(self, config, data_writer): - super().__init__(config, data_writer) - self.data_path = self.data_writer.dump_tensor_data_dir +class TensorDataProcessor(PytorchDataProcessor): def _analyze_tensor(self, tensor, suffix): - dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + - suffix + ".pt") - file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + dump_data_name, file_path = self.get_save_file_path(suffix) if not path_len_exceeds_limit(file_path): torch.save(tensor, file_path) change_mode(file_path, FileCheckConst.DATA_FILE_AUTHORITY) else: - print_warn_log(f'The file path {file_path} length exceeds limit.') + logger.warning(f'The file path {file_path} length exceeds limit.') single_arg = super()._analyze_tensor(tensor, suffix) single_arg.update({"data_name": dump_data_name}) return single_arg -class OverflowTensorDataProcessor(DataProcessor): +class OverflowCheckDataProcessor(PytorchDataProcessor): __slots__ = ["cached_tensors_and_file_paths"] def __init__(self, config, data_writer): @@ -426,28 +160,51 @@ class OverflowTensorDataProcessor(DataProcessor): self.cached_tensors_and_file_paths = {} self.real_overflow_dump_times = 0 self.overflow_nums = config.overflow_num + self.bits_for_overflow = 8 + + @staticmethod + def overflow_debug_mode_enable(): + overflow_mode = os.getenv(Const.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE) + return overflow_mode == Const.ENV_ENABLE + + @staticmethod + def handle_tensor_extremum_nan_inf(data_clone, operator): + data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) + if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): + return float('nan') + finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) + if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: + finite_values = data_clone[finite_mask] + return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(finite_values).item() + else: + data_no_nan = data_clone[~data_nan] + return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(data_no_nan).item() - def _analyze_tensor(self, tensor, suffix): - dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + - suffix + ".pt") - file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) - if not path_len_exceeds_limit(file_path): - self.cached_tensors_and_file_paths.update({file_path: tensor}) + def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): + data_clone = tensor.detach() + if hasattr(torch_npu._C, '_npu_is_support_inf_nan') and torch_npu._C._npu_is_support_inf_nan(): + if tensor_json['Max'] is None: + return + if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): + tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "max") + self.has_overflow = True + if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): + tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(data_clone, "min") + self.has_overflow = True else: - print_warn_log(f'The file path {file_path} length exceeds limit.') - single_arg = super()._analyze_tensor(tensor, suffix) - single_arg.update({"data_name": dump_data_name}) - return single_arg + self.has_overflow = self.check_overflow_npu() + if self.has_overflow: + self.clear_overflow_npu() - def analyze_forward(self, name, module, - module_input_output: ModuleForwardInputsOutputs): + def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_forward(name, module, module_input_output) self.maybe_save_overflow_data_and_check_overflow_times() return api_info_struct if self.has_overflow else None - def analyze_backward(self, name, module, - module_input_output: ModuleBackwardInputsOutputs): + def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): self.has_overflow = False api_info_struct = super().analyze_backward(name, module, module_input_output) self.maybe_save_overflow_data_and_check_overflow_times() @@ -466,22 +223,46 @@ class OverflowTensorDataProcessor(DataProcessor): if self.overflow_nums == -1: return if self.real_overflow_dump_times >= self.overflow_nums: - raise MsaccException(MsaccException.OVERFLOW_NUMS_ERROR, - str(self.real_overflow_dump_times)) + raise MsaccException(MsaccException.OVERFLOW_NUMS_ERROR, str(self.real_overflow_dump_times)) + + def clear_overflow_npu(self): + if self.overflow_debug_mode_enable(): + float_status = torch.zeros(self.bits_for_overflow).npu() + torch_npu.npu_clear_float_status(float_status, Const.OVERFLOW_DEBUG_MODE) + else: + torch_npu._C._clear_overflow_npu() + + def clear_overflow_npu(self): + if self.overflow_debug_mode_enable(): + float_status = torch.zeros(self.bits_for_overflow).npu() + torch_npu.npu_clear_float_status(float_status, Const.OVERFLOW_DEBUG_MODE) + else: + torch_npu._C._clear_overflow_npu() + + def _analyze_tensor(self, tensor, suffix): + dump_data_name, file_path = self.get_save_file_path(suffix) + if not path_len_exceeds_limit(file_path): + self.cached_tensors_and_file_paths.update({file_path: tensor}) + else: + logger.warning(f'The file path {file_path} length exceeds limit.') + single_arg = super()._analyze_tensor(tensor) + self._analyze_maybe_overflow_tensor(single_arg, tensor) + single_arg.update({"data_name": dump_data_name}) + return single_arg -class FreeBenchmarkDataProcessor(DataProcessor): +class FreeBenchmarkDataProcessor(PytorchDataProcessor): def __init__(self, config, data_writer): super().__init__(config, data_writer) self.checker = FreeBenchmarkCheck(config=config) - + def update_iter(self, current_iter): self.current_iter = current_iter self.checker.update_iter(current_iter) def update_unequal_rows(self, unequal_rows: List[UnequalRow]): - if len(unequal_rows) == 0: + if not unequal_rows: return for row in unequal_rows: data_dict = asdict(row) @@ -492,8 +273,7 @@ class FreeBenchmarkDataProcessor(DataProcessor): ) return - def analyze_pre_forward(self, name, module, - module_input_output: ModuleForwardInputsOutputs): + def analyze_pre_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): args = module_input_output.args kwargs = module_input_output.kwargs self.checker.pre_forward(name, module, self, args, kwargs) @@ -510,42 +290,69 @@ class FreeBenchmarkDataProcessor(DataProcessor): if self.checker.if_fix(): self._return_forward_new_output = True self._forward_new_output = new_output - return None def analyze_backward(self, name, module, module_input_output: ModuleBackwardInputsOutputs): self.checker.backward(name, module, module_input_output.grad_output) - return None - - -def overflow_debug_mode_enable(): - overflow_mode = os.getenv(OverflowConst.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE) - return overflow_mode == Const.ENV_ENABLE - - -def check_overflow_npu(): - if overflow_debug_mode_enable(): - float_status = torch.zeros(bits_for_overflow).npu() - result = torch_npu.npu_get_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE) - if (result.cpu()[0] != 0): - return True + + +class KernelDumpDataProcessor(PytorchDataProcessor): + forward_init_status = False + multi_output_apis = ["_sort_", "npu_flash_attention"] + + def __init__(self, config, data_writer): + super().__init__(config, data_writer) + + def analyze_forward(self, name, module, module_input_output): + if self.config.is_forward_acl_dump: + self.forward_acl_dump(name, module, module_input_output) else: - return False - else: - return torch_npu._C._check_overflow_npu() - - -def clear_overflow_npu(): - if overflow_debug_mode_enable(): - float_status = torch.zeros(bits_for_overflow).npu() - torch_npu.npu_clear_float_status(float_status, OverflowConst.OVERFLOW_DEBUG_MODE) - else: - torch_npu._C._clear_overflow_npu() - - -class OverflowConst: - """ - Class for Overflow - """ - OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" - OVERFLOW_ORIGINAL_MODE = 0 - OVERFLOW_DEBUG_MODE = 1 + self.dump_mode_backward_acl_dump(name, module, module_input_output) + + def forward_acl_dump(self, name, module, module_input_output): + if not KernelDumpDataProcessor.forward_init_status: + KernelDumpDataProcessor.forward_init_status = True + torch_npu.npu.synchronize() + torch_npu.npu.init_dump() + torch_npu.npu.set_dump(self.config.acl_config) + torch_npu.npu.synchronize() + if self.op_need_trigger(name): + module.forward(*module_input_output.args, **module_input_output.kwargs).cpu() + else: + module.forward(*module_input_output.args, **module_input_output.kwargs) + torch_npu.npu.synchronize() + torch_npu.npu.finalize_dump() + torch_npu.npu.synchronize() + KernelDumpDataProcessor.forward_init_status = False + logger.info("Dump %s op file." % name) + + def acl_backward_dump_status(self, output, grad, module_name): + if isinstance(output, torch.Tensor): + output.backward(grad, retain_graph=True) + return True + + for api_name in KernelDumpDataProcessor.multi_output_apis: + if api_name in module_name: + output[0].backward(grad, retain_graph=True) + return True + return False + + def dump_mode_backward_acl_dump(self, name, module, module_input_output): + grad_path = self.config.backward_input.get(name) + if not KernelDumpDataProcessor.forward_init_status: + KernelDumpDataProcessor.forward_init_status = True + output = module.forward(*module_input_output.args, **module_input_output.kwargs) + grad = torch.load(grad_path).to("npu").requires_grad_() + torch_npu.npu.init_dump() + torch_npu.npu.set_dump(self.config.acl_config) + torch_npu.npu.synchronize() + if not self.acl_backward_dump_status(output, grad, name): + logger.warning("The output of {} is not of tensor type and cannot be automatically derived. " + "you can manually construct a single API backward case for ACL dump.".format( + name)) + torch_npu.npu.synchronize() + torch_npu.npu.finalize_dump() + KernelDumpDataProcessor.forward_init_status = False + logger.info("Dump %s op file." % name) + + def op_need_trigger(self, module_name): + return 'Tensor.__getitem__.' in module_name \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/functional/json_writer.py b/debug/accuracy_tools/atat/core/data_dump/json_writer.py similarity index 80% rename from debug/accuracy_tools/atat/pytorch/functional/json_writer.py rename to debug/accuracy_tools/atat/core/data_dump/json_writer.py index 216d24882e..10070e2941 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/json_writer.py +++ b/debug/accuracy_tools/atat/core/data_dump/json_writer.py @@ -1,14 +1,15 @@ +import os import csv +import fcntl import json -import os from pathlib import Path from ..common.file_check import FileCheckConst, change_mode -from ..common.log import print_info_log_rank_0 +from ..common.log import logger from ..common.utils import Const -class DataWriter: # TODO: UT +class DataWriter: def __init__(self, init_json=None) -> None: self.dump_count = 0 @@ -25,22 +26,22 @@ class DataWriter: # TODO: UT @staticmethod def write_data_to_csv(result: list, result_header: tuple, file_path: str): - if len(result) == 0: + if not result: return is_exists = os.path.exists(file_path) append = "a+" if is_exists else "w+" with os.fdopen( - os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline="" + os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline="" ) as csv_file: spawn_writer = csv.writer(csv_file) if not is_exists: spawn_writer.writerow(result_header) - spawn_writer.writerows([result, ]) + spawn_writer.writerows([result,]) def initialize_json_file(self, **kwargs): kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) with os.fdopen( - os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w' + os.open(self.dump_file_path, Const.OVERWRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), 'w' ) as f: json.dump(kwargs, f) @@ -54,7 +55,7 @@ class DataWriter: # TODO: UT Path(self.construct_file_path).touch() change_mode(self.construct_file_path, FileCheckConst.DATA_FILE_AUTHORITY) - def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, + def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path): self.dump_file_path = dump_file_path self.stack_file_path = stack_file_path @@ -80,8 +81,7 @@ class DataWriter: # TODO: UT self.cache_construct.update(new_data) def write_data_json(self, file_path): - import fcntl - print_info_log_rank_0(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") + logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") if Path(file_path).exists() and os.path.getsize(file_path) > 0: with open(file_path, "r+") as f: fcntl.flock(f, fcntl.LOCK_EX) @@ -99,14 +99,12 @@ class DataWriter: # TODO: UT self.cache_data[Const.DATA].clear() def write_stack_info_json(self, file_path): - import fcntl with open(file_path, 'w+') as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(self.cache_stack, f, indent=1) fcntl.flock(f, fcntl.LOCK_UN) def write_construct_info_json(self, file_path): - import fcntl with open(file_path, 'w+') as f: fcntl.flock(f, fcntl.LOCK_EX) json.dump(self.cache_construct, f, indent=1) @@ -116,3 +114,18 @@ class DataWriter: # TODO: UT self.write_data_json(self.dump_file_path) self.write_stack_info_json(self.stack_file_path) self.write_construct_info_json(self.construct_file_path) + + @staticmethod + def write_data_to_csv(result: list, result_header: tuple, file_path: str): + if not result: + return + is_exists = os.path.exists(file_path) + append = "a+" if is_exists else "w+" + with os.fdopen( + os.open(file_path, Const.WRITE_FLAGS, FileCheckConst.DATA_FILE_AUTHORITY), append, newline="" + ) as csv_file: + spawn_writer = csv.writer(csv_file) + if not is_exists: + spawn_writer.writerow(result_header) + spawn_writer.writerows([result,]) + \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/functional/scope.py b/debug/accuracy_tools/atat/core/data_dump/scope.py similarity index 99% rename from debug/accuracy_tools/atat/pytorch/functional/scope.py rename to debug/accuracy_tools/atat/core/data_dump/scope.py index 735c6d9c18..fd84b897ed 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/scope.py +++ b/debug/accuracy_tools/atat/core/data_dump/scope.py @@ -10,7 +10,6 @@ def build_scope(scope_class, scope=None, api_list=None): scope = [] if api_list is None: api_list = [] - if scope_class: return scope_class(scope, api_list) return build_range_scope_according_to_scope_name(scope, api_list) @@ -73,6 +72,7 @@ class BaseScope(ABC): return True return False + class ListScope(BaseScope): @staticmethod def rectify_args(scope, api_list): @@ -94,6 +94,7 @@ class RangeScope(BaseScope, ABC): self.in_scope = False self.is_valid = self.check_scope_is_valid() + @staticmethod def rectify_args(scope, api_list): scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list) diff --git a/debug/accuracy_tools/atat/core/file_check_util.py b/debug/accuracy_tools/atat/core/file_check_util.py deleted file mode 100644 index b10cdd6104..0000000000 --- a/debug/accuracy_tools/atat/core/file_check_util.py +++ /dev/null @@ -1,319 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import os -import re - -from .log import print_warn_log, print_error_log - - -class FileCheckConst: - """ - Class for file check const - """ - READ_ABLE = "read" - WRITE_ABLE = "write" - READ_WRITE_ABLE = "read and write" - DIRECTORY_LENGTH = 4096 - FILE_NAME_LENGTH = 255 - FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" - PKL_SUFFIX = ".pkl" - NUMPY_SUFFIX = ".npy" - JSON_SUFFIX = ".json" - PT_SUFFIX = ".pt" - CSV_SUFFIX = ".csv" - YAML_SUFFIX = ".yaml" - MAX_PKL_SIZE = 1 * 1024 * 1024 * 1024 - MAX_NUMPY_SIZE = 10 * 1024 * 1024 * 1024 - MAX_JSON_SIZE = 1 * 1024 * 1024 * 1024 - MAX_PT_SIZE = 10 * 1024 * 1024 * 1024 - MAX_CSV_SIZE = 1 * 1024 * 1024 * 1024 - MAX_YAML_SIZE = 10 * 1024 * 1024 - DIR = "dir" - FILE = "file" - DATA_DIR_AUTHORITY = 0o750 - DATA_FILE_AUTHORITY = 0o640 - FILE_SIZE_DICT = { - PKL_SUFFIX: MAX_PKL_SIZE, - NUMPY_SUFFIX: MAX_NUMPY_SIZE, - JSON_SUFFIX: MAX_JSON_SIZE, - PT_SUFFIX: MAX_PT_SIZE, - CSV_SUFFIX: MAX_CSV_SIZE, - YAML_SUFFIX: MAX_YAML_SIZE - } - - -class FileCheckException(Exception): - """ - Class for File Check Exception - """ - NONE_ERROR = 0 - INVALID_PATH_ERROR = 1 - INVALID_FILE_TYPE_ERROR = 2 - INVALID_PARAM_ERROR = 3 - INVALID_PERMISSION_ERROR = 3 - - def __init__(self, code, error_info: str = ""): - super(FileCheckException, self).__init__() - self.code = code - self.error_info = error_info - - def __str__(self): - return self.error_info - - -class FileChecker: - """ - The class for check file. - - Attributes: - file_path: The file or dictionary path to be verified. - path_type: file or dictionary - ability(str): FileCheckConst.WRITE_ABLE or FileCheckConst.READ_ABLE to set file has writability or readability - file_type(str): The correct file type for file - """ - def __init__(self, file_path, path_type, ability=None, file_type=None, is_script=True): - self.file_path = file_path - self.path_type = self._check_path_type(path_type) - self.ability = ability - self.file_type = file_type - self.is_script = is_script - - @staticmethod - def _check_path_type(path_type): - if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: - print_error_log(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') - raise FileCheckException(FileCheckException.INVALID_PARAM_ERROR) - return path_type - - def common_check(self): - """ - 功能:用户校验基本文件权限:软连接、文件长度、是否存在、读写权限、文件属组、文件特殊字符 - 注意:文件后缀的合法性,非通用操作,可使用其他独立接口实现 - """ - check_path_exists(self.file_path) - check_link(self.file_path) - self.file_path = os.path.realpath(self.file_path) - check_path_length(self.file_path) - check_path_type(self.file_path, self.path_type) - self.check_path_ability() - if self.is_script: - check_path_owner_consistent(self.file_path) - check_path_pattern_vaild(self.file_path) - check_common_file_size(self.file_path) - check_file_suffix(self.file_path, self.file_type) - return self.file_path - - def check_path_ability(self): - if self.ability == FileCheckConst.WRITE_ABLE: - check_path_writability(self.file_path) - if self.ability == FileCheckConst.READ_ABLE: - check_path_readability(self.file_path) - if self.ability == FileCheckConst.READ_WRITE_ABLE: - check_path_readability(self.file_path) - check_path_writability(self.file_path) - - -class FileOpen: - """ - The class for open file by a safe way. - - Attributes: - file_path: The file or dictionary path to be opened. - mode(str): The file open mode - """ - SUPPORT_READ_MODE = ["r", "rb"] - SUPPORT_WRITE_MODE = ["w", "wb", "a", "ab"] - SUPPORT_READ_WRITE_MODE = ["r+", "rb+", "w+", "wb+", "a+", "ab+"] - - def __init__(self, file_path, mode, encoding='utf-8'): - self.file_path = file_path - self.mode = mode - self.encoding = encoding - self._handle = None - - def __enter__(self): - self.check_file_path() - binary_mode = "b" - if binary_mode not in self.mode: - self._handle = open(self.file_path, self.mode, encoding=self.encoding) - else: - self._handle = open(self.file_path, self.mode) - return self._handle - - def __exit__(self, exc_type, exc_val, exc_tb): - if self._handle: - self._handle.close() - - def check_file_path(self): - support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE - if self.mode not in support_mode: - print_error_log("File open not support %s mode" % self.mode) - raise FileCheckException(FileCheckException.INVALID_PARAM_ERROR) - check_link(self.file_path) - self.file_path = os.path.realpath(self.file_path) - check_path_length(self.file_path) - self.check_ability_and_owner() - check_path_pattern_vaild(self.file_path) - if os.path.exists(self.file_path): - check_common_file_size(self.file_path) - - def check_ability_and_owner(self): - if self.mode in self.SUPPORT_READ_MODE: - check_path_exists(self.file_path) - check_path_readability(self.file_path) - check_path_owner_consistent(self.file_path) - if self.mode in self.SUPPORT_WRITE_MODE and os.path.exists(self.file_path): - check_path_writability(self.file_path) - check_path_owner_consistent(self.file_path) - if self.mode in self.SUPPORT_READ_WRITE_MODE and os.path.exists(self.file_path): - check_path_readability(self.file_path) - check_path_writability(self.file_path) - check_path_owner_consistent(self.file_path) - - -def check_link(path): - abs_path = os.path.abspath(path) - if os.path.islink(abs_path): - print_error_log('The file path {} is a soft link.'.format(path)) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) - - -def check_path_length(path, name_length=None): - file_max_name_length = name_length if name_length else FileCheckConst.FILE_NAME_LENGTH - if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ - len(os.path.basename(path)) > file_max_name_length: - print_error_log('The file path length exceeds limit.') - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) - - -def check_path_exists(path): - if not os.path.exists(path): - print_error_log('The file path %s does not exist.' % path) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) - - -def check_path_readability(path): - if not os.access(path, os.R_OK): - print_error_log('The file path %s is not readable.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) - - -def check_path_writability(path): - if not os.access(path, os.W_OK): - print_error_log('The file path %s is not writable.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) - - -def check_path_executable(path): - if not os.access(path, os.X_OK): - print_error_log('The file path %s is not executable.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) - - -def check_other_user_writable(path): - st = os.stat(path) - if st.st_mode & 0o002: - _user_interactive_confirm( - 'The file path %s may be insecure because other users have write permissions. ' - 'Do you want to continue?' % path) - - -def _user_interactive_confirm(message): - while True: - check_message = input(message + " Enter 'c' to continue or enter 'e' to exit: ") - if check_message == "c": - break - elif check_message == "e": - print_warn_log("User canceled.") - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) - else: - print("Input is error, please enter 'c' or 'e'.") - - -def check_path_owner_consistent(path): - file_owner = os.stat(path).st_uid - if file_owner != os.getuid(): - print_error_log('The file path %s may be insecure because is does not belong to you.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) - - -def check_path_pattern_vaild(path): - if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): - print_error_log('The file path {} contains special characters.'.format(path)) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) - - -def check_file_size(file_path, max_size): - file_size = os.path.getsize(file_path) - if file_size >= max_size: - _user_interactive_confirm(f'The size of file path {file_path} exceeds {max_size} bytes.' - f'Do you want to continue?') - - -def check_common_file_size(file_path): - if os.path.isfile(file_path): - for suffix, max_size in FileCheckConst.FILE_SIZE_DICT.items(): - if file_path.endswith(suffix): - check_file_size(file_path, max_size) - break - - -def check_file_suffix(file_path, file_suffix): - if file_suffix: - if not file_path.endswith(file_suffix): - print_error_log(f"The {file_path} should be a {file_suffix} file!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) - - -def check_path_type(file_path, file_type): - if file_type == FileCheckConst.FILE: - if not os.path.isfile(file_path): - print_error_log(f"The {file_path} should be a file!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) - if file_type == FileCheckConst.DIR: - if not os.path.isdir(file_path): - print_error_log(f"The {file_path} should be a dictionary!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) - - -def create_directory(dir_path): - """ - Function Description: - creating a directory with specified permissions - Parameter: - dir_path: directory path - Exception Description: - when invalid data throw exception - """ - dir_path = os.path.realpath(dir_path) - try: - os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) - except OSError as ex: - print_error_log( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) from ex - - -def change_mode(path, mode): - if not os.path.exists(path) or os.path.islink(path): - return - try: - os.chmod(path, mode) - except PermissionError as ex: - print_error_log('Failed to change {} authority. {}'.format(path, str(ex))) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) from ex - diff --git a/debug/accuracy_tools/atat/core/log.py b/debug/accuracy_tools/atat/core/log.py deleted file mode 100644 index b9ac8f5edf..0000000000 --- a/debug/accuracy_tools/atat/core/log.py +++ /dev/null @@ -1,56 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2024. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" -import os -import time -import sys - - -def _print_log(level, msg, end='\n'): - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) - pid = os.getgid() - print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) - sys.stdout.flush() - - -def print_info_log(info_msg, end='\n'): - """ - Function Description: - print info log. - Parameter: - info_msg: the info message. - """ - _print_log("INFO", info_msg, end=end) - - -def print_error_log(error_msg): - """ - Function Description: - print error log. - Parameter: - error_msg: the error message. - """ - _print_log("ERROR", error_msg) - - -def print_warn_log(warn_msg): - """ - Function Description: - print warn log. - Parameter: - warn_msg: the warning message. - """ - _print_log("WARNING", warn_msg) \ No newline at end of file diff --git a/debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py b/debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py index b0f80f40e5..49dc18fdd3 100644 --- a/debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py +++ b/debug/accuracy_tools/atat/mindspore/dump/api_kbk_dump.py @@ -1,9 +1,9 @@ import os import json -from atat.core.utils import make_dump_path_if_not_exists +from debug.accuracy_tools.atat.core.common.utils import make_dump_path_if_not_exists from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.core.log import print_info_log -from atat.core.file_check_util import FileOpen +from atat.core.common.log import logger +from atat.core.common.file_check import FileOpen class ApiKbkDump: @@ -48,7 +48,7 @@ class ApiKbkDump: json_path = os.path.join(json_path, "api_kbk_dump.json") with FileOpen(json_path, 'w') as f: json.dump(self.dump_json, f) - print_info_log(json_path + " has been created.") + logger.info(json_path + " has been created.") os.environ["GRAPH_OP_RUN"] = "1" os.environ["MINDSPORE_DUMP_CONFIG"] = json_path if "MS_ACL_DUMP_CFG_PATH" in os.environ: diff --git a/debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py b/debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py index f8a10ec1b1..78aad2c076 100644 --- a/debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py +++ b/debug/accuracy_tools/atat/mindspore/dump/kernel_graph_dump.py @@ -1,9 +1,9 @@ import os import json -from atat.core.utils import make_dump_path_if_not_exists +from debug.accuracy_tools.atat.core.common.utils import make_dump_path_if_not_exists from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.core.log import print_info_log -from atat.core.file_check_util import FileOpen +from atat.core.common.log import logger +from atat.core.common.file_check import FileOpen class KernelGraphDump: @@ -49,7 +49,7 @@ class KernelGraphDump: json_path = os.path.join(json_path, "kernel_graph_dump.json") with FileOpen(json_path, 'w') as f: json.dump(self.dump_json, f) - print_info_log(json_path + " has been created.") + logger.info(json_path + " has been created.") os.environ["MINDSPORE_DUMP_CONFIG"] = json_path if self.dump_json["common_dump_settings"]["dump_mode"] == 0: if self.dump_json["common_dump_settings"]["iteration"] != "all" or \ diff --git a/debug/accuracy_tools/atat/mindspore/ms_config.py b/debug/accuracy_tools/atat/mindspore/ms_config.py index 0d846c4771..02cead32f1 100644 --- a/debug/accuracy_tools/atat/mindspore/ms_config.py +++ b/debug/accuracy_tools/atat/mindspore/ms_config.py @@ -1,6 +1,6 @@ import json from atat.core.common_config import CommonConfig, BaseConfig -from atat.core.file_check_util import FileOpen +from atat.core.common.file_check import FileOpen class TensorConfig(BaseConfig): diff --git a/debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py b/debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py index 5ef005e59e..3a315d1aea 100644 --- a/debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py +++ b/debug/accuracy_tools/atat/mindspore/overflow_check/kernel_graph_overflow_check.py @@ -1,9 +1,9 @@ import os import json -from atat.core.utils import make_dump_path_if_not_exists +from debug.accuracy_tools.atat.core.common.utils import make_dump_path_if_not_exists from atat.mindspore.debugger.debugger_config import DebuggerConfig -from atat.core.log import print_warn_log, print_info_log -from atat.core.file_check_util import FileOpen +from atat.core.common.log import logger +from atat.core.common.file_check import FileOpen class KernelGraphOverflowCheck: @@ -23,7 +23,7 @@ class KernelGraphOverflowCheck: self.dump_json["common_dump_settings"]["path"] = config.dump_path if len(config.step) > 0: - print_warn_log("Step would change to all in this task.") + logger.warning("Step would change to all in this task.") if len(config.rank) > 0: self.dump_json["common_dump_settings"]["support_device"] = config.rank if config.check_mode == "aicore": @@ -39,7 +39,7 @@ class KernelGraphOverflowCheck: json_path = os.path.join(json_path, "kernel_graph_overflow_check.json") with FileOpen(json_path, 'w') as f: json.dump(self.dump_json, f) - print_info_log(json_path + " has been created.") + logger.info(json_path + " has been created.") os.environ["MINDSPORE_DUMP_CONFIG"] = json_path if "MS_ACL_DUMP_CFG_PATH" in os.environ: del os.environ["MS_ACL_DUMP_CFG_PATH"] diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py index db193dcd83..55fce79133 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor.py +++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor.py @@ -20,8 +20,9 @@ import pandas as pd from .advisor_result import AdvisorResult from .advisor_const import AdvisorConst -from ...core.utils import CompareException, CompareConst, Const, print_info_log, print_warn_log, print_error_log -from ...core.file_check_util import FileChecker, FileCheckConst +from ..common.log import logger +from ...core.common.utils import CompareException, CompareConst, Const +from ...core.common.file_check import FileChecker, FileCheckConst class Advisor: @@ -57,7 +58,7 @@ class Advisor: if num_unmatch != 0: for i in range(len(accuracy_unmatched)): item = accuracy_unmatched.iloc[i] - print_warn_log("The tensor name matches but the shape or dtype does not match: {}" + logger.warning("The tensor name matches but the shape or dtype does not match: {}" .format(item[CompareConst.NPU_NAME])) def gen_advisor_result(self, pd_data): @@ -65,7 +66,7 @@ class Advisor: node_name = first_failing_data[CompareConst.NPU_NAME] index = first_failing_data['index'] message = self.gen_advisor_message(node_name) - print_warn_log("Find %s accuracy not reached, the line is %s" % (node_name, index)) + logger.warning("Find %s accuracy not reached, the line is %s" % (node_name, index)) result = AdvisorResult(node_name, index, message) return result @@ -88,7 +89,7 @@ class Advisor: def analysis(self): self._check_path_vaild() analyze_data = self._parse_input_data() - print_info_log("Start analyzing the comparison result: %s" % self.file_type) + logger.info("Start analyzing the comparison result: %s" % self.file_type) self.analyze_unmatched(analyze_data) if self.file_type == Const.ALL: failing_data = analyze_data[analyze_data[CompareConst.ACCURACY] == CompareConst.ACCURACY_CHECK_NO] @@ -97,7 +98,7 @@ class Advisor: elif self.file_type == Const.SUMMARY: failing_data = analyze_data[analyze_data[CompareConst.RESULT] == CompareConst.WARNING] if failing_data.empty: - print_info_log("All data from api input/output accuracy reached") + logger.info("All data from api input/output accuracy reached") result = AdvisorResult(AdvisorConst.NO_ERROR_API, AdvisorConst.NO_ERROR_API, AdvisorConst.NO_ERR_SUGGEST) else: result = self.gen_advisor_result(failing_data) @@ -113,7 +114,7 @@ class Advisor: elif {CompareConst.MAX_DIFF, CompareConst.RESULT}.issubset(data_columns): self.file_type = Const.SUMMARY else: - print_error_log('Compare result does not meet the required conditions.') + logger.error('Compare result does not meet the required conditions.') raise CompareException(CompareException.INVALID_DATA_ERROR) df = self.input_data.reset_index() return df diff --git a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py b/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py index f8a16d2a70..508b56846e 100644 --- a/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py +++ b/debug/accuracy_tools/atat/pytorch/advisor/advisor_result.py @@ -18,8 +18,9 @@ import os import time from .advisor_const import AdvisorConst -from ...core.utils import Const, print_info_log, print_error_log -from ...core.file_check_util import FileCheckConst, change_mode +from ..common.log import logger +from ...core.common.utils import Const +from ...core.common.file_check import FileCheckConst, change_mode class AdvisorResult: @@ -43,15 +44,15 @@ class AdvisorResult: output_file.writelines(message_list) change_mode(result_file, FileCheckConst.DATA_FILE_AUTHORITY) except IOError as io_error: - print_error_log("Failed to save %s, the reason is %s." % (result_file, io_error)) + logger.error("Failed to save %s, the reason is %s." % (result_file, io_error)) else: - print_info_log("The advisor summary is saved in: %s" % result_file) + logger.info("The advisor summary is saved in: %s" % result_file) def print_advisor_log(self): - print_info_log("The summary of the expert advice is as follows: ") + logger.info("The summary of the expert advice is as follows: ") message_list = [AdvisorConst.LINE + AdvisorConst.COLON + str(self.line), AdvisorConst.SUSPECT_NODES + AdvisorConst.COLON + self.suspect_node, AdvisorConst.ADVISOR_SUGGEST + AdvisorConst.COLON + self.advisor_message] for message in message_list: - print_info_log(message) + logger.info(message) return message_list diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py index db6db968bf..21f0986477 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py @@ -2,7 +2,7 @@ import os import yaml from ..common.utils import check_file_or_directory_path from ...hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps -from ...common.file_check import FileOpen +from atat.core.common.file_check import FileOpen WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py index 51d7c556ed..04a2b2f067 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py @@ -28,10 +28,10 @@ except ImportError: else: IS_GPU = False -from atat.pytorch.common.log import print_warn_log, print_error_log -from atat.pytorch.common.file_check import FileCheckConst, FileChecker, FileOpen, change_mode, create_directory +from atat.pytorch.common.log import logger +from atat.core.common.file_check import FileCheckConst, FileChecker, FileOpen, change_mode, create_directory from atat.pytorch.common.utils import Const -from atat.core.utils import CompareException +from debug.accuracy_tools.atat.core.common.utils import CompareException class DumpException(CompareException): @@ -55,7 +55,7 @@ def check_object_type(check_object, allow_type): when invalid data throw exception """ if not isinstance(check_object, allow_type): - print_error_log(f"{check_object} not of {allow_type} type") + logger.error(f"{check_object} not of {allow_type} type") raise CompareException(CompareException.INVALID_DATA_ERROR) @@ -71,24 +71,24 @@ def check_file_or_directory_path(path, isdir=False): """ if isdir: if not os.path.exists(path): - print_error_log('The path {} is not exist.'.format(path)) + logger.error('The path {} is not exist.'.format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) if not os.path.isdir(path): - print_error_log('The path {} is not a directory.'.format(path)) + logger.error('The path {} is not a directory.'.format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) if not os.access(path, os.W_OK): - print_error_log( + logger.error( 'The path {} does not have permission to write. Please check the path permission'.format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) else: if not os.path.isfile(path): - print_error_log('{} is an invalid file or non-exist.'.format(path)) + logger.error('{} is an invalid file or non-exist.'.format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) if not os.access(path, os.R_OK): - print_error_log( + logger.error( 'The path {} does not have permission to read. Please check the path permission'.format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) @@ -98,10 +98,10 @@ def get_json_contents(file_path): try: json_obj = json.loads(ops) except ValueError as error: - print_error_log('Failed to load "%s". %s' % (file_path, str(error))) + logger.error('Failed to load "%s". %s' % (file_path, str(error))) raise CompareException(CompareException.INVALID_FILE_ERROR) from error if not isinstance(json_obj, dict): - print_error_log('Json file %s, content is not a dictionary!' % file_path) + logger.error('Json file %s, content is not a dictionary!' % file_path) raise CompareException(CompareException.INVALID_FILE_ERROR) return json_obj @@ -161,7 +161,7 @@ def cross_entropy_process(api_info_dict): def initialize_save_path(save_path, dir_name): data_path = os.path.join(save_path, dir_name) if os.path.exists(data_path): - print_warn_log(f"{data_path} already exists, it will be overwritten") + logger.warning(f"{data_path} already exists, it will be overwritten") else: os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) data_path_checker = FileChecker(data_path, FileCheckConst.DIR) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py index fcc4ba3d5d..a5ba27d633 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -14,9 +14,9 @@ from atat.pytorch.api_accuracy_checker.compare.compare_utils import CompareConst convert_str_to_float, CompareMessage from atat.pytorch.api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn from atat.pytorch.api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path -from atat.pytorch.common.file_check import FileCheckConst, FileChecker, change_mode, check_path_before_create, create_directory -from atat.pytorch.common.log import print_info_log, print_warn_log, print_error_log -from atat.core.utils import CompareException +from atat.core.common.file_check import FileCheckConst, FileChecker, change_mode, check_path_before_create, create_directory +from atat.pytorch.common.log import logger +from debug.accuracy_tools.atat.core.common.utils import CompareException CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) unsupported_message = 'This data type does not support benchmark compare.' @@ -152,18 +152,18 @@ def write_detail_csv(content, save_path): def api_precision_compare(config): - print_info_log("Start compare task") - print_info_log(f"Compare task result will be saved in {config.result_csv_path}") - print_info_log(f"Compare task detail will be saved in {config.details_csv_path}") + logger.info("Start compare task") + logger.info(f"Compare task result will be saved in {config.result_csv_path}") + logger.info(f"Compare task detail will be saved in {config.details_csv_path}") try: npu_data = pd.read_csv(config.npu_csv_path) except Exception as err: - print_error_log(f"Open npu csv Error: %s" % str(err)) + logger.error(f"Open npu csv Error: %s" % str(err)) check_csv_columns(npu_data.columns, "npu_csv") try: gpu_data = pd.read_csv(config.gpu_csv_path) except Exception as err: - print_error_log(f"Open gpu csv Error: %s" % str(err)) + logger.error(f"Open gpu csv Error: %s" % str(err)) check_csv_columns(gpu_data.columns, "gpu_csv") detail_csv_title = [ApiPrecisionCompareColumn.get_detail_csv_title()] result_csv_title = [ApiPrecisionCompareColumn.get_result_csv_title()] @@ -172,7 +172,7 @@ def api_precision_compare(config): try: analyse_csv(npu_data, gpu_data, config) except Exception as err: - print_error_log(f"Analyse csv Error: %s" % str(err)) + logger.error(f"Analyse csv Error: %s" % str(err)) change_mode(config.result_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) change_mode(config.details_csv_path, FileCheckConst.DATA_FILE_AUTHORITY) @@ -187,7 +187,7 @@ def analyse_csv(npu_data, gpu_data, config): row_gpu = gpu_data[gpu_data[ApiPrecisionCompareColumn.API_NAME] == full_api_name_with_direction_status] _, api_name, _, direction_status, _, _ = full_api_name_with_direction_status.split(".") if row_gpu.empty: - print_warn_log(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.') + logger.warning(f'This API : {full_api_name_with_direction_status} does not exist in the GPU data.') continue if len(row_gpu) > 1: msg = f'This API : {full_api_name_with_direction_status} has multiple records in the GPU data.' @@ -234,7 +234,7 @@ def analyse_csv(npu_data, gpu_data, config): elif direction_status == 'backward': backward_status.append(new_status) else: - print_error_log(f"Invalid direction status: {direction_status}") + logger.error(f"Invalid direction status: {direction_status}") if last_api_name is not None: if last_api_dtype in API_PRECISION_COMPARE_UNSUPPORT_LIST: @@ -389,4 +389,4 @@ def _api_precision_compare_parser(parser): if __name__ == '__main__': _api_precision_compare() - print_info_log("Compare task completed.") + logger.info("Compare task completed.") diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py index 350a407747..6ec5b8dbcd 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py @@ -5,7 +5,7 @@ import torch import numpy as np from rich.table import Table from rich.console import Console -from ..common.utils import get_json_contents, write_csv, print_warn_log, Const +from ..common.utils import get_json_contents, write_csv, Const from ..compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, apis_threshold from ..compare.compare_column import CompareColumn @@ -14,7 +14,8 @@ from ..compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, ge get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps from ..common.config import msCheckerConfig -from ...common.file_check import FileOpen +from atat.pytorch.common.log import logger +from atat.core.common.file_check import FileOpen class Comparator: @@ -83,7 +84,7 @@ class Comparator: else: passing_rate = "0%" - print_warn_log("The follwing tables will be deprecated in the future." + logger.warning("The follwing tables will be deprecated in the future." "The following results are for reference only.") console = Console() table_total = Table( diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py index 5511da7244..61d59c82c6 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -3,8 +3,9 @@ import os import numpy as np import torch import yaml -from ..common.utils import Const, print_warn_log, CompareException -from ...common.file_check import FileOpen +from ..common.utils import Const, CompareException +from atat.pytorch.common.log import logger +from atat.core.common.file_check import FileOpen current_time = time.strftime("%Y%m%d%H%M%S") @@ -170,7 +171,7 @@ def check_dtype_comparable(x, y): if y.dtype in Const.INT_TYPE: return True return False - print_warn_log(f"Compare: Unexpected dtype {x.dtype}, {y.dtype}") + logger.warning(f"Compare: Unexpected dtype {x.dtype}, {y.dtype}") return False diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py index 723fb8ec66..5ad1ed2cc5 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -20,8 +20,8 @@ import math import torch import numpy -from ..common.utils import Const, check_file_or_directory_path, check_object_type, print_warn_log, \ - print_error_log, get_full_data_path, CompareException +from ..common.utils import Const, check_file_or_directory_path, check_object_type, get_full_data_path, CompareException +from atat.pytorch.common.log import logger TORCH_TYPE = ["torch.device", "torch.dtype"] TENSOR_DATA_LIST = ["torch.Tensor", "torch.nn.parameter.Parameter"] @@ -62,7 +62,7 @@ def gen_data(info, need_grad, convert_type, real_data_path=None): try: data = eval(data_type)(data) except Exception as err: - print_error_log("Failed to convert the type to numpy: %s" % str(err)) + logger.error("Failed to convert the type to numpy: %s" % str(err)) elif data_type == "torch.Size": data = torch.Size(info.get("value")) else: @@ -170,7 +170,7 @@ def gen_common_tensor(low_info, high_info, shape, data_dtype, convert_type): low, high = int(low), int(high) tensor = torch.randint(low, high + 1, shape, dtype=eval(data_dtype)) else: - print_error_log('Dtype is not supported: ' + data_dtype) + logger.error('Dtype is not supported: ' + data_dtype) raise NotImplementedError() if tensor.nelement() == 0: return tensor @@ -231,7 +231,7 @@ def gen_args(args_info, need_grad=True, convert_type=None, real_data_path=None): elif arg is None: data = None else: - print_warn_log(f'Warning: {arg} is not supported') + logger.warning(f'Warning: {arg} is not supported') raise NotImplementedError() args_result.append(data) return args_result @@ -304,6 +304,6 @@ def gen_api_params(api_info, need_grad=True, convert_type=None, real_data_path=N if api_info.get("input_args"): args_params = gen_args(api_info.get("input_args"), need_grad, convert_type, real_data_path) else: - print_warn_log(f'Warning: No args in {api_info} ') + logger.warning(f'Warning: No args in {api_info} ') args_params = [] return args_params, kwargs_params diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index b4aa2ddeaa..b9d1a4fd1f 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -13,9 +13,9 @@ from atat.pytorch.api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_ get_validated_details_csv_path, preprocess_forward_content from atat.pytorch.api_accuracy_checker.compare.compare import Comparator from atat.pytorch.common import parse_json_info_forward_backward -from atat.pytorch.common.file_check import FileCheckConst, FileChecker, check_file_suffix, check_link, FileOpen, \ +from atat.core.common.file_check import FileCheckConst, FileChecker, check_file_suffix, check_link, FileOpen, \ check_path_before_create, create_directory -from atat.pytorch.common.log import print_error_log, print_warn_log, print_info_log +from atat.pytorch.common.log import logger def split_json_file(input_file, num_splits, filter_api): @@ -57,7 +57,7 @@ def split_json_file(input_file, num_splits, filter_api): def signal_handler(signum, frame): - print_warn_log(f'Signal handler called with signal {signum}') + logger.warning(f'Signal handler called with signal {signum}') raise KeyboardInterrupt() @@ -74,8 +74,8 @@ def run_parallel_ut(config): processes = [] device_id_cycle = cycle(config.device_id) if config.save_error_data_flag: - print_info_log("UT task error datas will be saved") - print_info_log(f"Starting parallel UT with {config.num_splits} processes") + logger.info("UT task error datas will be saved") + logger.info(f"Starting parallel UT with {config.num_splits} processes") progress_bar = tqdm(total=config.total_items, desc="Total items", unit="items") def create_cmd(api_info, dev_id): @@ -105,7 +105,7 @@ def run_parallel_ut(config): print(output, end='') sys.stdout.flush() except ValueError as e: - print_warn_log(f"An error occurred while reading subprocess output: {e}") + logger.warning(f"An error occurred while reading subprocess output: {e}") def update_progress_bar(progress_bar, result_csv_path): while any(process.poll() is None for process in processes): @@ -114,9 +114,9 @@ def run_parallel_ut(config): completed_items = len(result_file.readlines()) - 1 progress_bar.update(completed_items - progress_bar.n) except FileNotFoundError: - print_warn_log(f"Result CSV file not found: {result_csv_path}.") + logger.warning(f"Result CSV file not found: {result_csv_path}.") except Exception as e: - print_error_log(f"An unexpected error occurred while reading result CSV: {e}") + logger.error(f"An unexpected error occurred while reading result CSV: {e}") time.sleep(1) for api_info in config.api_files: @@ -141,27 +141,27 @@ def run_parallel_ut(config): try: os.remove(file) except FileNotFoundError: - print_warn_log(f"File not found and could not be deleted: {file}") + logger.warning(f"File not found and could not be deleted: {file}") try: for process in processes: process.communicate(timeout=None) except KeyboardInterrupt: - print_warn_log("Interrupted by user, terminating processes and cleaning up...") + logger.warning("Interrupted by user, terminating processes and cleaning up...") except Exception as e: - print_error_log(f"An unexpected error occurred: {e}") + logger.error(f"An unexpected error occurred: {e}") finally: if progress_bar.n < config.total_items: - print_warn_log("The UT task has not been completed. The parameter '-csv_path' along with the path to the result CSV file will be utilized to resume the UT task.") + logger.warning("The UT task has not been completed. The parameter '-csv_path' along with the path to the result CSV file will be utilized to resume the UT task.") clean_up() progress_bar_thread.join() try: comparator = Comparator(config.result_csv_path, config.result_csv_path, False) comparator.print_pretest_result() except FileNotFoundError as e: - print_error_log(f"Error: {e}") + logger.error(f"Error: {e}") except Exception as e: - print_error_log(f"An unexpected error occurred: {e}") + logger.error(f"An unexpected error occurred: {e}") def prepare_config(args): @@ -182,8 +182,8 @@ def prepare_config(args): else: result_csv_path = get_validated_result_csv_path(args.result_csv_path, 'result') details_csv_path = get_validated_details_csv_path(result_csv_path) - print_info_log(f"UT task result will be saved in {result_csv_path}") - print_info_log(f"UT task details will be saved in {details_csv_path}") + logger.info(f"UT task result will be saved in {result_csv_path}") + logger.info(f"UT task details will be saved in {details_csv_path}") return ParallelUTConfig(split_files, out_path, args.num_splits, args.save_error_data, args.jit_compile, args.device_id, result_csv_path, total_items, args.real_data_path) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index be25e24b37..0becf86613 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -4,10 +4,10 @@ import sys import torch_npu import torch from tqdm import tqdm +from ..run_ut.run_ut import exec_api, generate_device_params, get_api_info from atat.pytorch.api_accuracy_checker.common.utils import get_json_contents -from atat.pytorch.common.file_check import check_link -from atat.pytorch.common.log import print_info_log, print_warn_log, print_error_log - +from atat.core.common.file_check import check_link +from atat.pytorch.common.log import logger def check_tensor_overflow(x): if isinstance(x, torch.Tensor) and x.numel() != 0 and x.dtype != torch.bool: @@ -45,7 +45,7 @@ def check_data_overflow(x): def run_overflow_check(forward_file): - print_info_log("start UT test") + logger.info("start UT test") forward_content = get_json_contents(forward_file) for api_full_name, api_info_dict in tqdm(forward_content.items()): try: @@ -53,13 +53,13 @@ def run_overflow_check(forward_file): except Exception as err: api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0] if "not implemented for 'Half'" in str(err): - print_warn_log(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API " + logger.warning(f"API {api_name} not support half tensor in CPU, please add {api_name} to CONVERT_API " f"'fp16_to_fp32' list in accuracy_tools/api_accuracy_check/common/utils.py file.") elif "expected scalar type Long" in str(err): - print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") else: - print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) + logger.error(f"Run {api_full_name} UT Error: %s" % str(err)) def run_torch_api(api_full_name, api_info_dict): @@ -68,7 +68,7 @@ def run_torch_api(api_full_name, api_info_dict): api_name = api_full_name.split(".", 1)[1].rsplit(".", 2)[0] args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path='') if not need_grad: - print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." + logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) npu_args, npu_kwargs = generate_device_params(args, kwargs, False, api_name) if kwargs.get("device"): @@ -78,9 +78,9 @@ def run_torch_api(api_full_name, api_info_dict): cpu_overflow = check_data_overflow(out) npu_overflow = torch_npu.npu.utils.npu_check_overflow(npu_out) if cpu_overflow == npu_overflow: - print_warn_log("The %s overflow is a normal overflow." % api_full_name) + logger.warning("The %s overflow is a normal overflow." % api_full_name) else: - print_warn_log("The %s overflow is an abnormal overflow." % api_full_name) + logger.warning("The %s overflow is an abnormal overflow." % api_full_name) return @@ -111,11 +111,11 @@ def _run_overflow_check_command(args): try: torch.npu.set_device(npu_device) except Exception as error: - print_error_log(f"Set NPU device id failed. device id is: {args.device_id}") + logger.error(f"Set NPU device id failed. device id is: {args.device_id}") raise NotImplementedError from error run_overflow_check(api_info) if __name__ == '__main__': _run_overflow_check() - print_info_log("UT task completed.") + logger.info("UT task completed.") diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py index 59ccea4bc6..559055adc3 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -27,9 +27,9 @@ from atat.pytorch.hook_module.wrap_functional import FunctionalOPTemplate from atat.pytorch.hook_module.wrap_torch import TorchOPTemplate from atat.pytorch.api_accuracy_checker.common.config import msCheckerConfig from atat.pytorch.common.parse_json import parse_json_info_forward_backward -from atat.pytorch.common.file_check import FileOpen, FileCheckConst, FileChecker, \ +from atat.core.common.file_check import FileOpen, FileCheckConst, FileChecker, \ change_mode, check_file_suffix, check_link, check_path_before_create, create_directory -from atat.pytorch.common.log import print_info_log, print_warn_log, print_error_log +from atat.pytorch.common.log import logger from atat.pytorch.common.utils import Const current_time = time.strftime("%Y%m%d%H%M%S") @@ -150,12 +150,12 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): def run_ut(config): - print_info_log("start UT test") - print_info_log(f"UT task result will be saved in {config.result_csv_path}") - print_info_log(f"UT task details will be saved in {config.details_csv_path}") + logger.info("start UT test") + logger.info(f"UT task result will be saved in {config.result_csv_path}") + logger.info(f"UT task details will be saved in {config.details_csv_path}") if config.save_error_data: error_data_path = os.path.abspath(os.path.join(msCheckerConfig.error_data_path, UT_ERROR_DATA_DIR)) - print_info_log(f"UT task error_datas will be saved in {error_data_path}") + logger.info(f"UT task error_datas will be saved in {error_data_path}") compare = Comparator(config.result_csv_path, config.details_csv_path, config.is_continue_run_ut) with FileOpen(config.result_csv_path, 'r') as file: csv_reader = csv.reader(file) @@ -182,10 +182,10 @@ def run_ut(config): except Exception as err: [_, api_name, _] = api_full_name.split(Const.SEP) if "expected scalar type Long" in str(err): - print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " + logger.warning(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") else: - print_error_log(f"Run {api_full_name} UT Error: %s" % str(err)) + logger.error_level(f"Run {api_full_name} UT Error: %s" % str(err)) compare.write_summary_csv((api_full_name, "SKIP", "SKIP", str(err))) finally: if is_gpu: @@ -202,7 +202,7 @@ def is_unsupported_api(api_name): split_name = api_name.split(Const.SEP)[0] flag = split_name in [Const.NPU, Const.DISTRIBUTED] if flag: - print_info_log(f"{split_name} api is not supported for run ut. SKIP.") + logger.info(f"{split_name} api is not supported for run ut. SKIP.") return flag @@ -226,11 +226,11 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict in_fwd_data_list.append(kwargs) need_backward = api_full_name in backward_content if not need_grad: - print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." + logger.warning("%s function with out=... arguments don't support automatic differentiation, skip backward." % api_full_name) if api_name in not_backward_list: need_grad = False - print_warn_log( + logger.warning( "%s function backward result is None, skip backward." % api_full_name) need_backward = need_backward and need_grad if kwargs.get("device"): @@ -377,7 +377,7 @@ def preprocess_forward_content(forward_content): existing_kwargs = processed_content[variant].get('kwargs', {}) filtered_existing_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in existing_args if isinstance(arg, dict)] except KeyError as e: - print_error_log(f"KeyError: {e} when processing {key}") + logger.error(f"KeyError: {e} when processing {key}") if filtered_existing_args == filtered_new_args and existing_kwargs == new_kwargs: is_duplicate = True break @@ -408,7 +408,7 @@ def run_ut_command(args): else: torch.npu.set_device(used_device) except Exception as error: - print_error_log(f"Set device id failed. device id is: {args.device_id}") + logger.error(f"Set device id failed. device id is: {args.device_id}") raise NotImplementedError from error check_link(args.api_info_file) api_info = os.path.realpath(args.api_info_file) @@ -451,4 +451,4 @@ class UtDataInfo: if __name__ == '__main__': _run_ut() - print_info_log("UT task completed.") + logger.info("UT task completed.") diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/run_ut.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/test_data_generate.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/debug/accuracy_tools/atat/pytorch/common/__init__.py b/debug/accuracy_tools/atat/pytorch/common/__init__.py index b391e10311..8283aa5021 100644 --- a/debug/accuracy_tools/atat/pytorch/common/__init__.py +++ b/debug/accuracy_tools/atat/pytorch/common/__init__.py @@ -1,4 +1,2 @@ -from .recursive import recursive_apply_transform -from .log import print_error_log_rank_0, print_info_log_rank_0, print_warn_log_rank_0 from .parse_json import parse_json_info_forward_backward from .utils import seed_all diff --git a/debug/accuracy_tools/atat/pytorch/common/log.py b/debug/accuracy_tools/atat/pytorch/common/log.py index dddbdbee3e..74faff2029 100644 --- a/debug/accuracy_tools/atat/pytorch/common/log.py +++ b/debug/accuracy_tools/atat/pytorch/common/log.py @@ -3,66 +3,28 @@ import time import sys from .utils import get_rank_if_initialized -from .exceptions import DistributedNotInitializedError +from ...core.common.log import BaseLogger +from ...core.common.exceptions import DistributedNotInitializedError -def on_rank_0(func): - def func_rank_0(*args, **kwargs): - try: - current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - current_rank = None - - if current_rank is None or current_rank == 0: - return func(*args, **kwargs) - - return func_rank_0 - - -def _print_log(level, msg, end='\n'): - current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) - pid = os.getpid() - full_msg = current_time + "(" + str(pid) + ")-[" + level + "]" + msg - try: - current_rank = get_rank_if_initialized() - except DistributedNotInitializedError: - current_rank = None - if current_rank is not None: - full_msg = f"[rank {current_rank}]-" + full_msg - print(full_msg, end=end) - sys.stdout.flush() - +class PyTorchLogger(BaseLogger): + def __init__(self): + super().__init__() -def print_info_log(info_msg, end='\n'): - """ - Function Description: - print info log. - Parameter: - info_msg: the info message. - """ - _print_log("INFO", info_msg, end=end) + def get_rank(self): + return get_rank_if_initialized() -def print_error_log(error_msg): - """ - Function Description: - print error log. - Parameter: - error_msg: the error message. - """ - _print_log("ERROR", error_msg) - - -def print_warn_log(warn_msg): - """ - Function Description: - print warn log. - Parameter: - warn_msg: the warning message. - """ - _print_log("WARNING", warn_msg) - - -print_info_log_rank_0 = on_rank_0(print_info_log) -print_warn_log_rank_0 = on_rank_0(print_warn_log) -print_error_log_rank_0 = on_rank_0(print_error_log) + def _print_log(self, level, msg, end='\n'): + try: + current_rank = self.get_rank() + except DistributedNotInitializedError: + current_rank = None + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + pid = os.getpid() + if current_rank is not None: + full_msg = f"{current_time} ({pid}) [rank {current_rank}] [{level}] {msg}" + print(full_msg, end=end) + sys.stdout.flush() + +logger = PyTorchLogger() \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/common/parse_json.py b/debug/accuracy_tools/atat/pytorch/common/parse_json.py index dc594c4cf8..7900e9d46b 100644 --- a/debug/accuracy_tools/atat/pytorch/common/parse_json.py +++ b/debug/accuracy_tools/atat/pytorch/common/parse_json.py @@ -1,5 +1,5 @@ import json -from .exceptions import ParseJsonException +from ...core.common.exceptions import ParseJsonException def parse_json_info_forward_backward(json_path): diff --git a/debug/accuracy_tools/atat/pytorch/common/recursive.py b/debug/accuracy_tools/atat/pytorch/common/recursive.py deleted file mode 100644 index 9b222f5f52..0000000000 --- a/debug/accuracy_tools/atat/pytorch/common/recursive.py +++ /dev/null @@ -1,31 +0,0 @@ -import numpy as np -import torch - -from .log import print_warn_log - -_recursive_key_stack = [] -special_type = (torch.device, torch.dtype, torch.Size, torch.Tensor, np.integer, np.floating, np.bool_, np.complexfloating, \ - np.str_, np.byte, np.unicode_, bool, int, float, str, slice) - - -def recursive_apply_transform(args, transform): - global _recursive_key_stack - if isinstance(args, special_type): - arg_transform = transform(args, _recursive_key_stack) - return arg_transform - elif isinstance(args, (list, tuple)): - transform_result = [] - for i, arg in enumerate(args): - _recursive_key_stack.append(str(i)) - transform_result.append(recursive_apply_transform(arg, transform)) - _recursive_key_stack.pop() - return type(args)(transform_result) - elif isinstance(args, dict): - transform_dict = {} - for k, arg in args.items(): - _recursive_key_stack.append(str(k)) - transform_dict[k] = recursive_apply_transform(arg, transform) - _recursive_key_stack.pop() - return transform_dict - elif args is not None: - print_warn_log(f"Data type {type(args)} is not supported.") diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/atat/pytorch/common/utils.py index aa4a321407..0c3f0c7a8b 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/common/utils.py @@ -15,14 +15,13 @@ # limitations under the License. """ import os -import re import random import stat import torch import numpy as np from functools import wraps -from .exceptions import DistributedNotInitializedError +from ...core.common.exceptions import DistributedNotInitializedError try: import torch_npu diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index 12f790fbeb..bb5c5b63f7 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -18,11 +18,9 @@ import json import multiprocessing import os.path -import stat import sys -import math -import torch +import torch import numpy as np import pandas as pd import openpyxl @@ -34,10 +32,10 @@ from .match import graph_mapping from .highlight import HighlightRules, get_header_index from .npy_compare import compare_ops_apply, get_error_type, reshape_value, get_relative_err, get_error_message from ..advisor.advisor import Advisor -from ...core.utils import check_compare_param, add_time_with_xlsx, CompareException, CompareConst, \ - format_value, check_file_not_exists, check_configuration_param, task_dumppath_get, print_info_log, \ - print_warn_log, print_error_log, Const -from ...core.file_check_util import FileChecker, FileCheckConst, change_mode, FileOpen, create_directory +from ..common.log import logger +from ...core.common.utils import check_compare_param, add_time_with_xlsx, CompareException, CompareConst, \ + format_value, check_file_not_exists, check_configuration_param, task_dumppath_get, Const +from ...core.common.file_check import FileChecker, FileCheckConst, change_mode, FileOpen, create_directory def check_graph_mode(a_op_name, b_op_name): @@ -61,7 +59,7 @@ def check_op(npu_dict, bench_dict, fuzzy_match): try: is_match = fuzzy_check_op(a_op_name, b_op_name) except Exception as err: - print_warn_log("%s and %s can not fuzzy match." % (a_op_name, b_op_name)) + logger.warning("%s and %s can not fuzzy match." % (a_op_name, b_op_name)) is_match = False return is_match and struct_match @@ -309,7 +307,7 @@ def _do_multi_process(input_parma, result_df): result_df = _handle_multi_process(compare_ops, input_parma, result_df, multiprocessing.Manager().RLock()) return result_df except ValueError as e: - print_error_log('result dataframe is not found.') + logger.error('result dataframe is not found.') raise CompareException(CompareException.INVALID_DATA_ERROR) from e @@ -324,10 +322,10 @@ def read_dump_data(result_df): op_name_mapping_dict[npu_dump_name] = [npu_dump_tensor, npu_dump_tensor] return op_name_mapping_dict except ValueError as e: - print_error_log('result dataframe is not found.') + logger.error('result dataframe is not found.') raise CompareException(CompareException.INVALID_DATA_ERROR) from e except IndexError as e: - print_error_log('result dataframe elements can not be access.') + logger.error('result dataframe elements can not be access.') raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e @@ -345,11 +343,11 @@ def _handle_multi_process(func, input_parma, result_df, lock): pool = multiprocessing.Pool(process_num) def err_call(args): - print_error_log('multiprocess compare failed! Reason: {}'.format(args)) + logger.error('multiprocess compare failed! Reason: {}'.format(args)) try: pool.terminate() except OSError as e: - print_error_log("pool terminate failed") + logger.error("pool terminate failed") for process_idx, df_chunk in enumerate(df_chunks): idx = df_chunk_size * process_idx @@ -374,11 +372,11 @@ def compare_ops(idx, dump_path_dict, result_df, lock, input_parma): for i in range(len(result_df)): op_name = result_df.iloc[i, 0] if is_print_compare_log: - print_info_log("start compare: {}".format(op_name)) + logger.info("start compare: {}".format(op_name)) cos_sim, max_abs_err, max_relative_err, one_thousand_err_ratio, five_thousand_err_ratio, err_msg = compare_by_op( op_name, dump_path_dict, input_parma) if is_print_compare_log: - print_info_log( + logger.info( "[{}] Compare result: cosine {}, max_abs_err {}, max_relative_err {}, {}, one_thousand_err_ratio {}, " "five_thousand_err_ratio {}".format(op_name, cos_sim, max_abs_err, max_relative_err, err_msg, one_thousand_err_ratio, five_thousand_err_ratio)) @@ -437,10 +435,10 @@ def _save_cmp_result(offset, result: ComparisonResult, result_df, lock): result_df.loc[process_index, CompareConst.FIVE_THOUSANDTHS_ERR_RATIO] = result.five_thousand_err_ratio_result[i] return result_df except ValueError as e: - print_error_log('result dataframe is not found.') + logger.error('result dataframe is not found.') raise CompareException(CompareException.INVALID_DATA_ERROR) from e except IndexError as e: - print_error_log('result dataframe elements can not be access.') + logger.error('result dataframe elements can not be access.') raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e finally: lock.release() @@ -456,7 +454,7 @@ def check_accuracy(cos, max_abs_err): try: cos, max_abs_err = float(cos), float(max_abs_err) except ValueError: - print_warn_log("Cosine or MaxAbsErr can not get float value.") + logger.warning("Cosine or MaxAbsErr can not get float value.") return CompareConst.NONE if cos < CompareConst.COS_THRESHOLD and max_abs_err > CompareConst.MAX_ABS_ERR_THRESHOLD: return CompareConst.ACCURACY_CHECK_NO @@ -615,7 +613,7 @@ def find_compare_result_error_rows(result_df, highlight_dict, summary_compare): def highlight_rows_xlsx(result_df, highlight_dict, file_path): """Write and highlight results in Excel""" - print_info_log('Compare result is %s' % file_path) + logger.info('Compare result is %s' % file_path) wb = openpyxl.Workbook() ws = wb.active @@ -648,7 +646,7 @@ def compare(input_parma, output_path, stack_mode=False, auto_analyze=True, create_directory(output_path) check_compare_param(input_parma, output_path, stack_mode, summary_compare, md5_compare) except CompareException as error: - print_error_log('Compare failed. Please check the arguments and do it again!') + logger.error('Compare failed. Please check the arguments and do it again!') sys.exit(error.code) compare_core(input_parma, output_path, stack_mode=stack_mode, auto_analyze=auto_analyze, fuzzy_match=fuzzy_match, summary_compare=summary_compare, @@ -681,7 +679,7 @@ def compare_core(input_parma, output_path, **kwargs): summary_compare = kwargs.get('summary_compare', False) md5_compare = kwargs.get('md5_compare', False) - print_info_log("Please check whether the input data belongs to you. If not, there may be security risks.") + logger.info("Please check whether the input data belongs to you. If not, there may be security risks.") file_name = add_time_with_xlsx("compare_result" + suffix) file_path = os.path.join(os.path.realpath(output_path), file_name) check_file_not_exists(file_path) @@ -704,7 +702,7 @@ def compare_core(input_parma, output_path, **kwargs): def parse(pkl_file, module_name_prefix): if not isinstance(module_name_prefix, str): - print_error_log("The parameter:module_name_prefix is not a string.") + logger.error("The parameter:module_name_prefix is not a string.") raise CompareException(CompareException.INVALID_PARAM_ERROR) with FileOpen(pkl_file, "r") as f: done = False @@ -723,18 +721,18 @@ def parse(pkl_file, module_name_prefix): continue if info_prefix.find("stack_info") != -1: - print_info_log("\nTrace back({}):".format(msg[0])) + logger.info("\nTrace back({}):".format(msg[0])) for item in reversed(msg[1]): - print_info_log(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2])) - print_info_log(" {}".format(item[3])) + logger.info(" File \"{}\", line {}, in {}".format(item[0], item[1], item[2])) + logger.info(" {}".format(item[3])) continue if len(msg) > 5: summary_info = " [{}][dtype: {}][shape: {}][max: {}][min: {}][mean: {}]" \ .format(msg[0], msg[3], msg[4], msg[5][0], msg[5][1], msg[5][2]) if not title_printed: - print_info_log("\nStatistic Info:") + logger.info("\nStatistic Info:") title_printed = True - print_info_log(summary_info) + logger.info(summary_info) def op_item_parse(item, op_name, index, item_list=[], top_bool=True): @@ -880,7 +878,7 @@ def compare_process(file_handles, stack_mode, fuzzy_match, summary_compare=False stack_json_data = json.load(stack_json_handle) if fuzzy_match: - print_warn_log("This task uses fuzzy matching, which may affect the accuracy of the comparison.") + logger.warning("This task uses fuzzy matching, which may affect the accuracy of the comparison.") npu_ops_queue = [] bench_ops_queue = [] diff --git a/debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py b/debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py index 09d40b214d..fe226ad0ec 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/distributed_compare.py @@ -17,10 +17,11 @@ import os import sys import re -from ...core.utils import print_error_log, CompareException, check_compare_param, \ +from ...core.common.utils import CompareException, check_compare_param, \ check_configuration_param, task_dumppath_get, check_file_or_directory_path, check_regex_prefix_format_valid from .acc_compare import compare_core -from ...core.file_check_util import create_directory +from ...core.common.file_check import create_directory +from ..common.log import logger def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): @@ -46,7 +47,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): pattern = re.compile(rf'^{prefix}(?:0|[0-9][1-9]*)?$') for name in contents: if not pattern.match(name): - print_error_log( + logger.error( f"dump_dir contains '{name}'. Expected '{prefix}'. This name is not in the format of dump " f"output. Please check and delete irrelevant files in {dump_dir} and try again." ) @@ -66,12 +67,12 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): # Provide robustness on invalid directory inputs if not json_path: - print_error_log(f'No file is found in dump dir {dirname}. ') + logger.error(f'No file is found in dump dir {dirname}. ') raise CompareException(CompareException.NO_DUMP_FILE_ERROR) return json_path if kwargs.get('suffix'): - print_error_log("Argument 'suffix' is not supported for compare_distributed.") + logger.error("Argument 'suffix' is not supported for compare_distributed.") raise CompareException(CompareException.INVALID_PARAM_ERROR) stack_mode = kwargs.get('stack_mode', False) auto_analyze = kwargs.get('auto_analyze', True) @@ -80,7 +81,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): npu_ranks = sorted(check_and_return_dir_contents(npu_dump_dir, 'rank')) bench_ranks = sorted(check_and_return_dir_contents(bench_dump_dir, 'rank')) if len(npu_ranks) != len(bench_ranks): - print_error_log('The number of ranks in the two runs are different. ' + logger.error('The number of ranks in the two runs are different. ' 'Unable to match the ranks. Please use another folder to compare ' 'or use compare() api and manually match the ranks.') raise CompareException(CompareException.INVALID_PATH_ERROR) @@ -104,7 +105,7 @@ def compare_distributed(npu_dump_dir, bench_dump_dir, output_path, **kwargs): create_directory(output_path) check_compare_param(dump_result_param, output_path, stack_mode=stack_mode, summary_compare=summary_compare) except CompareException as error: - print_error_log('Compare failed. Please check the arguments and do it again!') + logger.error('Compare failed. Please check the arguments and do it again!') sys.exit(error.code) compare_core(dump_result_param, output_path, suffix=f'_{nr}-{br}', summary_compare=summary_compare, md5_compare=md5_compare, **kwargs) diff --git a/debug/accuracy_tools/atat/pytorch/compare/highlight.py b/debug/accuracy_tools/atat/pytorch/compare/highlight.py index fdc1130300..141d19f070 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/highlight.py +++ b/debug/accuracy_tools/atat/pytorch/compare/highlight.py @@ -1,7 +1,7 @@ import math import abc import numpy as np -from ...core.utils import CompareConst, get_header_index +from ...core.common.utils import CompareConst, get_header_index class HighlightCheck(abc.ABC): diff --git a/debug/accuracy_tools/atat/pytorch/compare/match.py b/debug/accuracy_tools/atat/pytorch/compare/match.py index 51fb2fb666..48ac4eee4a 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/match.py +++ b/debug/accuracy_tools/atat/pytorch/compare/match.py @@ -1,7 +1,7 @@ import os import yaml -from ...core.file_check_util import FileOpen -from ...core.utils import CompareException +from ...core.common.file_check import FileOpen +from ...core.common.utils import CompareException class AtenIrMapping(): diff --git a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py index b94a83f134..a36a7c8fcc 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/npy_compare.py @@ -1,6 +1,7 @@ import abc import numpy as np -from ...core.utils import CompareConst, Const, print_warn_log, format_value +from ...core.common.utils import CompareConst, Const, format_value +from ..common.log import logger def handle_inf_nan(n_value, b_value): @@ -69,7 +70,7 @@ def get_error_message(n_value, b_value, op_name, error_flag, error_file=None): if not n_value.shape: return "This is type of scalar data, can not compare." if n_value.dtype != b_value.dtype: - print_warn_log("Dtype of NPU and bench Tensor do not match: {}".format(op_name)) + logger.warning("Dtype of NPU and bench Tensor do not match: {}".format(op_name)) return "Dtype of NPU and bench Tensor do not match." return "" diff --git a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py index 451410dc96..551c40a8aa 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/debugger_config.py @@ -1,5 +1,7 @@ -from ..common import print_warn_log_rank_0, seed_all -from ...core.utils import Const +from ..common import seed_all +from ...core.common.utils import Const +from ..common.log import logger + class DebuggerConfig: def __init__(self, common_config, task_config, task, dump_path, level): @@ -20,12 +22,8 @@ class DebuggerConfig: self.is_forward_acl_dump = True self.summary_mode = task_config.summary_mode if task_config.summary_mode else Const.STATISTICS self.overflow_num = task_config.overflow_num if task_config.overflow_num else 1 - self.repair_scope = None - self.repair_api_str = None - self.on_step_end = None - self.repair_type = None - if self.task == "free_benchmark": + if self.task == Const.FREE_BENCHMARK: self.fuzz_device = task_config.fuzz_device if task_config.fuzz_device else 'npu' self.handler_type = task_config.handler_type if task_config.handler_type else 'check' self.pert_mode = task_config.pert_mode if task_config.pert_mode else 'improve_precision' @@ -79,7 +77,7 @@ class DebuggerConfig: if not isinstance(rank_id, int) or rank_id < 0: raise ValueError(f"rank {self.rank} must be an integer and greater than or equal to 0.") else: - print_warn_log_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.") + logger.warning_on_rank_0(f"Rank argument is provided. Only rank {self.rank} data will be dumpped.") def _check_step(self): if self.step: diff --git a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py index 8d67ae9ba6..26bf853532 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py @@ -2,9 +2,9 @@ import torch from torch.utils.data import dataloader from .debugger_config import DebuggerConfig from ..service import Service -from ..common import print_warn_log_rank_0 +from ..common.log import logger from ..pt_config import parse_json_config -from ..common.exceptions import MsaccException +from ...core.common.exceptions import MsaccException class PrecisionDebugger: @@ -39,7 +39,7 @@ class PrecisionDebugger: self.service = Service(self.config) self.enable_dataloader = self.config.enable_dataloader if self.enable_dataloader: - print_warn_log_rank_0("The enable_dataloader feature will be deprecated in the future.") + logger.warning_on_rank_0("The enable_dataloader feature will be deprecated in the future.") dataloader._BaseDataLoaderIter.__next__ = iter_tracer(dataloader._BaseDataLoaderIter.__next__) @property @@ -52,7 +52,7 @@ class PrecisionDebugger: if not instance: raise Exception("No instance of PrecisionDebugger found.") if instance.enable_dataloader: - print_warn_log_rank_0("DataLoader is enabled, start() skipped.") + logger.warning_on_rank_0("DataLoader is enabled, start() skipped.") else: instance.service.start(instance.model) @@ -62,7 +62,7 @@ class PrecisionDebugger: if not instance: raise Exception("PrecisionDebugger instance is not created.") if instance.enable_dataloader: - print_warn_log_rank_0("DataLoader is enabled, stop() skipped.") + logger.warning_on_rank_0("DataLoader is enabled, stop() skipped.") else: instance.service.stop() diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py index 3ffe161cba..f86fc41d55 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/__init__.py @@ -1,5 +1,5 @@ -from atat.pytorch.common import print_warn_log_rank_0, print_info_log_rank_0 -from atat.pytorch.common.exceptions import FreeBenchmarkException +from atat.core.common.log import logger +from atat.core.common.exceptions import FreeBenchmarkException from atat.pytorch.common.utils import Const from .main import FreeBenchmarkCheck diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py index c5dfefb43f..440348d78c 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/common/params.py @@ -1,9 +1,8 @@ -from abc import ABC from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple import torch -from atat.pytorch.free_benchmark import Const, print_warn_log_rank_0 +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.enums import ( DeviceType, FuzzLevel, @@ -78,7 +77,7 @@ def data_pre_deal(name, func, args, kwargs): index = check_args_type(args) data_params.valid_input_index = index if index == -1: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: 无标杆工具不支持当前算子的输入类型 {name}." ) return data_params diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py index 5094da3e2a..2497e2e869 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/grad_saver.py @@ -1,6 +1,6 @@ import torch from atat.pytorch.common.exceptions import FreeBenchmarkException -from atat.pytorch.free_benchmark import print_warn_log_rank_0 +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import CommonField from atat.pytorch.free_benchmark.common.params import DataParams, HandlerParams from atat.pytorch.free_benchmark.perturbed_layers.layer_factory import LayerFactory @@ -40,18 +40,18 @@ class GradSaver: ) data_processor.update_unequal_rows(handler.get_unequal_rows()) except IndexError: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: grad index out of range. api:{self.handler_params.api_name}." f"index:{new_grad_index}, perturbation grad len {len(self.perturbed_grad_input)}" ) return grad except FreeBenchmarkException as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: grad input check error: {e}" ) return grad except Exception as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: grad compare error: {e}" ) return grad @@ -76,7 +76,7 @@ class GradSaver: self.data_params.original_result = self.origin_grad_input handler.handle(self.data_params) except Exception as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: compare two vjp failed: api:{self.handler_params.api_name}." f"{e}" ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py index 80c526be91..85aa68f13b 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/compare/single_benchmark.py @@ -1,7 +1,7 @@ import math import torch -from atat.pytorch.free_benchmark import print_warn_log_rank_0 +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.utils import TorchC @@ -61,7 +61,7 @@ class SingleCompare: actual.dtype, ThresholdConfig.BENCHMARK_THD_DICT.get(torch.float32) ) if self.filter_overflow(golden) > 0: - print_warn_log_rank_0("[atat] Free Benchmark: inf and nan" + logger.warning_on_rank_0("[atat] Free Benchmark: inf and nan" "in golden tensor is not supported.") return True actual = self.replace_inf_or_nan(actual) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py index c2e0005181..ba3e9a6b25 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/main.py @@ -1,9 +1,7 @@ -import importlib from abc import ABC import torch -from atat.pytorch.free_benchmark import Const, print_warn_log_rank_0 - +from atat.pytorch.free_benchmark import Const, logger from atat.pytorch.free_benchmark.common.params import data_pre_deal, make_handler_params from atat.pytorch.free_benchmark.common.enums import ( PerturbationMode, @@ -80,7 +78,7 @@ class FreeBenchmarkCheck(ABC): try: grad_saver = getattr(module, "grad_saver") except AttributeError: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: get grad saver failed. api_name:{name}" ) return @@ -96,7 +94,7 @@ class FreeBenchmarkCheck(ABC): _new_grad_output, need_grad_tensors, _inner_args ) except Exception as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: grad vjp calculate failed. api_name:{name} error: {e}" ) return diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py index d5ba63c6a9..af8a93f7d4 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/add_noise.py @@ -1,8 +1,5 @@ import torch -from atat.pytorch.free_benchmark import ( - print_info_log_rank_0, - print_warn_log_rank_0, -) +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams @@ -39,7 +36,7 @@ class AddNoiseLayer(NpuBaseLayer): """ 对输入添加扰动并返回 """ - print_info_log_rank_0( + logger.info_on_rank_0( f"[atat] Free benchmark: Perturbation is " f"{PerturbationMode.ADD_NOISE} of {self.api_name}." ) @@ -62,13 +59,13 @@ class AddNoiseLayer(NpuBaseLayer): 判断是否需要添加扰动 """ if not self.perturbed_value: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.api_name}, " f"dtype unsupported. Cancel perturbation." ) return False if tensor_obj.numel() == 0: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: For {self.api_name}, tensor shape must > 0." f" Cancel adding noise." ) @@ -79,13 +76,13 @@ class AddNoiseLayer(NpuBaseLayer): try: max_val = TorchC.max(TorchC.abs(tensor_obj)).item() except Exception: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.api_name}, " f"when calculate maximun value, tensor is changed to float32." ) max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.api_name}, " f"Maximun value is less than the minimun threshold. Cancel add noise." ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py index 2c1ed9a3e1..40b99acf41 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/bit_noise.py @@ -1,8 +1,5 @@ import torch -from atat.pytorch.free_benchmark import ( - print_info_log_rank_0, - print_warn_log_rank_0, -) +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams @@ -55,7 +52,7 @@ class BitNoiseLayer(NpuBaseLayer): """ 对输入添加扰动并返回 """ - print_info_log_rank_0( + logger.info_on_rank_0( f"[atat] Free benchmark: Perturbation is " f"{PerturbationMode.BIT_NOISE} of {self.api_name}." ) @@ -67,13 +64,13 @@ class BitNoiseLayer(NpuBaseLayer): 判断是否需要添加扰动, bit翻转 """ if not self.bit_type: - print_warn_log_rank_0( + logger.info_on_rank_0( f"[atat] Free Benchmark: For {self.api_name}, " f"dtype unsupported. Cancel perturbation." ) return False if tensor_obj.numel() == 0: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free benchmark: For {self.api_name}, tensor shape must > 0" f" Cancel adding noise." ) @@ -84,13 +81,13 @@ class BitNoiseLayer(NpuBaseLayer): try: max_val = TorchC.max(TorchC.abs(tensor_obj)).item() except Exception: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.api_name}, " f"when calculate maximun value, tensor is changed to float32." ) max_val = TorchC.max(TorchC.abs(tensor_obj.to(torch.float32))).item() if max_val < abs_tol: - print_warn_log_rank_0( + logger.info_on_rank_0( f"[atat] Free Benchmark: For {self.api_name}, " f"Maximun value is less than the minimun threshold. Cancel add noise." ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py index b4ee673841..b7a967e18b 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/change_value.py @@ -1,5 +1,5 @@ import torch -from atat.pytorch.free_benchmark import print_warn_log_rank_0, print_info_log_rank_0 +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import TorchC @@ -43,7 +43,7 @@ class ChangeValueLayer(NpuBaseLayer): """ 对输入添加扰动并返回 """ - print_info_log_rank_0( + logger.info_on_rank_0( f"[atat] Free benchmark: Perturbation is " f"{PerturbationMode.CHANGE_VALUE} of {self.api_name}." ) @@ -55,7 +55,7 @@ class ChangeValueLayer(NpuBaseLayer): 判断是否需要添加扰动, 首尾值交换 """ if tensor_obj.size(0) < 2: - print_warn_log_rank_0( + logger.info_on_rank_0( f"[atat] Free Benchmark: For {self.api_name}, " f"size 0 must greater than 1. Cancel change value." ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py index e18b303a60..d5b85b949c 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/improve_precision.py @@ -1,5 +1,5 @@ import torch -from atat.pytorch.free_benchmark import Const, print_info_log_rank_0 +from atat.pytorch.free_benchmark import Const, logger from atat.pytorch.free_benchmark.common.constant import CommonField from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams @@ -31,7 +31,7 @@ class ImprovePrecisionLayer(NpuBaseLayer): return tensor_obj def handle(self, params: DataParams) -> torch.Any: - print_info_log_rank_0( + logger.info_on_rank_0( f"[atat] Free benchmark: Perturbation is " f"{PerturbationMode.IMPROVE_PRECISION} of {self.api_name}." ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py index 204e649d80..bb065385c6 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/npu/no_change.py @@ -1,5 +1,5 @@ import torch -from atat.pytorch.free_benchmark import print_info_log_rank_0 +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.enums import PerturbationMode from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.perturbed_layers.npu.npu_base_layser import ( @@ -20,7 +20,7 @@ class NoChangeLayer(NpuBaseLayer): """ 对输入添加扰动并返回 """ - print_info_log_rank_0( + logger.info_on_rank_0( f"[atat] Free benchmark: Perturbation is " f"{PerturbationMode.NO_CHANGE} of {self.api_name}." ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py index 387f9447fd..024958ffbe 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/perturbed_layers/run_cpu.py @@ -1,5 +1,5 @@ import torch -from atat.pytorch.free_benchmark import print_info_log_rank_0 +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import Tools from atat.pytorch.free_benchmark.common.enums import DeviceType @@ -10,7 +10,7 @@ class CpuLayer(BaseLayer): def handle(self, params: DataParams) -> torch.Any: - print_info_log_rank_0( + logger.info_on_rank_0( f"[atat] Free benchmark: Perturbation is to_cpu of {self.api_name}." ) new_args = Tools.convert_device_and_dtype(params.args, DeviceType.CPU, change_dtype=True) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py index 0d1f6ec5d0..1f1f8e1cba 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/base_handler.py @@ -5,7 +5,7 @@ from typing import Any, Optional, Tuple import torch from atat.pytorch.free_benchmark import ( Const, - print_warn_log_rank_0, + logger, ) from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.enums import ( @@ -101,7 +101,7 @@ class FuzzHandler(ABC): origin_output, perturbed_output ) except Exception as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.params.api_name}, " f"when computing ratio," f" y1 or y2 dtype is not supported {e}" @@ -130,7 +130,7 @@ class FuzzHandler(ABC): origin_output / perturbed_output, ) elif not isinstance(perturbed_output, torch.Tensor): - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.params.api_name} " f"The compare for output type {type(perturbed_output)} is not supported" ) @@ -182,7 +182,7 @@ class FuzzHandler(ABC): ) ) except Exception as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.params.api_name}, " f"when campare the result exception raise {e}" ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py index 2f590855f1..7444c855eb 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/check_handler.py @@ -1,7 +1,6 @@ from typing import Any -import torch -from atat.pytorch.free_benchmark import print_warn_log_rank_0 +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.enums import DeviceType from atat.pytorch.free_benchmark.compare.single_benchmark import SingleCompare from atat.pytorch.free_benchmark.common.params import DataParams, make_unequal_row @@ -34,7 +33,7 @@ class CheckerHandler(FuzzHandler): else: self.other_compare(data_params) except Exception as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.params.api_name}, " f"when campare the result exception raise {e}" ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py index 789e2653aa..fa5c6f3749 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/fix_handler.py @@ -3,7 +3,7 @@ from typing import Any from atat.pytorch.free_benchmark.common.params import DataParams from atat.pytorch.free_benchmark.common.utils import Tools from atat.pytorch.free_benchmark.result_handlers.base_handler import FuzzHandler -from atat.pytorch.free_benchmark import print_warn_log_rank_0 +from atat.pytorch.free_benchmark import logger class FixHandler(FuzzHandler): @@ -17,7 +17,7 @@ class FixHandler(FuzzHandler): data_params.original_result, data_params.perturbed_result ) except Exception as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.params.api_name} " f"Fix output failed. " ) diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py index 50f791d81e..cff629854d 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/handler_factory.py @@ -1,6 +1,5 @@ from atat.pytorch.free_benchmark import FreeBenchmarkException from atat.pytorch.free_benchmark.common.constant import PreheatConfig -from atat.pytorch.free_benchmark.common.utils import Tools from atat.pytorch.free_benchmark.common.enums import HandlerType from atat.pytorch.free_benchmark.common.params import HandlerParams from atat.pytorch.free_benchmark.result_handlers.check_handler import CheckerHandler diff --git a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py index 1e70067b93..ee2ee11a79 100644 --- a/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py +++ b/debug/accuracy_tools/atat/pytorch/free_benchmark/result_handlers/preheat_handler.py @@ -1,7 +1,7 @@ import math from typing import Any -from atat.pytorch.free_benchmark import print_info_log_rank_0, print_warn_log_rank_0 +from atat.pytorch.free_benchmark import logger from atat.pytorch.free_benchmark.common.constant import ThresholdConfig from atat.pytorch.free_benchmark.common.counter import preheat_counter from atat.pytorch.free_benchmark.common.enums import DeviceType @@ -74,14 +74,14 @@ class PreheatHandler(FuzzHandler): try: cpu_consistent = self.compare_npu_and_cpu(data_params) except Exception as e: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.params.api_name}, " f"when campare to cpu exception raise {e}" ) try: first_dtype = Tools.get_first_tensor_dtype(data_params.perturbed_result) except RuntimeError: - print_warn_log_rank_0( + logger.warning_on_rank_0( f"[atat] Free Benchmark: For {self.params.api_name}, " f"the output sequence does not contain tensors." ) @@ -96,7 +96,7 @@ class PreheatHandler(FuzzHandler): res = curr_called_seq in need_sample_set if res: total_count = preheat_counter.get_one_step_used_api(self.pure_name) - print_info_log_rank_0( + logger.info_on_rank_0( f"[atat] Free benchmark: preheat sample in step{self.params.step}" f"api_name {self.params.api_name}, " f"curr_called_seq: {curr_called_seq}/{total_count}" diff --git a/debug/accuracy_tools/atat/pytorch/functional/__init__.py b/debug/accuracy_tools/atat/pytorch/functional/__init__.py index 12e530d4c9..e69de29bb2 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/__init__.py +++ b/debug/accuracy_tools/atat/pytorch/functional/__init__.py @@ -1,4 +0,0 @@ -from .repair import build_repair -from .scope import build_scope -from .step_post_process import build_step_post_process -from .data_collector import build_data_collector diff --git a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py b/debug/accuracy_tools/atat/pytorch/functional/dump_module.py index fed73ad537..8a2f57b5e2 100644 --- a/debug/accuracy_tools/atat/pytorch/functional/dump_module.py +++ b/debug/accuracy_tools/atat/pytorch/functional/dump_module.py @@ -1,20 +1,21 @@ import torch.nn as nn -from atat.core.utils import print_error_log, DumpException -from .scope import BaseScope +from ..common.log import logger from ..common.utils import Const from ..hook_module.api_registry import api_register from ..debugger.precision_debugger import PrecisionDebugger +from ...core.common.exceptions import MsaccException +from ...core.data_dump.scope import BaseScope module_count = {} def module_dump(module, dump_name): if not isinstance(module, nn.Module): - print_error_log("The parameter:module in module_dump is not a Module subclass.") - raise DumpException(DumpException.INVALID_PARAM_ERROR) + logger.error("The parameter:module in module_dump is not a Module subclass.") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) if not isinstance(dump_name, str): - print_error_log("The parameter:dump_name in module_dump is not a str type.") - raise DumpException(DumpException.INVALID_PARAM_ERROR) + logger.error("The parameter:dump_name in module_dump is not a str type.") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) api_register.api_originality() if dump_name not in module_count: module_count[dump_name] = 0 diff --git a/debug/accuracy_tools/atat/pytorch/functional/repair.py b/debug/accuracy_tools/atat/pytorch/functional/repair.py deleted file mode 100644 index aed8326424..0000000000 --- a/debug/accuracy_tools/atat/pytorch/functional/repair.py +++ /dev/null @@ -1,90 +0,0 @@ -from abc import ABC, abstractmethod - -import torch - -from .scope import build_scope, ListScope, BaseScope -from ..common.exceptions import RepairException -from ..common import recursive_apply_transform, print_info_log_rank_0 - - -def build_repair(config): - if config.repair_type is None: - return None - elif config.repair_type == RepairAPI.ToCPU: - return RepairAPI_toCPU(config) - elif config.repair_type == RepairAPI.RaisePrecision: - return RepairAPI_raise(config) - else: - raise RepairException(RepairException.InvalidRepairType, f"精度修复类型" - f"须配置为'{RepairAPI.ToCPU}'或'{RepairAPI.RaisePrecision}," - f"实际配置为{config.repair_type}") - - -class RepairAPI(ABC): - ToCPU = "cpu" - RaisePrecision = "raise" - - def __init__(self, config): - self.config = config - self.scope = build_scope(ListScope, config.repair_scope, config.repair_api_str) - self.saved, self.towards = "None", "None" - - def check_name_and_module_type(self, name, module_type): - if module_type == BaseScope.Module_Type_Module: - return False - if not self.scope.check(name): - return False - return True - - def convert(self, name, module_type, args, kwargs): - is_target = self.check_name_and_module_type(name, module_type) - if is_target: - args = recursive_apply_transform(args, self.fx) - kwargs = recursive_apply_transform(kwargs, self.fx) - print_info_log_rank_0(f"[msProbe] convert inputs of {name} to " - f"{self.towards}.") - return args, kwargs - - def invert(self, name, module_type, out_feat): - is_target = self.check_name_and_module_type(name, module_type) - if is_target: - out_feat = recursive_apply_transform(out_feat, self.inv_fx) - print_info_log_rank_0(f"[msProbe] convert outputs of {name} back to "\ - f"{self.saved}.") - return out_feat - - -class RepairAPI_toCPU(RepairAPI): - def fx(self, arg, _): - if isinstance(arg, torch.Tensor): - self.saved = arg.device - self.towards = torch.device("cpu") - return arg.cpu() - return arg - - def inv_fx(self, arg, _): - if isinstance(arg, torch.Tensor): - return arg.to(self.saved) - return arg - - -class RepairAPI_raise(RepairAPI): - raise_dtype_map = { - torch.bfloat16: torch.float32, - torch.float16: torch.float32 - } - - def fx(self, arg, _): - if isinstance(arg, torch.Tensor): - self.saved = arg.dtype - self.towards = RepairAPI_raise.raise_dtype_map.get(self.saved) - # bug: nested input may be of various dtypes. which to save and invert? - return arg.to(self.towards) - return arg - - def inv_fx(self, arg, _): - if isinstance(arg, torch.Tensor): - return arg.to(self.saved) - return arg - - diff --git a/debug/accuracy_tools/atat/pytorch/functional/step_post_process.py b/debug/accuracy_tools/atat/pytorch/functional/step_post_process.py deleted file mode 100644 index 7f0d345932..0000000000 --- a/debug/accuracy_tools/atat/pytorch/functional/step_post_process.py +++ /dev/null @@ -1,43 +0,0 @@ -from abc import ABC, abstractmethod -from ..common.exceptions import StepException - - -def run_parallel_ut(config): - pass - - -def compare_distrbuted(config): - pass - - -def build_step_post_process(config): - if not config.on_step_end: - return None - if config.on_step_end == StepPostProcess.SingleAPICheck: - return SingleAPICheck(config) - elif config.on_step_end == StepPostProcess.Compare: - return AutoCompare(config) - else: - raise StepException(StepException.InvalidPostProcess, f"step后处理须配置为" - f"'{StepPostProcess.SingleAPICheck}'或'{StepPostProcess.Compare}'," - f"实际配置为{config.on_step_end}") - - -class StepPostProcess(ABC): - SingleAPICheck = 'single_api_check' - Compare = 'compare' - - -class SingleAPICheck: - def __init__(self, config): - self.config = config - - def run(self): - run_parallel_ut(self.config) - -class AutoCompare: - def __init__(self, config): - self.config = config - - def run(self): - compare_distrbuted(self.config.bench_dump_path, self.config.dump_path) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py index ae4a7abdab..09d45927be 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py @@ -22,6 +22,7 @@ import torch.nn as nn import torch.utils.hooks as full_hooks from ..common.utils import Const + class HOOKModule(nn.Module): module_count = {} inner_stop_hook = {} diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/utils.py b/debug/accuracy_tools/atat/pytorch/hook_module/utils.py index 96883072eb..6c2651d42f 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/utils.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/utils.py @@ -18,7 +18,7 @@ import os import yaml -from ..common.file_check import FileOpen +from ...core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py index 8666287095..a2707eb58e 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py @@ -22,7 +22,7 @@ import yaml from .hook_module import HOOKModule from ..common.utils import torch_device_guard, Const -from ..common.file_check import FileOpen +from ...core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py index 68ce83c16b..711a8a10b2 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py @@ -22,7 +22,7 @@ import yaml from .hook_module import HOOKModule from ..common.utils import torch_device_guard, Const -from ..common.file_check import FileOpen +from ...core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py index f5dde41b12..930e47be5d 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py @@ -22,13 +22,13 @@ import yaml from .hook_module import HOOKModule from ..common.utils import torch_device_guard, Const -from ..common.log import print_info_log_rank_0 -from ..common.file_check import FileOpen +from ..common.log import logger +from ...core.common.file_check import FileOpen def remove_dropout(): if torch.__version__ > "1.8": - print_info_log_rank_0("For precision comparison, the probability p in the dropout method is set to 0.") + logger.info_on_rank_0("For precision comparison, the probability p in the dropout method is set to 0.") import torch.nn.functional as F from torch import _VF from torch.overrides import has_torch_function_unary, handle_torch_function diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py index e910e609c8..f72cfff502 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py @@ -22,7 +22,7 @@ import yaml from .hook_module import HOOKModule from ..common.utils import torch_device_guard, torch_without_guard_version, Const -from ..common.file_check import FileOpen +from ...core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py index bf4a889524..80d30e39e8 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py @@ -22,7 +22,7 @@ import yaml from .hook_module import HOOKModule from ..common.utils import torch_device_guard, parameter_adapter, Const -from ..common.file_check import FileOpen +from ...core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py index 09fa67166c..5bc03ae575 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py @@ -22,7 +22,7 @@ import yaml from .hook_module import HOOKModule from ..common.utils import torch_device_guard, Const -from ..common.file_check import FileOpen +from ...core.common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py index 9303aec6e0..c7330679d9 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py @@ -21,7 +21,7 @@ import torch import yaml from .hook_module import HOOKModule -from ..common.file_check import FileOpen +from ...core.common.file_check import FileOpen from ..common.utils import torch_device_guard, Const cur_path = os.path.dirname(os.path.realpath(__file__)) diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py index fda3d37bc9..87f6217132 100644 --- a/debug/accuracy_tools/atat/pytorch/module_processer.py +++ b/debug/accuracy_tools/atat/pytorch/module_processer.py @@ -1,8 +1,8 @@ from functools import wraps import torch from torch.utils.hooks import BackwardHook -from .functional.scope import ModuleRangeScope from .common.utils import Const +from ..core.data_dump.scope import ModuleRangeScope class ModuleProcesser: diff --git a/debug/accuracy_tools/atat/pytorch/pt_config.py b/debug/accuracy_tools/atat/pytorch/pt_config.py index 46d9b70cc9..ff0d919588 100644 --- a/debug/accuracy_tools/atat/pytorch/pt_config.py +++ b/debug/accuracy_tools/atat/pytorch/pt_config.py @@ -2,11 +2,10 @@ import json import os from ..core.common_config import CommonConfig, BaseConfig -from ..core.file_check_util import FileOpen -from ..core.utils import Const +from ..core.common.file_check import FileOpen +from ..core.common.utils import Const -# 特定任务配置类 class TensorConfig(BaseConfig): def __init__(self, json_config): super().__init__(json_config) diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/atat/pytorch/service.py index 8b7f2a1d87..be33a6092f 100644 --- a/debug/accuracy_tools/atat/pytorch/service.py +++ b/debug/accuracy_tools/atat/pytorch/service.py @@ -2,55 +2,45 @@ import functools import os from pathlib import Path -from .common import print_info_log_rank_0 -from .common.file_check import FileChecker, FileCheckConst, check_path_before_create -from .common.utils import get_rank_if_initialized, is_gpu, Const, DistributedNotInitializedError -from .functional import build_repair, build_data_collector, build_step_post_process -from .functional.data_processor import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs -from .functional.scope import BaseScope +from ..core.common.log import logger +from ..core.common.file_check import FileChecker, FileCheckConst, check_path_before_create +from ..core.common.utils import Const +from ..core.common.exceptions import DistributedNotInitializedError, MsaccException +from ..core.data_dump.data_collector import build_data_collector +from ..core.data_dump.scope import BaseScope +from ..core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from .common.utils import get_rank_if_initialized +from .module_processer import ModuleProcesser from .hook_module import remove_dropout from .hook_module.api_registry import api_register -from .module_processer import ModuleProcesser - -from ..core.utils import DumpException class Service: - make_dir_flag = True - REGISTER_HOOK_KWARGS = ["overflow_nums", "dump_mode", "dump_config"] - def __init__(self, config): self.model = None self.config = config self.data_collector = build_data_collector(config) self.module_processor = ModuleProcesser(self.data_collector.scope) - self.repair = build_repair(config) - self.step_post_process = build_step_post_process(config) self.switch = False self.current_iter = 0 self.first_start = True self.current_rank = None - self.first_touch_dir = True self.dump_iter_dir = None def build_hook(self, module_type, name): - def pre_hook(repair, api_or_module_name, module, args, kwargs): - nonlocal module_type, pid + def pre_hook(api_or_module_name, module, args, kwargs): if module_type == BaseScope.Module_Type_Module: api_or_module_name = module.mindstudio_reserved_name self.data_collector.visit_and_clear_overflow_status(api_or_module_name) if not self.switch: return args, kwargs - if repair: - args, kwargs = repair.convert(api_or_module_name, module_type, args, kwargs) if self.data_collector: module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=None) self.data_collector.pre_forward_data_collect(api_or_module_name, module, pid, module_input_output) return args, kwargs - def forward_hook(repair, api_or_module_name, module, args, kwargs, output): - nonlocal module_type, pid + def forward_hook(api_or_module_name, module, args, kwargs, output): if module_type == BaseScope.Module_Type_Module: api_or_module_name = module.mindstudio_reserved_name self.data_collector.visit_and_clear_overflow_status(api_or_module_name) @@ -62,13 +52,9 @@ class Service: self.data_collector.forward_data_collect(api_or_module_name, module, pid, module_input_output) if self.data_collector.if_return_forward_new_output(): return self.data_collector.get_forward_new_output() - if repair: - output = repair.invert(api_or_module_name, module_type, output) - return output - def backward_hook(repair, api_or_module_name, module, grad_input, grad_output): - nonlocal module_type, pid + def backward_hook(api_or_module_name, module, grad_input, grad_output): if module_type == BaseScope.Module_Type_Module: api_or_module_name = module.mindstudio_reserved_name self.data_collector.visit_and_clear_overflow_status(api_or_module_name) @@ -82,15 +68,13 @@ class Service: pid = os.getpid() forward_name_template = name + Const.FORWARD backward_name_template = name + Const.BACKWARD - pre_forward_hook = functools.partial(pre_hook, self.repair, forward_name_template) - forward_hook = functools.partial(forward_hook, self.repair, forward_name_template) - backward_hook = functools.partial(backward_hook, None, backward_name_template) + pre_forward_hook = functools.partial(pre_hook, forward_name_template) + forward_hook = functools.partial(forward_hook, forward_name_template) + backward_hook = functools.partial(backward_hook, backward_name_template) return pre_forward_hook, forward_hook, backward_hook def step(self): self.current_iter += 1 - if self.step_post_process: - self.step_post_process() self.data_collector.update_iter(self.current_iter) def start(self, model): @@ -111,10 +95,10 @@ class Service: self.register_hook_new() self.first_start = False self.switch = True - print_info_log_rank_0(f"Dump switch is turned on at step {self.current_iter}. ") + logger.info_on_rank_0(f"Dump switch is turned on at step {self.current_iter}. ") if self.config.level != "L2": self.create_dirs() - print_info_log_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.") + logger.info_on_rank_0(f"Dump data will be saved in {self.dump_iter_dir}.") def stop(self): if self.config.level == "L2": @@ -151,16 +135,12 @@ class Service: dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path) def register_hook_new(self): - hook_name = self.config.task - - if "overflow_check" in hook_name and not is_gpu: - pass - - print_info_log_rank_0("The {} hook function is successfully mounted to the model.".format(hook_name)) + logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task)) if self.config.level in ["L0", "mix"]: if self.model is None: - raise DumpException("Model is None") - print_info_log_rank_0("The init dump mode is enabled, and the module dump function will not be available") + logger.error_on_rank_0("The model is None.") + raise MsaccException(MsaccException.INVALID_PARAM_ERROR) + logger.info_on_rank_0("The init dump mode is enabled, and the module dump function will not be available") for name, module in self.model.named_modules(): if module == self.model: continue @@ -184,5 +164,5 @@ class Service: api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) api_register.api_modularity() - if Const.STATISTICS in hook_name or Const.TENSOR in hook_name: + if Const.STATISTICS == self.config.task or Const.TENSOR == self.config.task: remove_dropout() diff --git a/debug/accuracy_tools/atat/test/core_ut/test_utils.py b/debug/accuracy_tools/atat/test/core_ut/test_utils.py index 9492bbc9f9..7dc9eb8934 100644 --- a/debug/accuracy_tools/atat/test/core_ut/test_utils.py +++ b/debug/accuracy_tools/atat/test/core_ut/test_utils.py @@ -1,11 +1,11 @@ from unittest import TestCase from unittest.mock import patch -from atat.core.utils import check_seed_all, Const, CompareException - +from debug.accuracy_tools.atat.core.common.utils import check_seed_all, Const, CompareException +from atat.core.common.log import logger class TestUtils(TestCase): - @patch("atat.core.utils.print_error_log") + @patch.object(logger, "error") def test_check_seed_all(self, mock_print_error_log): self.assertIsNone(check_seed_all(1234, True)) self.assertIsNone(check_seed_all(0, True)) diff --git a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py index 0029e24bda..f0d6109cd8 100644 --- a/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py +++ b/debug/accuracy_tools/atat/test/mindspore_ut/test_ms_config.py @@ -1,7 +1,7 @@ from unittest import TestCase from unittest.mock import patch, mock_open -from atat.core.utils import Const +from debug.accuracy_tools.atat.core.common.utils import Const from atat.mindspore.ms_config import parse_json_config diff --git a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py index 8279c20776..e7cbece2d0 100644 --- a/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py +++ b/debug/accuracy_tools/atat/test/pytorch_ut/test_pt_config.py @@ -1,7 +1,7 @@ from unittest import TestCase from unittest.mock import patch, mock_open -from atat.core.utils import Const +from debug.accuracy_tools.atat.core.common.utils import Const from atat.pytorch.pt_config import parse_json_config diff --git a/debug/accuracy_tools/atat/test/run_ut.py b/debug/accuracy_tools/atat/test/run_ut.py index 7f51d266c2..91e4e31709 100644 --- a/debug/accuracy_tools/atat/test/run_ut.py +++ b/debug/accuracy_tools/atat/test/run_ut.py @@ -3,7 +3,7 @@ import shutil import subprocess import sys -from atat.core.log import print_info_log, print_error_log +from atat.core.common.log import logger def get_ignore_dirs(cur_dir): @@ -12,13 +12,13 @@ def get_ignore_dirs(cur_dir): import torch import torch_npu except ImportError: - print_info_log(f"Skipping the {cur_dir}/pytorch_ut directory") + logger.info(f"Skipping the {cur_dir}/pytorch_ut directory") ignore_dirs.extend(["--ignore", f"{cur_dir}/pytorch_ut"]) try: import mindspore except ImportError: - print_info_log(f"Skipping the {cur_dir}/mindspore_ut directory") + logger.info(f"Skipping the {cur_dir}/mindspore_ut directory") ignore_dirs.extend(["--ignore", f"{cur_dir}/mindspore_ut"]) return ignore_dirs @@ -43,14 +43,14 @@ def run_ut(): while result_ut.poll() is None: line = result_ut.stdout.readline().strip() if line: - print_info_log(str(line)) + logger.info(str(line)) ut_flag = False if result_ut.returncode == 0: ut_flag = True - print_info_log("run ut successfully.") + logger.info("run ut successfully.") else: - print_error_log("run ut failed.") + logger.error("run ut failed.") return ut_flag -- Gitee