diff --git a/debug/accuracy_tools/atat/pytorch/__init__.py b/debug/accuracy_tools/atat/pytorch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1f1172219d3a701bfecc392ea4d0a689f425845 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/__init__.py @@ -0,0 +1,2 @@ +from .debugger.precision_debugger import PrecisionDebugger +from .common.utils import seed_all diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/__init__.py index e3e59fdcc5955043a54ee5dfdef4483bba2a7c4a..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/__init__.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/__init__.py @@ -1,21 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -""" - -from api_accuracy_checker.common.utils import seed_all -seed_all() -__all__ = [] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py index c273b5f5fcd84a23f5f5966305dbf3f8a373495b..dd6607a81ec00ce635ffae6e41b4b9d18e090827 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/config.py @@ -1,8 +1,8 @@ import os import yaml -from api_accuracy_checker.common.utils import check_file_or_directory_path -from api_accuracy_checker.hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +from ..common.utils import check_file_or_directory_path +from ..hook_module.utils import WrapFunctionalOps, WrapTensorOps, WrapTorchOps +from ...common.file_check import FileOpen WrapApi = set(WrapFunctionalOps) | set(WrapTensorOps) | set(WrapTorchOps) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py index 1152d19c655948bde1f98b5c60a08c666ef1e9b1..d01646740e2736e1df7716ca0aa1f0beb8bc2d63 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/common/utils.py @@ -36,8 +36,8 @@ except ImportError: else: IS_GPU = False -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, FileOpen -from ptdbg_ascend.src.python.ptdbg_ascend.common import file_check_util +from ...common.file_check import FileCheckConst, FileChecker, FileOpen +from ...common import file_check as file_check_util torch_without_guard_version_list = ['2.1'] for version in torch_without_guard_version_list: @@ -54,6 +54,7 @@ class Const: """ Class for const """ + SEP = '.' DIRECTORY_LENGTH = 4096 FILE_NAME_LENGTH = 255 FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py index c92eff25a701c5f0c228d3225fbbb22959d5f929..7983709f14bcca72a0cb29c453198396561681b1 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/algorithm.py @@ -1,8 +1,7 @@ # 定义比对算法及比对标准 import torch import numpy as np -from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable -from api_accuracy_checker.common.utils import Const +from .compare_utils import CompareConst #cos diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py index cd334733cc0951efb5c62cc4a391af0093a6abb0..f7f61a23e601c5ce5087ac88b44aa7242d282dd3 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_compare.py @@ -1,22 +1,20 @@ import argparse import os import sys -import csv import math from collections import namedtuple import pandas as pd -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, print_error_log, write_csv, \ +from ..common.utils import print_info_log, print_warn_log, print_error_log, write_csv, \ CompareException, create_directory -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.compare.compare_utils import CompareConst, API_PRECISION_COMPARE_RESULT_FILE_NAME, \ +from ..common.config import msCheckerConfig +from ..compare.compare_utils import CompareConst, API_PRECISION_COMPARE_RESULT_FILE_NAME, \ API_PRECISION_COMPARE_DETAILS_FILE_NAME, BENCHMARK_COMPARE_SUPPORT_LIST, API_PRECISION_COMPARE_UNSUPPORT_LIST, \ ApiPrecisionCompareColumn, AbsoluteStandardApi, BinaryStandardApi, BINARY_COMPARE_UNSUPPORT_LIST, \ convert_str_to_float, CompareMessage -from api_accuracy_checker.compare.compare_column import ApiPrecisionOutputColumn -from api_accuracy_checker.run_ut.run_ut import get_validated_result_csv_path -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, change_mode -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create +from ..compare.compare_column import ApiPrecisionOutputColumn +from ..run_ut.run_ut import get_validated_result_csv_path +from ...common.file_check import FileCheckConst, FileChecker, change_mode, check_path_before_create CompareConfig = namedtuple('CompareConfig', ['npu_csv_path', 'gpu_csv_path', 'result_csv_path', 'details_csv_path']) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml index 92ff8b0676cba6de65f1c1a8d6ecd3ee731264e8..efba9c5c02bbcc094b75ce2497d830789744b143 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/api_precision_standard.yaml @@ -106,4 +106,3 @@ BinaryCompareStandard: - tril_ - triu - triu_ - - type_as diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py index bd10f77976642331fa8e7bce28a703f0922c1411..39b13f34ff735d999e673f05a818e546b1e10068 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare.py @@ -5,16 +5,16 @@ import torch import numpy as np from rich.table import Table from rich.console import Console -from api_accuracy_checker.common.utils import get_json_contents, write_csv, print_warn_log -from api_accuracy_checker.compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ +from ..common.utils import get_json_contents, write_csv, print_warn_log, Const +from ..compare.compare_utils import CompareConst, check_dtype_comparable, DETAIL_TEST_ROWS, \ precision_configs, BENCHMARK_COMPARE_SUPPORT_LIST, AbsoluteStandardApi, BinaryStandardApi, apis_threshold -from api_accuracy_checker.compare.compare_column import CompareColumn -from api_accuracy_checker.compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err, \ +from ..compare.compare_column import CompareColumn +from ..compare.algorithm import get_rmse, get_error_balance, get_max_rel_err, get_mean_rel_err, \ get_rel_err, get_abs_err, get_max_abs_err, get_rel_err_ratio, cosine_sim, get_rel_err_origin, \ get_small_value_err_ratio, get_finite_and_infinite_mask, get_small_value_mask, check_inf_nan_value, \ check_small_value, check_norm_value, get_abs_bench_with_eps -from api_accuracy_checker.common.config import msCheckerConfig -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +from ..common.config import msCheckerConfig +from ...common.file_check import FileOpen class Comparator: @@ -159,7 +159,7 @@ class Comparator: self.write_detail_csv(args) def compare_output(self, full_api_name, bench_output, device_output, bench_grad=None, npu_grad=None): - _, api_name, _ = full_api_name.split("*") + _, api_name, _ = full_api_name.split(Const.SEP) compare_func = self._compare_dropout if "dropout" in full_api_name else self._compare_core_wrapper fwd_success_status, fwd_compare_alg_results = compare_func(api_name, bench_output, device_output) if not (bench_grad and npu_grad): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py index 961fce6811efd34789cb06f19d894244da681c33..97cf8226bd1ea6c9a668abd91719fd2662b5183b 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_column.py @@ -1,4 +1,4 @@ -from api_accuracy_checker.compare.compare_utils import CompareConst +from .compare_utils import CompareConst class CompareColumn: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py index bce185a9cb9427610022443bfb48dc2e9803089a..5511da724446187e2dd886448bf6b26ea7b7b369 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/compare/compare_utils.py @@ -3,8 +3,8 @@ import os import numpy as np import torch import yaml -from api_accuracy_checker.common.utils import Const, print_warn_log, CompareException -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +from ..common.utils import Const, print_warn_log, CompareException +from ...common.file_check import FileOpen current_time = time.strftime("%Y%m%d%H%M%S") diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/__init__.py index f3e3fe66364169f8d1617acfd378905e225a52d2..c9602292b85f753fd132634b98c74c76460997b0 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/__init__.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/__init__.py @@ -1,5 +1 @@ -from api_accuracy_checker.dump.dump import set_dump_switch -import api_accuracy_checker.dump.dump_scope -from api_accuracy_checker.common.config import msCheckerConfig - __all__ = ['set_dump_switch'] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py index 4f198d10ec34f4a738250bea675cd0cc09c16941..7452cec74e80c812902341ef2af13d3f29c5f10c 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/api_info.py @@ -3,10 +3,10 @@ import os import inspect import torch import numpy as np -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory, DumpException, \ +from ..common.config import msCheckerConfig +from ..common.utils import print_error_log, write_pt, create_directory, DumpException, \ get_real_data_path -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create +from ...common.file_check import check_path_before_create def get_tensor_extremum(data, operator): @@ -97,7 +97,7 @@ class APIInfo: @staticmethod def get_full_save_path(save_path, dir_name, contain_step=False): if contain_step: - from api_accuracy_checker.dump.dump import DumpUtil + from calibrator.pytorch.api_accuracy_checker.dump.dump import DumpUtil step_dir = "step" + str(DumpUtil.call_num - 1 if msCheckerConfig.enable_dataloader else DumpUtil.call_num) rank_dir = f"rank{os.getpid()}" return os.path.join(save_path, step_dir, dir_name, rank_dir) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/dump.py index b03cd423c3e1468eb036f35c8218783066d5b442..b20378fd45d322e1e2e4a61031c8c1fa240ca5a0 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/dump.py @@ -15,11 +15,11 @@ # limitations under the License. """ -from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo -from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json -from api_accuracy_checker.common.utils import print_error_log, CompareException, print_info_log -from api_accuracy_checker.hook_module.register_hook import initialize_hook -from api_accuracy_checker.common.config import msCheckerConfig +from .api_info import ForwardAPIInfo, BackwardAPIInfo +from .info_dump import write_api_info_json, initialize_output_json +from ..common.utils import print_error_log, CompareException, print_info_log +from ..hook_module.register_hook import initialize_hook +from ..common.config import msCheckerConfig def set_dump_switch(switch): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/dump_scope.py index 1f65dbc9c8a7e482d8ac85e3d06cffc3b11b406a..ac78fa8ccae9f5935d919b62ec72ed588b290a9f 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/dump_scope.py @@ -1,8 +1,8 @@ # dump范围控制 import torch from torch.utils.data.dataloader import _BaseDataLoaderIter -from api_accuracy_checker.dump.dump import DumpUtil -from api_accuracy_checker.common.config import msCheckerConfig +from ..dump.dump import DumpUtil +from ..common.config import msCheckerConfig def iter_tracer(original_next): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/info_dump.py index c73058e4f3058a9d1bf10b0a14046845f78440ee..31165077165c724f0e10ad0e279f5a59593cfd48 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/dump/info_dump.py @@ -4,20 +4,19 @@ import os import threading import multiprocessing -from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo -from api_accuracy_checker.common.utils import check_file_or_directory_path, initialize_save_path, create_directory -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create -from api_accuracy_checker.common.config import msCheckerConfig +from ..dump.api_info import ForwardAPIInfo, BackwardAPIInfo +from ..common.utils import check_file_or_directory_path, create_directory +from ...common.file_check import check_path_before_create +from ...common.file_check import FileOpen, FileCheckConst, FileChecker, change_mode +from ..common.config import msCheckerConfig -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker, change_mode - lock = threading.Lock() proc_lock = multiprocessing.Lock() def write_api_info_json(api_info): - from api_accuracy_checker.dump.dump import DumpUtil + from ..dump.dump import DumpUtil dump_path = msCheckerConfig.dump_path dump_path = os.path.join(msCheckerConfig.dump_path, "step" + str((DumpUtil.call_num - 1) if msCheckerConfig.enable_dataloader else DumpUtil.call_num)) check_path_before_create(dump_path) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/utils.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/utils.py index d0bbc7f4a9350ba166a959e36fc2452a2bb52b8b..6641807f929babeed3af30cf14b043d1e4f7913c 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/utils.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/utils.py @@ -18,7 +18,7 @@ import os import yaml -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +from ...common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_functional.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_functional.py index 0fef8a71cda3238bbde9e51963c7c7d22ee373a3..056c1d047eb592f0006e3632eaa5597eba5630da 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_functional.py @@ -15,16 +15,11 @@ # limitations under the License. """ -import os - import torch -import yaml -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.hook_module.utils import WrapFunctionalOps -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard +from ..common.config import msCheckerConfig for f in dir(torch.nn.functional): locals().update({f: getattr(torch.nn.functional, f)}) diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_tensor.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_tensor.py index dae3f8253e3b01ee9de99836dc6884923698188c..f7791cdc9ac8e2084fc63d76e3819e137f4ea9d7 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_tensor.py @@ -15,17 +15,12 @@ # limitations under the License. """ -import os - import torch -import yaml - -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.hook_module.utils import WrapTensorOps -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import parameter_adapter + +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard +from ..common.config import msCheckerConfig +from ...common.utils import parameter_adapter def get_tensor_ops(): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_torch.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_torch.py index 7d9f1611117a37fb4cdf7033772599fa286fbd0d..aab245b5d21daff0e0ea44e4073333c6854f95ac 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/hook_module/wrap_torch.py @@ -15,16 +15,11 @@ # limitations under the License. """ -import os - import torch -import yaml -from api_accuracy_checker.hook_module.hook_module import HOOKModule -from api_accuracy_checker.common.utils import torch_device_guard -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.hook_module.utils import WrapTorchOps -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen +from .hook_module import HOOKModule +from ..common.utils import torch_device_guard +from ..common.config import msCheckerConfig def get_torch_ops(): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py index 51fcfedefbb8a6ef079a2d26cdc7d9ca841092bf..3e4b16ed3849211ac418633dfa1843c969ad337f 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/data_generate.py @@ -20,7 +20,7 @@ import math import torch import numpy -from api_accuracy_checker.common.utils import Const, check_file_or_directory_path, check_object_type, print_warn_log, \ +from ..common.utils import Const, check_file_or_directory_path, check_object_type, print_warn_log, \ print_error_log, get_full_data_path, CompareException TORCH_TYPE = ["torch.device", "torch.dtype"] @@ -226,6 +226,8 @@ def gen_args(args_info, need_grad=True, convert_type=None, real_data_path=None): data = gen_args(arg, need_grad, convert_type, real_data_path) elif isinstance(arg, dict): data = gen_data(arg, need_grad, convert_type, real_data_path) + elif arg is None: + data = None else: print_warn_log(f'Warning: {arg} is not supported') raise NotImplementedError() @@ -243,10 +245,12 @@ def gen_kwargs(api_info, convert_type=None, real_data_path=None): real_data_path: the root directory for storing real data. """ check_object_type(api_info, dict) - kwargs_params = api_info.get("kwargs") + kwargs_params = api_info.get("input_kwargs") for key, value in kwargs_params.items(): if isinstance(value, (list, tuple)): kwargs_params[key] = gen_list_kwargs(value, convert_type, real_data_path) + elif value is None: + kwargs_params[key] = None elif value.get('type') in TENSOR_DATA_LIST or value.get('type').startswith("numpy"): kwargs_params[key] = gen_data(value, True, convert_type, real_data_path) elif value.get('type') in TORCH_TYPE: @@ -293,8 +297,8 @@ def gen_api_params(api_info, need_grad=True, convert_type=None, real_data_path=N error_info = f"convert_type params not support {convert_type}." raise CompareException(CompareException.INVALID_PARAM_ERROR, error_info) kwargs_params = gen_kwargs(api_info, convert_type, real_data_path) - if api_info.get("args"): - args_params = gen_args(api_info.get("args"), need_grad, convert_type, real_data_path) + if api_info.get("input_args"): + args_params = gen_args(api_info.get("input_args"), need_grad, convert_type, real_data_path) else: print_warn_log(f'Warning: No args in {api_info} ') args_params = [] diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py index cccb2dc0f5a1893604faf472076dacc20126503d..fdcf8926b73d7e13444e5ed09488ebc62f2b5d9d 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/multi_run_ut.py @@ -9,12 +9,12 @@ import threading from collections import namedtuple from itertools import cycle from tqdm import tqdm -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, \ +from ...common.file_check import FileCheckConst, FileChecker, \ check_file_suffix, check_link, FileOpen -from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.run_ut.run_ut import _run_ut_parser, get_validated_result_csv_path, get_validated_details_csv_path, preprocess_forward_content -from api_accuracy_checker.common.utils import print_error_log, print_warn_log, print_info_log, create_directory -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create +from ..compare.compare import Comparator +from .run_ut import _run_ut_parser, get_validated_result_csv_path, get_validated_details_csv_path, preprocess_forward_content +from ..common.utils import print_error_log, print_warn_log, print_info_log, create_directory +from ...common.file_check import check_path_before_create def split_json_file(input_file, num_splits, filter_api): diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py index 2e8a12231ed31c499648f12ee93486dfed47e00c..09fbf306d346456db07628aa3cdc13ba30a32557 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_overflow_check.py @@ -4,9 +4,9 @@ import sys import torch_npu import torch from tqdm import tqdm -from api_accuracy_checker.run_ut.run_ut import exec_api, generate_device_params, get_api_info -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, print_error_log -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import check_link +from ..run_ut.run_ut import exec_api, generate_device_params, get_api_info +from ..common.utils import print_info_log, print_warn_log, get_json_contents, print_error_log +from ...common.file_check import check_link def check_tensor_overflow(x): @@ -64,8 +64,8 @@ def run_overflow_check(forward_file): def run_torch_api(api_full_name, api_info_dict): torch.npu.clear_npu_overflow_flag() - api_type = api_full_name.split("_")[0] - api_name = api_full_name.split("_", 1)[1].rsplit("_", 2)[0] + api_type = api_full_name.split(".")[0] + api_name = api_full_name.split(".", 1)[1].rsplit(".", 2)[0] args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path='') if not need_grad: print_warn_log("%s function with out=... arguments don't support automatic differentiation, skip backward." diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py index e9e8b161ea3d1e2740c0badfaf1c110ed0f2e64a..64002dbba592718a2955dc5948e22c06804cbacd 100644 --- a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/run_ut/run_ut.py @@ -16,19 +16,18 @@ else: current_device = "npu" import torch from tqdm import tqdm -from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args -from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ +from .data_generate import gen_api_params, gen_args +from ..common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ print_error_log, initialize_save_path, Const, create_directory -from api_accuracy_checker.compare.compare import Comparator -from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate -from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate -from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.dump.api_info import APIInfo -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create - - -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, FileCheckConst, FileChecker, \ +from ..compare.compare import Comparator +from ..hook_module.wrap_tensor import TensorOPTemplate +from ..hook_module.wrap_functional import FunctionalOPTemplate +from ..hook_module.wrap_torch import TorchOPTemplate +from ..common.config import msCheckerConfig +from ..dump.api_info import APIInfo +from ...common.parse_json import parse_json_info_forward_backward +from ...common.file_check import check_path_before_create +from ...common.file_check import FileOpen, FileCheckConst, FileChecker, \ change_mode, check_file_suffix, check_link current_time = time.strftime("%Y%m%d%H%M%S") @@ -39,7 +38,6 @@ RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'save_error_data', 'is_continue_run_ut', 'real_data_path']) not_backward_list = ['repeat_interleave'] not_detach_set = {'resize_', 'resize_as_', 'set_', 'transpose_', 't_', 'squeeze_', 'unsqueeze_'} -not_raise_dtype_set = {'type_as'} tqdm_params = { 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1 @@ -143,7 +141,6 @@ def generate_cpu_params(input_args, input_kwargs, need_backward, api_name): elif len(need_raise_dtypes) >= 2: raise_dtype = torch.float32 - raise_dtype = None if api_name in not_raise_dtype_set else raise_dtype is_detach = api_name not in not_detach_set cpu_args = recursive_arg_to_cpu(input_args, is_detach, raise_dtype=raise_dtype) cpu_kwargs = {key: recursive_arg_to_cpu(value, key != "out" and is_detach, raise_dtype=raise_dtype) for key, value in input_kwargs.items()} @@ -167,7 +164,7 @@ def run_ut(config): continue try: if msCheckerConfig.white_list: - [_, api_name, _] = api_full_name.split("*") + [_, api_name, _] = api_full_name.split(Const.SEP) if api_name not in set(msCheckerConfig.white_list): continue data_info = run_torch_api(api_full_name, config.real_data_path, config.backward_content, api_info_dict) @@ -179,7 +176,7 @@ def run_ut(config): if config.save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: - [_, api_name, _] = api_full_name.split("*") + [_, api_name, _] = api_full_name.split(Const.SEP) if "expected scalar type Long" in str(err): print_warn_log(f"API {api_name} not support int32 tensor in CPU, please add {api_name} to CONVERT_API " f"'int32_to_int64' list in accuracy_tools/api_accuracy_check/common/utils.py file.") @@ -199,7 +196,6 @@ def run_ut(config): def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: - api_full_name = api_full_name.replace("*", ".") for element in data_info.in_fwd_data_list: UtAPIInfo(api_full_name + '.forward.input', element) UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) @@ -211,7 +207,7 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict): in_fwd_data_list = [] - [api_type, api_name, _] = api_full_name.split("*") + [api_type, api_name, _] = api_full_name.split(Const.SEP) args, kwargs, need_grad = get_api_info(api_info_dict, api_name, real_data_path) in_fwd_data_list.append(args) in_fwd_data_list.append(kwargs) @@ -241,27 +237,29 @@ def run_torch_api(api_full_name, real_data_path, backward_content, api_info_dict grad_index = grad_input_index.get('grad_index') if need_backward: - backward_args = backward_content[api_full_name] + backward_args = backward_content[api_full_name].get("grad_output") grad = gen_args(backward_args, real_data_path=real_data_path)[0] bench_grad, _ = generate_cpu_params(grad, {}, False, api_name) bench_grad_out = run_backward(cpu_args, bench_grad, grad_index, out) device_grad = grad.clone().detach().to(current_device) device_grad_out = run_backward(device_args, device_grad, grad_index, device_out) + if grad_index is not None: + return UtDataInfo(bench_grad_out, device_grad_out, device_out[grad_index], out[grad_index], bench_grad, + in_fwd_data_list) return UtDataInfo(bench_grad_out, device_grad_out, device_out, out, bench_grad, in_fwd_data_list) def get_api_info(api_info_dict, api_name, real_data_path): convert_type, api_info_dict = api_info_preprocess(api_name, api_info_dict) need_grad = True - if api_info_dict.get("kwargs") and "out" in api_info_dict.get("kwargs"): + if api_info_dict.get("input_kwargs") and "out" in api_info_dict.get("input_kwargs"): need_grad = False args, kwargs = gen_api_params(api_info_dict, need_grad, convert_type, real_data_path) return args, kwargs, need_grad def run_backward(args, grad, grad_index, out): - if grad_index is not None: out[grad_index].backward(grad) elif isinstance(out, (list, tuple)): @@ -357,7 +355,7 @@ def preprocess_forward_content(forward_content): processed_content = {} base_keys_variants = {} for key, value in forward_content.items(): - base_key = key.rsplit('*', 1)[0] + base_key = key.rsplit(Const.SEP, 1)[0] new_args = value['args'] new_kwargs = value['kwargs'] filtered_new_args = [{k: v for k, v in arg.items() if k not in ['Max', 'Min']} for arg in new_args if isinstance(arg, dict)] @@ -411,15 +409,10 @@ def run_ut_command(args): out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) out_path = out_path_checker.common_check() save_error_data = args.save_error_data - forward_content = get_json_contents(forward_file) + forward_content, backward_content, real_data_path = parse_json_info_forward_backward(forward_file) if args.filter_api: forward_content = preprocess_forward_content(forward_content) - backward_content = {} - if args.backward_input_file: - check_link(args.backward_input_file) - backward_file = os.path.realpath(args.backward_input_file) - check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) - backward_content = get_json_contents(backward_file) + result_csv_path = os.path.join(out_path, RESULT_FILE_NAME) details_csv_path = os.path.join(out_path, DETAILS_FILE_NAME) if args.result_csv_path: diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/common/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/compare/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/compare/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/dump/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/dump/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/hook_module/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/hook_module/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/__init__.py b/debug/accuracy_tools/atat/pytorch/api_accuracy_checker/test/ut/run_ut/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/atat/pytorch/common/__init__.py b/debug/accuracy_tools/atat/pytorch/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b391e103115498a2c2cf8b78f48168822517be73 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/common/__init__.py @@ -0,0 +1,4 @@ +from .recursive import recursive_apply_transform +from .log import print_error_log_rank_0, print_info_log_rank_0, print_warn_log_rank_0 +from .parse_json import parse_json_info_forward_backward +from .utils import seed_all diff --git a/debug/accuracy_tools/atat/pytorch/common/exceptions.py b/debug/accuracy_tools/atat/pytorch/common/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..607f6eb2be4f573df4a472b7e0774b069422856a --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/common/exceptions.py @@ -0,0 +1,67 @@ + +class CodedException(Exception): + def __init__(self, code, error_info=''): + self.error_info = self.err_strs.get(code) + error_info + + def __str__(self): + return self.error_info + + +class MsaccException(CodedException): + INVALID_PARAM_ERROR = 0 + + err_strs = { + INVALID_PARAM_ERROR: "[msacc] 无效参数: " + } + + +class FileCheckException(CodedException): + INVALID_FILE_ERROR = 0 + FILE_PERMISSION_ERROR = 1 + SOFT_LINK_ERROR = 2 + ILLEGAL_PATH_ERROR = 3 + ILLEGAL_PARAM_ERROR = 4 + FILE_TOO_LARGE_ERROR = 5 + + err_strs = { + SOFT_LINK_ERROR: "[msacc] 检测到软链接: ", + FILE_PERMISSION_ERROR: "[msacc] 文件权限错误: ", + INVALID_FILE_ERROR: "[msacc] 无效文件: ", + ILLEGAL_PATH_ERROR: "[msacc] 非法文件路径: ", + ILLEGAL_PARAM_ERROR: "[msacc] 非法打开方式: ", + FILE_TOO_LARGE_ERROR: "[msacc] 文件过大: " + } + + +class ParseJsonException(CodedException): + UnexpectedNameStruct = 0 + InvalidDumpJson = 1 + err_strs = { + UnexpectedNameStruct: "[msacc] Unexpected name in json: ", + InvalidDumpJson: "[msacc] json格式不正确: ", + } + + +class ScopeException(CodedException): + InvalidApiStr = 0 + InvalidScope = 1 + ArgConflict = 2 + err_strs = { + InvalidApiStr: "[msacc] Invalid api_list: ", + InvalidScope: "[msacc] Invalid scope: ", + ArgConflict: "[msacc] Scope and api_list conflict: ", + } + + +class RepairException(CodedException): + InvalidRepairType = 0 + err_strs = { + InvalidRepairType: "[msacc] Invalid repair_type: " + } + + +class StepException(CodedException): + InvalidPostProcess = 0 + err_strs = { + InvalidPostProcess: "[msacc] 错误的step后处理配置: ", + } diff --git a/debug/accuracy_tools/atat/pytorch/common/file_check_util.py b/debug/accuracy_tools/atat/pytorch/common/file_check.py similarity index 80% rename from debug/accuracy_tools/atat/pytorch/common/file_check_util.py rename to debug/accuracy_tools/atat/pytorch/common/file_check.py index 61fc4ddf94c8e295b08c395f21776ac0f05f5c61..d47869921ec60d0ac1b72163c577c53cd410d8e9 100644 --- a/debug/accuracy_tools/atat/pytorch/common/file_check_util.py +++ b/debug/accuracy_tools/atat/pytorch/common/file_check.py @@ -17,7 +17,9 @@ import os import re -from .log import print_warn_log, print_error_log +from .log import print_error_log +from .exceptions import FileCheckException +from .utils import Const class FileCheckConst: @@ -56,25 +58,6 @@ class FileCheckConst: } -class FileCheckException(Exception): - """ - Class for File Check Exception - """ - NONE_ERROR = 0 - INVALID_PATH_ERROR = 1 - INVALID_FILE_TYPE_ERROR = 2 - INVALID_PARAM_ERROR = 3 - INVALID_PERMISSION_ERROR = 3 - - def __init__(self, code, error_info: str = ""): - super(FileCheckException, self).__init__() - self.code = code - self.error_info = error_info - - def __str__(self): - return self.error_info - - class FileChecker: """ The class for check file. @@ -96,7 +79,7 @@ class FileChecker: def _check_path_type(path_type): if path_type not in [FileCheckConst.DIR, FileCheckConst.FILE]: print_error_log(f'The path_type must be {FileCheckConst.DIR} or {FileCheckConst.FILE}.') - raise FileCheckException(FileCheckException.INVALID_PARAM_ERROR) + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) return path_type def common_check(self): @@ -162,7 +145,7 @@ class FileOpen: support_mode = self.SUPPORT_READ_MODE + self.SUPPORT_WRITE_MODE + self.SUPPORT_READ_WRITE_MODE if self.mode not in support_mode: print_error_log("File open not support %s mode" % self.mode) - raise FileCheckException(FileCheckException.INVALID_PARAM_ERROR) + raise FileCheckException(FileCheckException.ILLEGAL_PARAM_ERROR) check_link(self.file_path) self.file_path = os.path.realpath(self.file_path) check_path_length(self.file_path) @@ -189,7 +172,7 @@ def check_link(path): abs_path = os.path.abspath(path) if os.path.islink(abs_path): print_error_log('The file path {} is a soft link.'.format(path)) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + raise FileCheckException(FileCheckException.SOFT_LINK_ERROR) def check_path_length(path, name_length=None): @@ -197,71 +180,58 @@ def check_path_length(path, name_length=None): if len(path) > FileCheckConst.DIRECTORY_LENGTH or \ len(os.path.basename(path)) > file_max_name_length: print_error_log('The file path length exceeds limit.') - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_exists(path): if not os.path.exists(path): print_error_log('The file path %s does not exist.' % path) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_path_readability(path): if not os.access(path, os.R_OK): print_error_log('The file path %s is not readable.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_writability(path): if not os.access(path, os.W_OK): print_error_log('The file path %s is not writable.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_executable(path): if not os.access(path, os.X_OK): print_error_log('The file path %s is not executable.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_other_user_writable(path): st = os.stat(path) if st.st_mode & 0o002: - _user_interactive_confirm( - 'The file path %s may be insecure because other users have write permissions. ' - 'Do you want to continue?' % path) - - -def _user_interactive_confirm(message): - while True: - check_message = input(message + " Enter 'c' to continue or enter 'e' to exit: ") - if check_message == "c": - break - elif check_message == "e": - print_warn_log("User canceled.") - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) - else: - print("Input is error, please enter 'c' or 'e'.") + print_error_log('The file path %s may be insecure because other users have write permissions. ' % path) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_owner_consistent(path): file_owner = os.stat(path).st_uid if file_owner != os.getuid(): print_error_log('The file path %s may be insecure because is does not belong to you.' % path) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR) def check_path_pattern_vaild(path): if not re.match(FileCheckConst.FILE_VALID_PATTERN, path): print_error_log('The file path {} contains special characters.'.format(path)) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR) def check_file_size(file_path, max_size): file_size = os.path.getsize(file_path) if file_size >= max_size: - _user_interactive_confirm(f'The size of file path {file_path} exceeds {max_size} bytes.' - f'Do you want to continue?') + print_error_log(f'The size of file path {file_path} exceeds {max_size} bytes.') + raise FileCheckException(FileCheckException.FILE_TOO_LARGE_ERROR) def check_common_file_size(file_path): @@ -276,18 +246,18 @@ def check_file_suffix(file_path, file_suffix): if file_suffix: if not file_path.endswith(file_suffix): print_error_log(f"The {file_path} should be a {file_suffix} file!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) def check_path_type(file_path, file_type): if file_type == FileCheckConst.FILE: if not os.path.isfile(file_path): print_error_log(f"The {file_path} should be a file!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) if file_type == FileCheckConst.DIR: if not os.path.isdir(file_path): print_error_log(f"The {file_path} should be a dictionary!") - raise FileCheckException(FileCheckException.INVALID_FILE_TYPE_ERROR) + raise FileCheckException(FileCheckException.INVALID_FILE_ERROR) def create_directory(dir_path): @@ -303,9 +273,18 @@ def create_directory(dir_path): try: os.makedirs(dir_path, mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) except OSError as ex: - print_error_log( - 'Failed to create {}.Please check the path permission or disk space .{}'.format(dir_path, str(ex))) - raise FileCheckException(FileCheckException.INVALID_PATH_ERROR) from ex + raise FileCheckException(FileCheckException.ILLEGAL_PATH_ERROR, + 'Failed to create {}. Please check the path permission or disk space .{}'.format(dir_path, str(ex))) from ex + + +def check_path_before_create(path): + if len(os.path.realpath(path)) > Const.DIRECTORY_LENGTH or len(os.path.basename(path)) > \ + Const.FILE_NAME_LENGTH: + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR, 'The file path length exceeds limit.') + + if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): + raise FileCheckException(FileCheckException.INVALID_PATH_ERROR, + 'The file path {} contains special characters.'.format(path)) def change_mode(path, mode): @@ -314,6 +293,6 @@ def change_mode(path, mode): try: os.chmod(path, mode) except PermissionError as ex: - print_error_log('Failed to change {} authority. {}'.format(path, str(ex))) - raise FileCheckException(FileCheckException.INVALID_PERMISSION_ERROR) from ex + raise FileCheckException(FileCheckException.FILE_PERMISSION_ERROR, + 'Failed to change {} authority. {}'.format(path, str(ex))) from ex diff --git a/debug/accuracy_tools/atat/pytorch/common/log.py b/debug/accuracy_tools/atat/pytorch/common/log.py index 32c3423551febda3358fc51c0aacdc6164b71d2e..fab5aca45c08af7253dedf8ee13db10b271683da 100644 --- a/debug/accuracy_tools/atat/pytorch/common/log.py +++ b/debug/accuracy_tools/atat/pytorch/common/log.py @@ -1,12 +1,26 @@ import os import time import sys +from .utils import get_rank_if_initialized + + +def on_rank_0(func): + def func_rank_0(*args, **kwargs): + current_rank = get_rank_if_initialized() + if current_rank is None or current_rank == 0: + return func(*args, **kwargs) + + return func_rank_0 def _print_log(level, msg, end='\n'): current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) - pid = os.getgid() - print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) + pid = os.getpid() + full_msg = current_time + "(" + str(pid) + ")-[" + level + "]" + msg + current_rank = get_rank_if_initialized() + if current_rank is not None: + full_msg = f"[rank {current_rank}]-" + full_msg + print(full_msg, end=end) sys.stdout.flush() @@ -37,4 +51,9 @@ def print_warn_log(warn_msg): Parameter: warn_msg: the warning message. """ - _print_log("WARNING", warn_msg) \ No newline at end of file + _print_log("WARNING", warn_msg) + + +print_info_log_rank_0 = on_rank_0(print_info_log) +print_warn_log_rank_0 = on_rank_0(print_warn_log) +print_error_log_rank_0 = on_rank_0(print_error_log) diff --git a/debug/accuracy_tools/atat/pytorch/common/parse_json.py b/debug/accuracy_tools/atat/pytorch/common/parse_json.py new file mode 100644 index 0000000000000000000000000000000000000000..2dddb185c14abb7e3b6e560322aa6169708a122d --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/common/parse_json.py @@ -0,0 +1,37 @@ +import json +from .exceptions import ParseJsonException + + +def parse_json_info_forward_backward(json_path): + def parse_data_name_with_pattern(data_name, pattern): + name_struct = data_name.split('.') + if not name_struct[-1] == pattern: + raise ParseJsonException(ParseJsonException.UnexpectedNameStruct, + f"{data_name} in file {json_path}") + api_name = '.'.join(name_struct[:-1]) + return api_name + + with open(json_path, 'r') as f: + dump_json = json.load(f) + + real_data_path = dump_json.get("dump_path") + dump_data = dump_json.get("data") + if not dump_data: + raise ParseJsonException(ParseJsonException.InvalidDumpJson, "dump数据中没有data字段") + + forward_data = {} + backward_data = {} + for data_name, data_item in dump_data.items(): + if "Module" in data_name: + continue + if "forward" in data_name: + api_name = parse_data_name_with_pattern(data_name, "forward") + forward_data.update({api_name: data_item}) + elif "backward" in data_name: + api_name = parse_data_name_with_pattern(data_name, "backward") + backward_data.update({api_name: data_item}) + else: + raise ParseJsonException(ParseJsonException.UnexpectedNameStruct, + f"{data_name} in file {json_path}.") + + return forward_data, backward_data, real_data_path diff --git a/debug/accuracy_tools/atat/pytorch/common/recursive.py b/debug/accuracy_tools/atat/pytorch/common/recursive.py new file mode 100644 index 0000000000000000000000000000000000000000..3745a33f9eac6c1c7e8e5437ca375dc4e0f8f22a --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/common/recursive.py @@ -0,0 +1,23 @@ +import torch + +_recursive_key_stack = [] +def recursive_apply_transform(args, transform): + global _recursive_key_stack + if isinstance(args, (list, tuple)): + transform_result = [] + for i, arg in enumerate(args): + _recursive_key_stack.append(str(i)) + transform_result.append(recursive_apply_transform(arg, transform)) + _recursive_key_stack.pop() + return type(args)(transform_result) + elif isinstance(args, dict): + transform_result = {} + for k, arg in args.items(): + _recursive_key_stack.append(str(k)) + transform_result[k] = recursive_apply_transform(arg, transform) + _recursive_key_stack.pop() + return transform_result + else: + arg_transform = transform(args, _recursive_key_stack) + return arg_transform + diff --git a/debug/accuracy_tools/atat/pytorch/common/utils.py b/debug/accuracy_tools/atat/pytorch/common/utils.py index 8f530db180ad1288b4d3c9615687db9ca9fb02d9..821ce4a7a30f32bd3930ebf9972d635bef4bd908 100644 --- a/debug/accuracy_tools/atat/pytorch/common/utils.py +++ b/debug/accuracy_tools/atat/pytorch/common/utils.py @@ -15,19 +15,12 @@ # limitations under the License. """ import os +from pathlib import Path import random -import zlib -from functools import wraps - +import stat import torch import numpy as np - -from atat.core.utils import print_error_log -from atat.core.utils import Const -from atat.core.utils import CompareException - - - +from functools import wraps try: import torch_npu except ImportError: @@ -35,8 +28,8 @@ except ImportError: else: is_gpu = False -torch_without_guard_version_list = ['2.1', '2.2'] -npu_distributed_api = ['isend', 'irecv'] + +torch_without_guard_version_list = ['2.1'] for version in torch_without_guard_version_list: if torch.__version__.startswith(version): torch_without_guard_version = True @@ -47,51 +40,7 @@ for version in torch_without_guard_version_list: if not is_gpu and not torch_without_guard_version: from torch_npu.utils.device_guard import torch_device_guard as torch_npu_device_guard - -def check_is_npu(): - return not is_gpu - - -def torch_device_guard(func): - if is_gpu or torch_without_guard_version: - return func - # Parse args/kwargs matched torch.device objects - - @torch_npu_device_guard - def wrapper(*args, **kwargs): - return func(*args, **kwargs) - return wrapper - - -def seed_all(seed=1234, mode=False): - check_seed_all(seed, mode) - random.seed(seed) - os.environ['PYTHONHASHSEED'] = str(seed) - np.random.seed(seed) - torch.manual_seed(seed) - torch.use_deterministic_algorithms(mode) - if is_gpu: - torch.cuda.manual_seed_all(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.enable = False - torch.backends.cudnn.benchmark = False - else: - torch_npu.npu.manual_seed_all(seed) - torch_npu.npu.manual_seed(seed) - - -def check_seed_all(seed, mode): - if isinstance(seed, int): - if seed < 0 or seed > Const.MAX_SEED_VALUE: - print_error_log(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - else: - print_error_log(f"Seed must be integer.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - if not isinstance(mode, bool): - print_error_log(f"seed_all mode must be bool.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) +npu_distributed_api = ['isend', 'irecv'] def parameter_adapter(func): @@ -124,9 +73,98 @@ def parameter_adapter(func): return inner -def get_md5_for_tensor(x): - if x.dtype == torch.bfloat16: - x = x.float() - tensor_bytes = x.cpu().detach().numpy().tobytes() - crc32_hash = zlib.crc32(tensor_bytes) - return f"{crc32_hash:08x}" +def torch_device_guard(func): + if is_gpu or torch_without_guard_version: + return func + # Parse args/kwargs matched torch.device objects + + @torch_npu_device_guard + def wrapper(*args, **kwargs): + return func(*args, **kwargs) + return wrapper + + +def get_rank_if_initialized(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return None + + +def seed_all(seed=1234, mode=False): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.use_deterministic_algorithms(mode) + if is_gpu: + torch.cuda.manual_seed_all(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.enable = False + torch.backends.cudnn.benchmark = False + else: + torch_npu.npu.manual_seed_all(seed) + torch_npu.npu.manual_seed(seed) + + +class Const: + """ + Class for const + """ + SEP = "." + MODEL_TYPE = ['.onnx', '.pb', '.om'] + DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*" + SEMICOLON = ";" + COLON = ":" + EQUAL = "=" + COMMA = "," + DOT = "." + DUMP_RATIO_MAX = 100 + SUMMERY_DATA_NUMS = 256 + FLOAT_EPSILON = np.finfo(float).eps + SUPPORT_DUMP_MODE = ['api', 'acl'] + ON = 'ON' + OFF = 'OFF' + BACKWARD = 'backward' + FORWARD = 'forward' + PRE_FORWARD = "pre_forward" + + # dump mode + ALL = "all" + LIST = "list" + RANGE = "range" + STACK = "stack" + ACL = "acl" + API_LIST = "api_list" + API_STACK = "api_stack" + DUMP_MODE = [ALL, LIST, RANGE, STACK, ACL, API_LIST, API_STACK] + AUTO = "auto" + ONLINE_DUMP_MODE = [ALL, LIST, AUTO, OFF] + SUMMARY = "summary" + MD5 = "md5" + SUMMARY_MODE = [ALL, SUMMARY, MD5] + + WRITE_FLAGS = os.O_WRONLY | os.O_CREAT + WRITE_MODES = stat.S_IWUSR | stat.S_IRUSR + + PKL_SUFFIX = ".pkl" + NUMPY_SUFFIX = ".npy" + ONE_GB = 1 * 1024 * 1024 * 1024 + TEN_GB = 10 * 1024 * 1024 * 1024 + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' + FILE_NAME_LENGTH = 255 + DIRECTORY_LENGTH = 4096 + DISTRIBUTED_PREFIX_LENGTH = 60 + SUMMARY_COLUMN_NUM = 6 + STACK_COLUMN_NUM = 2 + # env dump path + ASCEND_WORK_PATH = "ASCEND_WORK_PATH" + DUMP_DIR = "dump_data" + + ENV_ENABLE = "1" + ENV_DISABLE = "0" + + MAX_SEED_VALUE = 2**32 - 1 + + INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", + "_reduce_scatter_base", "_all_gather_base"] \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/common/version.py b/debug/accuracy_tools/atat/pytorch/common/version.py deleted file mode 100644 index f7d2e869417b85069ceed6e72b4a0c28f153ce64..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/atat/pytorch/common/version.py +++ /dev/null @@ -1 +0,0 @@ -__version__ = '3.0' diff --git a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py index 95accc67c10adf09a2005d4df96d553c5189da44..65325999834ddb0526502d8b758fa6feaa1ada05 100644 --- a/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/atat/pytorch/debugger/precision_debugger.py @@ -7,7 +7,7 @@ from ..pt_config import parse_json_config class PrecisionDebugger: _instance = None - def __new__(cls): + def __new__(cls, config_path=None, task=None, dump_path=None, level=None): if cls._instance is None: cls._instance = super(PrecisionDebugger, cls).__new__(cls) cls._instance.config = None diff --git a/debug/accuracy_tools/atat/pytorch/functional/__init__.py b/debug/accuracy_tools/atat/pytorch/functional/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d7282af08f0aadc803f7554602614359e7689e14 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/functional/__init__.py @@ -0,0 +1,4 @@ +from .repair import build_repair +from .scope import build_scope +from .step_post_process import build_step_post_process +from .data_collector import build_collect_data \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_collector.py b/debug/accuracy_tools/atat/pytorch/functional/data_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..6bf5ab6b3e03a87892a194c5c38a6406de7dd132 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/functional/data_collector.py @@ -0,0 +1,89 @@ + +import os +from ..module_processer import ModuleProcesser +from .scope import BaseScope, build_scope +from .json_writer import DataWriter +from ..common.log import print_info_log, print_info_log_rank_0, print_error_log_rank_0 +from ..common.utils import Const +from ..common.file_check import FileOpen +from .data_processor import build_data_processor, DataProcessor + + +def build_collect_data(config): + return DataCollector(config) + + +class DataCollector: + overflow_task = "overflow_check" + tasks_need_tensor_data = ["overflow_check", "tensor"] + level_without_construct = "L1" + + def __init__(self, config): + self.config = config + self.data_writer = DataWriter() + self.data_processor = build_data_processor(config, self.data_writer) + self.module_count = {} + self.scope = build_scope(None, self.config.scope, self.config.list) + + @property + def dump_data_dir(self): + return self.data_writer.dump_tensor_data_dir + + @property + def dump_file_path(self): + return self.data_writer.dump_file_path + + def write_json(self): + self.data_writer.write_json() + + def __call__(self, name_template, module_type, module, pid, module_input_output): + if module_type == BaseScope.Module_Type_Module: + name = module.mindstudio_reserved_name + else: + name = name_template + + if self.config.level != DataCollector.level_without_construct: + self.data_writer.update_construct({name: ModuleProcesser.api_parent_node}) + self.data_writer.update_construct(ModuleProcesser.module_node) + if not self.scope or self.scope.check(name): + msg = f"Calibrator is collecting data on {name}. " + if pid == os.getpid(): + if "forward" in name: + data_info = self.data_processor.analyze_forward(name, module_input_output) + self.data_writer.update_stack(self.data_processor.analyze_api_call_stack(name)) + else: + data_info = self.data_processor.analyze_backward(name, module_input_output) + if self.config.task == DataProcessor.overflow: + if data_info: + self.data_writer.update_data(data_info) + msg += "Overflow detected." + else: + msg += "No Overflow, OK." + else: + self.data_writer.update_data(data_info) + print_info_log(msg) + + + def module_count_func(self, name, name_template): + module_name = name.split(Const.SEP)[-3] + if "forward" in name_template: + if module_name not in self.module_count: + self.module_count[module_name] = [0, [0]] + else: + if self.module_count[module_name][-1] and \ + self.module_count[module_name][0] != self.module_count[module_name][-1][-1]: + self.module_count[module_name][-1].pop() + self.module_count[module_name][0] += 1 + self.module_count[module_name][-1].append(self.module_count[module_name][0]) + index = self.module_count[module_name][0] + else: + backward_stack = self.module_count[module_name][-1] if module_name in self.module_count else [] + if not backward_stack: + index = "abnormal" + else: + index = backward_stack.pop() + return index + + def update_dump_paths(self, *args): + self.data_writer.update_dump_paths(*args) + self.data_writer.initialize_json_file(task=self.config.task, level=self.config.level) diff --git a/debug/accuracy_tools/atat/pytorch/functional/data_processor.py b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py new file mode 100644 index 0000000000000000000000000000000000000000..79433ffff5ebabd99fde8fc5f4b24511efd59c1c --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/functional/data_processor.py @@ -0,0 +1,318 @@ +import torch +import zlib +import numpy as np +import os +import inspect +from dataclasses import dataclass +from typing import Tuple, List, Dict, Optional, Union +from ..common.exceptions import MsaccException +from ..common.utils import Const +from ..common import recursive_apply_transform + + +def build_data_processor(config, data_writer): + if config.task == DataProcessor.full: + return FullTensorDataProcessor(config, data_writer) + elif config.task == DataProcessor.summary: + return DataProcessor(config, data_writer) + elif config.task == DataProcessor.overflow: + return OverflowTensorDataProcessor(config, data_writer) + else: + raise MsaccException(MsaccException.INVALID_PARAM_ERROR, + "task should be in [{}, {}, {}]".format( + DataProcessor.full, + DataProcessor.summary, + DataProcessor.overflow + )) + + +@dataclass +class ModuleForwardInputsOutputs: + args: Optional[Tuple] + kwargs: Optional[Dict] + output: Union[Tuple, torch.Tensor] + + def __init__(self, args, kwargs, output): + if not isinstance(args, tuple): + args = (args, ) + if not isinstance(output, tuple): + output = (output, ) + self.args = args + self.kwargs = kwargs + self.output = output + + +@dataclass +class ModuleBackwardInputsOutputs: + grad_output: Optional[Tuple] + grad_input: Optional[Tuple] + + def __init__(self, grad_input, grad_output): + if not isinstance(grad_input, tuple): + grad_input = (grad_input, ) + if not isinstance(grad_output, tuple): + grad_output = (grad_output,) + self.grad_input = grad_input + self.grad_output = grad_output + + +class DataProcessor: + full = "tensor" + summary = "statistics" + overflow = "overflow_check" + + def __init__(self, config, data_writer): + self.data_writer = data_writer + self.api_info_struct = {} + self.stack_info_struct = {} + self.torch_object_key = { + "device": self.analyze_device_in_kwargs, + "dtype": self.analyze_dtype_in_kwargs + } + self.api_name = None + self.config = config + self.api_data_category = None + self.has_overflow = False + + @staticmethod + def get_md5_for_tensor(x): + if x.dtype == torch.bfloat16: + x = x.float() + tensor_bytes = x.cpu().detach().numpy().tobytes() + crc32_hash = zlib.crc32(tensor_bytes) + return f"{crc32_hash:08x}" + + @staticmethod + def analyze_device_in_kwargs(element): + single_arg = {} + single_arg.update({'type': "torch.device"}) + if not isinstance(element, str): + if hasattr(element, "index"): + device_value = element.type + ":" + str(element.index) + else: + device_value = element.type + single_arg.update({"value": device_value}) + else: + single_arg.update({"value": element}) + return single_arg + + @staticmethod + def analyze_dtype_in_kwargs(element): + single_arg = {} + single_arg.update({"type": "torch.dtype"}) + single_arg.update({"value": str(element)}) + return single_arg + + @staticmethod + def _convert_numpy_to_builtin(arg): + type_mapping = { + np.integer: int, + np.floating: float, + np.bool_: bool, + np.complexfloating: complex, + np.str_: str, + np.byte: bytes, + np.unicode_: str + } + for numpy_type, builtin_type in type_mapping.items(): + if isinstance(arg, numpy_type): + return builtin_type(arg), type(arg).__name__ + return arg, '' + + def _analyze_numpy(self, value, numpy_type): + single_arg = {} + single_arg.update({"type": numpy_type}) + single_arg.update({"value": value}) + return single_arg + + def get_stat_info(self, data): + if data.is_meta: + return + data_clone = data.detach() + if data_clone.numel() == 0: + tensor_max = None + tensor_min = None + tensor_mean = None + tensor_norm = None + elif data_clone.dtype == torch.bool: + tensor_max = True in data_clone + tensor_min = False not in data_clone + tensor_mean = None + tensor_norm = None + elif not len(data_clone.shape): + tensor_max = data_clone.item() + tensor_min = tensor_max + tensor_mean = tensor_max + tensor_norm = tensor_max + else: + if not data_clone.is_floating_point(): + data_clone = data_clone.float() + tensor_max = torch._C._VariableFunctionsClass.max(data_clone).item() + tensor_min = torch._C._VariableFunctionsClass.min(data_clone).item() + tensor_mean = torch._C._VariableFunctionsClass.mean(data_clone).item() + tensor_norm = torch._C._VariableFunctionsClass.norm(data_clone).item() + + return tensor_max, tensor_min, tensor_mean, tensor_norm + + def _analyze_builtin(self, arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({"type": "slice"}) + single_arg.update({"value": [arg.start, arg.stop, arg.step]}) + else: + single_arg.update({"type": type(arg).__name__}) + single_arg.update({"value": arg}) + return single_arg + + @staticmethod + def handle_tensor_extremum_nan_inf(data_clone, operator): + data_nan = torch._C._VariableFunctionsClass.isnan(data_clone) + if int(torch._C._VariableFunctionsClass.sum(data_nan)) == data_clone.numel(): + return float('nan') + finite_mask = torch._C._VariableFunctionsClass.isfinite(data_clone) + if int(torch._C._VariableFunctionsClass.sum(finite_mask)) > 0: + finite_values = data_clone[finite_mask] + return torch._C._VariableFunctionsClass.max(finite_values).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(finite_values).item() + else: + data_no_nan = data_clone[~data_nan] + return torch._C._VariableFunctionsClass.max(data_no_nan).item() if operator == 'max' else \ + torch._C._VariableFunctionsClass.min(data_no_nan).item() + + def _analyze_maybe_overflow_tensor(self, tensor_json, tensor): + if np.isinf(tensor_json['Max']) or np.isnan(tensor_json['Max']): + tensor_json['Max_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "max") + self.has_overflow = True + if np.isinf(tensor_json['Min']) or np.isnan(tensor_json['Min']): + tensor_json['Min_except_inf_nan'] = self.handle_tensor_extremum_nan_inf(tensor, "min") + self.has_overflow = True + + def _analyze_tensor(self, tensor, suffix): + tensor_max, tensor_min, tensor_mean, tensor_norm = self.get_stat_info(tensor) + + tensor_json = {} + tensor_json.update({'type': 'torch.Tensor'}) + tensor_json.update({'dtype': str(tensor.dtype)}) + tensor_json.update({"shape": tensor.shape}) + tensor_json.update({"Max": tensor_max}) + tensor_json.update({"Min": tensor_min}) + self._analyze_maybe_overflow_tensor(tensor_json, tensor) + tensor_json.update({"Mean": tensor_mean}) + tensor_json.update({"Norm": tensor_norm}) + tensor_json.update({"requires_grad": tensor.requires_grad}) + if self.config.summary_mode == "md5": + tensor_md5 = self.get_md5_for_tensor(tensor) + tensor_json.update({"md5": tensor_md5}) + + return tensor_json + + def analyze_single_element(self, element, suffix_stack): + if suffix_stack and suffix_stack[-1] in self.torch_object_key: + return self.torch_object_key[suffix_stack[-1]](element) + + converted_numpy, numpy_type = self._convert_numpy_to_builtin(element) + if converted_numpy is not element: + return self._analyze_numpy(converted_numpy, numpy_type) + + if isinstance(element, torch.Tensor): + return self._analyze_tensor(element, Const.SEP.join(suffix_stack)) + + if isinstance(element, (bool, int, float, str, slice)): + return self._analyze_builtin(element) + + def analyze_element(self, element): + return recursive_apply_transform(element, self.analyze_single_element) + + @staticmethod + def analyze_api_call_stack(name): + stack_str = [] + for (_, path, line, func, code, _) in inspect.stack()[5:]: + if not code: + continue + stack_line = " ".join([ + "File", ", ".join([ + path, + " ".join(["line", str(line)]), + " ".join(["in", func]), + " ".join(["\n", code[0].strip()]) + ]) + ]) + stack_str.append(stack_line) + stack_info_struct = {name: stack_str} + return stack_info_struct + + def analyze_forward(self, name, + module_input_output: ModuleForwardInputsOutputs): + self.api_name = name + self.api_data_category = "input" + args_info_list = self.analyze_element(module_input_output.args) + self.api_data_category = "kwargs" + kwargs_info_list = self.analyze_element(module_input_output.kwargs) + self.api_data_category = "output" + output_info_list = self.analyze_element(module_input_output.output) + api_info_struct = {name: {"input_args": args_info_list, + "input_kwargs": kwargs_info_list, + "output": output_info_list}} + return api_info_struct + + def analyze_backward(self, name, + module_input_output: ModuleBackwardInputsOutputs): + self.api_name = name + self.api_data_category = "output" + input_info_list = self.analyze_element(module_input_output.grad_input) + self.api_data_category = "input" + output_info_list = self.analyze_element(module_input_output.grad_output) + api_info_struct = {name: {"grad_input": input_info_list, "grad_output": output_info_list}} # TODO: magic str + return api_info_struct + + +class FullTensorDataProcessor(DataProcessor): + def _analyze_tensor(self, tensor, suffix): + self.data_path = self.data_writer.dump_tensor_data_dir + dump_data_name = (self.api_name + Const.SEP + self.api_data_category + Const.SEP + + suffix + ".pt") + file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + torch.save(tensor, file_path) + single_arg = super()._analyze_tensor(tensor, suffix) + single_arg.update({"data_name": dump_data_name}) + return single_arg + + +class OverflowTensorDataProcessor(FullTensorDataProcessor): + __slots__ = ["cached_tensors_and_file_paths"] + + def __init__(self, config, data_writer): + super().__init__(config, data_writer) + self.cached_tensors_and_file_paths = {} + + def _analyze_tensor(self, tensor, suffix): + self.data_path = self.data_writer.dump_tensor_data_dir + dump_data_name = (self.api_name + Const.SEP + self.api_data_category + Const.SEP + + suffix + ".pt") + file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + self.cached_tensors_and_file_paths.update({file_path: tensor}) + single_arg = super()._analyze_tensor(tensor, suffix) + single_arg.update({"data_name": dump_data_name}) + + def analyze_forward(self, name, + module_input_output: ModuleForwardInputsOutputs): + self.has_overflow = False + api_info_struct = super().analyze_forward(name, module_input_output) + if self.has_overflow: + self.save_overflow_data() + return api_info_struct + return None + + def analyze_backward(self, name, + module_input_output: ModuleBackwardInputsOutputs): + self.has_overflow = False + api_info_struct = super().analyze_backward(name, module_input_output) + if self.has_overflow: + self.save_overflow_data() + return api_info_struct + return None + + def save_overflow_data(self): + for file_path, tensor in self.cached_tensors_and_file_paths.items(): + torch.save(tensor, file_path) + self.cached_tensors_and_file_paths = {} diff --git a/debug/accuracy_tools/atat/pytorch/functional/json_writer.py b/debug/accuracy_tools/atat/pytorch/functional/json_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..760f4f2d71afd7210155d1de5f930842dcd3dc80 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/functional/json_writer.py @@ -0,0 +1,90 @@ +import os +from pathlib import Path +import json +from ..common.log import print_info_log_rank_0 + + +class DataWriter: # TODO: UT + # dump_json_name = "dump.json" + # stack_json_name = "stack.json" + # construct_json_name = "construct.json" + + def __init__(self, init_json=None) -> None: + self.dump_count = 0 + self.init_json = init_json + self.dump_file_path = None # os.path.join(dump_dir, DataWriter.dump_json_name) + self.stack_file_path = None # os.path.join(dump_dir, DataWriter.stack_json_name) + self.construct_file_path = None # os.path.join(dump_dir, DataWriter.construct_json_name) + self.dump_tensor_data_dir = None + self.batch_size = 1000 + self.cache_data = {"data": {}} + self.cache_stack = {} + self.cache_construct = {} + + def initialize_json_file(self, **kwargs): + kwargs.update({"dump_data_dir": self.dump_tensor_data_dir, "data": {}}) + with open(self.dump_file_path, 'w') as f: + json.dump(kwargs, f) + + if os.path.exists(self.stack_file_path): + os.remove(self.stack_file_path) + Path(self.stack_file_path).touch() + + if os.path.exists(self.construct_file_path): + os.remove(self.construct_file_path) + Path(self.construct_file_path).touch() + + def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir): + self.dump_file_path = dump_file_path + self.stack_file_path = stack_file_path + self.construct_file_path = construct_file_path + self.dump_tensor_data_dir = dump_data_dir + + def update_data(self, new_data): + self.cache_data["data"].update(new_data) + if len(self.cache_data["data"]) >= self.batch_size: + self.write_data_json(self.dump_file_path) + + def update_stack(self, new_data): + self.cache_stack.update(new_data) + + def update_construct(self, new_data): + self.cache_construct.update(new_data) + + def write_data_json(self, file_path): + import fcntl + print_info_log_rank_0(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") + if Path(file_path).exists() and os.path.getsize(file_path) > 0: + with open(file_path, "r+") as f: + fcntl.flock(f, fcntl.LOCK_EX) + data_to_write = json.load(f) + fcntl.flock(f, fcntl.LOCK_UN) + else: + self.init_json['data_path'] = self.dump_tensor_data_dir + data_to_write = self.init_json + data_to_write['data'].update(self.cache_data['data']) + with open(file_path, 'w+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(data_to_write, f, indent=1) + fcntl.flock(f, fcntl.LOCK_UN) + + self.cache_data["data"].clear() + + def write_stack_info_json(self, file_path): + import fcntl + with open(file_path, 'w+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(self.cache_stack, f, indent=1) + fcntl.flock(f, fcntl.LOCK_UN) + + def write_construct_info_json(self, file_path): + import fcntl + with open(file_path, 'w+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + json.dump(self.cache_construct, f, indent=1) + fcntl.flock(f, fcntl.LOCK_UN) + + def write_json(self): + self.write_data_json(self.dump_file_path) + self.write_stack_info_json(self.stack_file_path) + self.write_construct_info_json(self.construct_file_path) \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/functional/repair.py b/debug/accuracy_tools/atat/pytorch/functional/repair.py new file mode 100644 index 0000000000000000000000000000000000000000..3469db9da74de2e0fc8145631eb69e2d64d01558 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/functional/repair.py @@ -0,0 +1,90 @@ +from abc import ABC, abstractmethod + +import torch + +from .scope import build_scope, ListScope, BaseScope +from ..common.exceptions import RepairException +from ..common import recursive_apply_transform, print_info_log_rank_0 + + +def build_repair(config): + if config.repair_type is None: + return None + elif config.repair_type == RepairAPI.ToCPU: + return RepairAPI_toCPU(config) + elif config.repair_type == RepairAPI.RaisePrecision: + return RepairAPI_raise(config) + else: + raise RepairException(RepairException.InvalidRepairType, f"精度修复类型" + f"须配置为'{RepairAPI.ToCPU}'或'{RepairAPI.RaisePrecision}," + f"实际配置为{config.repair_type}") + + +class RepairAPI(ABC): + ToCPU = "cpu" + RaisePrecision = "raise" + + def __init__(self, config): + self.config = config + self.scope = build_scope(ListScope, config.repair_scope, config.repair_api_str) + self.saved, self.towards = "None", "None" + + def check_name_and_module_type(self, name, module_type): + if module_type == BaseScope.Module_Type_Module: + return False + if not self.scope.check(name): + return False + return True + + def convert(self, name, module_type, args, kwargs): + is_target = self.check_name_and_module_type(name, module_type) + if is_target: + args = recursive_apply_transform(args, self.fx) + kwargs = recursive_apply_transform(kwargs, self.fx) + print_info_log_rank_0(f"[calibrator] convert inputs of {name} to " + f"{self.towards}.") + return args, kwargs + + def invert(self, name, module_type, out_feat): + is_target = self.check_name_and_module_type(name, module_type) + if is_target: + out_feat = recursive_apply_transform(out_feat, self.inv_fx) + print_info_log_rank_0(f"[calibrator] convert outputs of {name} back to "\ + f"{self.saved}.") + return out_feat + + +class RepairAPI_toCPU(RepairAPI): + def fx(self, arg, _): + if isinstance(arg, torch.Tensor): + self.saved = arg.device + self.towards = torch.device("cpu") + return arg.cpu() + return arg + + def inv_fx(self, arg, _): + if isinstance(arg, torch.Tensor): + return arg.to(self.saved) + return arg + + +class RepairAPI_raise(RepairAPI): + raise_dtype_map = { + torch.bfloat16: torch.float32, + torch.float16: torch.float32 + } + + def fx(self, arg, _): + if isinstance(arg, torch.Tensor): + self.saved = arg.dtype + self.towards = RepairAPI_raise.raise_dtype_map.get(self.saved) + # bug: nested input may be of various dtypes. which to save and invert? + return arg.to(self.towards) + return arg + + def inv_fx(self, arg, _): + if isinstance(arg, torch.Tensor): + return arg.to(self.saved) + return arg + + diff --git a/debug/accuracy_tools/atat/pytorch/functional/scope.py b/debug/accuracy_tools/atat/pytorch/functional/scope.py new file mode 100644 index 0000000000000000000000000000000000000000..01ea607ac049cb3edbeba212d5f8c541571f1dd2 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/functional/scope.py @@ -0,0 +1,174 @@ +from abc import ABC, abstractmethod +from ..common.exceptions import ScopeException +from ..common.utils import Const + + +def build_scope(scope_class, scope=[], api_list=[]): + if not scope and not api_list: + return None + if scope_class: + return scope_class(scope, api_list) + return build_range_scope_according_to_scope_name(scope, api_list) + + +def build_range_scope_according_to_scope_name(scope, api_list): + api_range_scope = APIRangeScope(scope, api_list) + module_range_scope = ModuleRangeScope(scope, api_list) + if not scope: # 如果没有scope参数则用哪类scope都一样 + return api_range_scope + if api_range_scope.is_valid and module_range_scope.is_valid: + raise ScopeException(ScopeException.InvalidScope, f"scope={scope}.") + elif api_range_scope.is_valid: + return api_range_scope + elif module_range_scope.is_valid: + return module_range_scope + else: + raise ScopeException(ScopeException.InvalidScope, f"scope={scope}") + + +class BaseScope(ABC): + Module_Type_Module = "Module" + Module_Type_API = "api" + + @staticmethod + def rectify_args(scope, api_list): + if not isinstance(api_list, list): + raise ScopeException(ScopeException.InvalidApiStr, + f"api_list参数须配置为列表,实际类型为{type(api_list)}.") + for api_list in api_list: + if not isinstance(api_list, str): + raise ScopeException(ScopeException.InvalidApiStr, + f"api_list中的元素须配置为字符串,实际类型为{type(api_list)}.") + if isinstance(scope, str): + scope = [scope] + return scope, api_list + if not isinstance(scope, list): + raise ScopeException(ScopeException.InvalidScope, + f"scope参数须配置为字符串或列表,实际类型为{type(scope)}.") + for s in scope: + if not isinstance(s, str): + raise ScopeException(ScopeException.InvalidScope, + f"scope列表元素要求类型为字符串,实际类型为{type(s)}.") + return scope, api_list + + def __init__(self, scope, api_list): + scope, api_list = self.rectify_args(scope, api_list) + self.scope = scope + self.api_list = api_list + + def check_api_list(self, api_name): + if not self.api_list: + return True + for api_str in self.api_list: + if api_str in api_name: + return True + + @abstractmethod + def check(self, name): + pass + + +class ListScope(BaseScope): + @staticmethod + def rectify_args(scope, api_list): + if scope and api_list: + raise ScopeException(ScopeException.ArgConflict, + f"scope和api_list不可以同时配置,实际配置为scope={scope}, api_list={api_list}.") + return super(ListScope, ListScope).rectify_args(scope, api_list) + + def check(self, module_name): + if not self.scope or module_name in self.scope: + return self.check_api_list(module_name) + return False + + +class RangeScope(BaseScope, ABC): + @staticmethod + def rectify_args(scope, api_list): + scope, api_list = super(RangeScope, RangeScope).rectify_args(scope, api_list) + if isinstance(scope, list): + if len(scope) == 1: + scope.append(scope[0]) + elif len(scope) > 2: + raise ScopeException(ScopeException.InvalidScope, + f"scope参数指定区间断点,须传入长度为1或2的列表,实际长度为{len(scope)}.") + + return scope, api_list + + @abstractmethod + def check_scope_is_valid(self): + pass + + def __init__(self, *args): + super().__init__(*args) + self.in_scope = False + self.is_valid = self.check_scope_is_valid() + + def begin_module(self, module_name): + pass + + def end_module(self, module_name): + pass + + +class APIRangeScope(RangeScope): + def check_scope_is_valid(self): + if not self.scope: + return True + scope_start_type = self.scope[0].split(Const.SEP)[0] + if scope_start_type == BaseScope.Module_Type_Module: + return False + scope_stop_type = self.scope[1].split(Const.SEP)[0] + if scope_stop_type == BaseScope.Module_Type_Module: + return False + return True + + def check(self, api_name): + if self.scope and api_name == self.scope[0]: + self.in_scope = True + + if not self.scope or self.in_scope: + result = self.check_api_list(api_name) + else: + result = False + + if self.scope and api_name == self.scope[1]: + self.in_scope = False + return result + + +class ModuleRangeScope(RangeScope): + """ + 模块与api不同的是,模块内部还有子结构需要dump, + 需要用pre_hook和full_backward_hook来精确控制module的开始和结束, + 在这些hook触发时调用begin_module和end_module做区间控制 + """ + def check_scope_is_valid(self): + if not self.scope: + return True + scope_start_type = self.scope[0].split(Const.SEP)[0] + scope_stop_type = self.scope[1].split(Const.SEP)[0] + if scope_start_type == BaseScope.Module_Type_Module and \ + scope_stop_type == BaseScope.Module_Type_Module: + return True + return False + + def begin_module(self, module_name): + if not self.scope: + return + if module_name == self.scope[0]: + self.in_scope = True + + def end_module(self, module_name): + if not self.scope: + return + if module_name == self.scope[1]: + self.in_scope = False + + def check(self, module_name): + if not self.scope or self.in_scope: + return self.check_api_list(module_name) + return False + + + diff --git a/debug/accuracy_tools/atat/pytorch/functional/step_post_process.py b/debug/accuracy_tools/atat/pytorch/functional/step_post_process.py new file mode 100644 index 0000000000000000000000000000000000000000..7f0d3459326f04691a0041c120bf4efc676f8bc1 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/functional/step_post_process.py @@ -0,0 +1,43 @@ +from abc import ABC, abstractmethod +from ..common.exceptions import StepException + + +def run_parallel_ut(config): + pass + + +def compare_distrbuted(config): + pass + + +def build_step_post_process(config): + if not config.on_step_end: + return None + if config.on_step_end == StepPostProcess.SingleAPICheck: + return SingleAPICheck(config) + elif config.on_step_end == StepPostProcess.Compare: + return AutoCompare(config) + else: + raise StepException(StepException.InvalidPostProcess, f"step后处理须配置为" + f"'{StepPostProcess.SingleAPICheck}'或'{StepPostProcess.Compare}'," + f"实际配置为{config.on_step_end}") + + +class StepPostProcess(ABC): + SingleAPICheck = 'single_api_check' + Compare = 'compare' + + +class SingleAPICheck: + def __init__(self, config): + self.config = config + + def run(self): + run_parallel_ut(self.config) + +class AutoCompare: + def __init__(self, config): + self.config = config + + def run(self): + compare_distrbuted(self.config.bench_dump_path, self.config.dump_path) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/__init__.py b/debug/accuracy_tools/atat/pytorch/hook_module/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..4e7a5ca15e8d08d0bb886866bf413712796c9edd 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/__init__.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/__init__.py @@ -0,0 +1 @@ +from .wrap_functional import remove_dropout \ No newline at end of file diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py index cf21fe86bb541e64101dbdd360739a136f898d71..003a8699cd750a424bf989ae9d1b3fac78f76650 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/api_registry.py @@ -24,15 +24,11 @@ from .wrap_tensor import get_tensor_ops from .wrap_vf import get_vf_ops from .wrap_distributed import get_distributed_ops from .wrap_aten import get_aten_ops -from ..common.utils import torch_without_guard_version, npu_distributed_api +from ..common.utils import torch_without_guard_version, npu_distributed_api, is_gpu torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' -try: +if not is_gpu: import torch_npu -except ImportError: - is_gpu = True -else: - is_gpu = False from . import wrap_npu_custom from .wrap_npu_custom import get_npu_ops diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py index 6f23a8d42b48aca5e8a839bc8112b683995b595e..eb35de84b2da72a92532bc62c612bac1c29097f6 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/hook_module.py @@ -20,17 +20,15 @@ import threading import torch import torch.nn as nn import torch.utils.hooks as full_hooks - +from ..common.utils import Const class HOOKModule(nn.Module): module_count = {} inner_stop_hook = {} - def __init__(self, hook) -> None: + def __init__(self, build_hook) -> None: super(HOOKModule, self).__init__() self.has_overflow = False - self.input_args = tuple() - self.input_kwargs = dict() self.prefix = "" self.current_thread = threading.current_thread().ident if self.current_thread not in HOOKModule.inner_stop_hook: @@ -43,12 +41,14 @@ class HOOKModule(nn.Module): if self.prefix not in HOOKModule.module_count: HOOKModule.module_count[self.prefix] = 1 - self.prefix += '0_' + self.prefix += '0' + Const.SEP else: HOOKModule.module_count[self.prefix] += 1 - self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + '_' - self.register_forward_hook(hook(self.prefix + "forward")) - self.register_backward_hook(hook(self.prefix + "backward")) + self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + Const.SEP + forward_pre_hook, forward_hook, backward_hook = build_hook(self.prefix) + self.register_forward_pre_hook(forward_pre_hook, with_kwargs=True) + self.register_forward_hook(forward_hook, with_kwargs=True) + self.register_backward_hook(backward_hook) def __call__(self, *input, **kwargs): changed = False @@ -65,17 +65,17 @@ class HOOKModule(nn.Module): if len(self._backward_hooks) > 0: full_backward_hooks, non_full_backward_hooks = self._get_backward_hooks() for hook in self._forward_pre_hooks.values(): - result = hook(self, input) - if result is not None: - if not isinstance(result, tuple): - result = (result,) - input = result + result_input, result_kwargs = hook(self, input, kwargs) + if result_input is not None: + if not isinstance(result_input, tuple): + result_input = (result_input,) + input = result_input + if result_kwargs is not None: + kwargs = result_kwargs bw_hook = None if len(full_backward_hooks) > 0: bw_hook = full_hooks.BackwardHook(self, full_backward_hooks) input = bw_hook.setup_input_hook(input) - self.input_args = input - self.input_kwargs = kwargs if torch._C._get_tracing_state(): result = self._slow_forward(*input, **kwargs) else: @@ -83,7 +83,7 @@ class HOOKModule(nn.Module): input_list = list(input) input_list.extend(kwargs.values()) for hook in self._forward_hooks.values(): - hook_result = hook(self, input_list, result) + hook_result = hook(self, input, kwargs, result) if hook_result is not None: result = hook_result if bw_hook: diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/register_hook.py b/debug/accuracy_tools/atat/pytorch/hook_module/register_hook.py deleted file mode 100644 index 7715bda67181adc372e216198e58172a6c94e4f8..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/atat/pytorch/hook_module/register_hook.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -""" -# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -""" - -import functools -import os - -from inspect import isfunction -import torch -import torch.distributed as dist - -from atat.core.utils import check_file_or_directory_path, print_error_log, CompareException, Const, \ - print_info_log, print_warn_log, get_process_rank -from . import wrap_torch, wrap_functional, wrap_tensor, wrap_vf, wrap_distributed, wrap_aten -from .hook_module import HOOKModule -from .api_registry import api_register -from .wrap_functional import remove_dropout -from ..common.utils import torch_without_guard_version -from ..dump.utils import make_dump_dirs, DumpUtil -from ..overflow_check.utils import OverFlowUtil, clear_overflow_npu - -torch_version_above_2 = torch.__version__.split('+')[0] > '2.0' - -try: - import torch_npu -except ImportError: - is_gpu = True -else: - is_gpu = False - from . import wrap_npu_custom - -make_dir_flag = True -REGISTER_HOOK_KWARGS = ["overflow_nums", "dump_mode", "dump_config"] - - -def add_clear_overflow(func, pid): - first_module = True - - def clear_overflow_wrapper(*args, **kwargs): - child_pid = os.getpid() - if pid != child_pid: - return func(*args, **kwargs) - nonlocal first_module - if first_module: - clear_overflow_npu() - first_module = False - return func(*args, **kwargs) - - return clear_overflow_wrapper - - -def register_hook(model, hook, **kwargs): - check_register_hook(hook, **kwargs) - print_info_log("Please disable dataloader shuffle before running the program.") - overflow_nums = kwargs.get('overflow_nums', 1) - init_overflow_nums(overflow_nums) - dump_mode, dump_config_file = init_dump_config(kwargs) - if dump_mode == 'acl': - DumpUtil.dump_switch_mode = dump_mode - DumpUtil.set_acl_config(dump_config_file) - register_hook_core(hook) - - -def init_overflow_nums(overflow_nums): - if isinstance(overflow_nums, int) and overflow_nums > 0 or overflow_nums == -1: - OverFlowUtil.overflow_nums = overflow_nums - else: - print_error_log("overflow_nums must be an integer greater than 0 or set -1.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - - -def check_register_hook(hook, **kwargs): - if not isfunction(hook) or hook.__name__ not in ["overflow_check", "acc_cmp_dump"]: - print_error_log("hook function must be set overflow_check or acc_cmp_dump") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - for item in kwargs.keys(): - if item not in REGISTER_HOOK_KWARGS: - print_error_log(f"{item} not a valid keyword arguments in register_hook.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - - -def register_hook_core(hook, model=None): - global make_dir_flag - - pid = os.getpid() - need_clear = True - if make_dir_flag: - make_dump_dirs() - make_dir_flag = False - hook_name = hook.__name__ - - if "overflow_check" in hook_name and model is not None: - print_error_log("Overflow check does not support model dump mode") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - if "overflow_check" in hook_name and not is_gpu: - if hasattr(torch_npu._C, "_enable_overflow_npu"): - torch_npu._C._enable_overflow_npu() - print_info_log("Enable overflow function success.") - else: - print_warn_log("Api '_enable_overflow_npu' is not exist, " - "the overflow detection function on milan platform maybe not work! " - "please check the version of software torch_npu.") - # In NPU scene, clear the overflow flag before overflow detection - if need_clear: - HOOKModule.__init__ = add_clear_overflow(HOOKModule.__init__, pid) - - print_info_log("Start mounting the {} hook function to the model.".format(hook_name)) - hook = functools.partial(hook, dump_step=0, pid=pid) - print_info_log("The {} hook function is successfully mounted to the model.".format(hook_name)) - - if model is not None: - print_info_log("The init dump mode is enabled, and the module dump function will not be available") - if not isinstance(model, torch.nn.Module): - print_error_log("The argument model must be an object of torch.nn.Module") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - for name, module in model.named_modules(): - if module == model: - continue - prefix = name + "_" + module.__class__.__name__ - module.register_forward_hook(hook(prefix + "_{}_" + "forward")) - module.register_backward_hook(hook(prefix + "_{}_" + "backward")) - else: - api_register.initialize_hook(hook) - api_register.api_modularity() - - if "acc_cmp_dump" in hook_name: - remove_dropout() - - -def init_dump_config(kwargs): - dump_mode = kwargs.get('dump_mode', "api") - dump_config = kwargs.get('dump_config') - dump_config_file = '' - if dump_mode not in Const.SUPPORT_DUMP_MODE: - print_error_log("dump_mode only support %s" % Const.SUPPORT_DUMP_MODE) - raise CompareException(CompareException.INVALID_PARAM_ERROR) - if dump_mode == "acl": - if dump_config is None: - print_error_log("dump_mode is acl mode, dump_config must be configured.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - dump_config_file = os.path.realpath(dump_config) - check_file_or_directory_path(dump_config_file) - if not dump_config.endswith(".json"): - print_error_log("dump_config must be configure json file.") - raise CompareException(CompareException.INVALID_PARAM_ERROR) - return dump_mode, dump_config_file diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py index 9b6c694bec91e15826f8d7841d5ca7b1e19ef50d..8666287095bbe12f7e9d5f314cff1db75d74a108 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_aten.py @@ -20,9 +20,9 @@ import torch import yaml -from atat.core.file_check_util import FileOpen from .hook_module import HOOKModule -from ..common.utils import torch_device_guard +from ..common.utils import torch_device_guard, Const +from ..common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) @@ -56,7 +56,7 @@ class AtenOPTemplate(HOOKModule): if not '.' + overload_name in op_name_: op_name_ = op_name_ + '.' + overload_name self.op = op - self.prefix_op_name_ = "Aten_" + str(op_name_) + "_" + self.prefix_op_name_ = "Aten" + Const.SEP + str(op_name_) + Const.SEP super().__init__(hook) @torch_device_guard diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py index 4579caaf19ba5c9fc2c8f3809d0dcb8302c3bf0a..64ce06c33e8fe45966b900bcb9748d798e1b6e84 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_distributed.py @@ -20,9 +20,9 @@ from functools import wraps import torch.distributed as dist import yaml -from atat.core.file_check_util import FileOpen from .hook_module import HOOKModule from ..common.utils import torch_device_guard, Const +from ..common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) @@ -49,7 +49,7 @@ class HOOKDistributedOP(object): class DistributedOPTemplate(HOOKModule): def __init__(self, op_name, hook): self.op_name_ = op_name - self.prefix_op_name_ = "Distributed_" + str(op_name) + "_" + self.prefix_op_name_ = "Distributed" + Const.SEP + str(op_name) + Const.SEP super().__init__(hook) if self.op_name_ in Const.INPLACE_LIST: self.register_forward_pre_hook(hook(self.prefix + Const.PRE_FORWARD)) diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py index 23d90789e739b2332b34841d833a127fb998401d..46f25efe664fca2bff917b93e3e0632398bdc74e 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_functional.py @@ -20,15 +20,15 @@ import os import torch import yaml -from atat.core.utils import print_info_log -from atat.core.file_check_util import FileOpen from .hook_module import HOOKModule -from ..common.utils import torch_device_guard +from ..common.utils import torch_device_guard, Const +from ..common.log import print_info_log_rank_0 +from ..common.file_check import FileOpen def remove_dropout(): if torch.__version__ > "1.8": - print_info_log("For precision comparison, the probability p in the dropout method is set to 0.") + print_info_log_rank_0("For precision comparison, the probability p in the dropout method is set to 0.") import torch.nn.functional as F from torch import _VF from torch.overrides import has_torch_function_unary, handle_torch_function @@ -85,7 +85,7 @@ class HOOKFunctionalOP(object): class FunctionalOPTemplate(HOOKModule): def __init__(self, op_name, hook): self.op_name_ = op_name - self.prefix_op_name_ = "Functional_" + str(op_name) + "_" + self.prefix_op_name_ = "Functional" + Const.SEP + str(op_name) + Const.SEP super().__init__(hook) @torch_device_guard diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py index ea02096b1e11fadc12527f42455979aaa23c17d7..e910e609c8379e0c66239755c3ec2a44953ef1ec 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_npu_custom.py @@ -20,9 +20,9 @@ import torch import torch_npu import yaml -from atat.core.file_check_util import FileOpen from .hook_module import HOOKModule -from ..common.utils import torch_device_guard, torch_without_guard_version +from ..common.utils import torch_device_guard, torch_without_guard_version, Const +from ..common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") @@ -47,7 +47,7 @@ class NpuOPTemplate(HOOKModule): def __init__(self, op_name, hook): self.op_name_ = op_name - self.prefix_op_name_ = "NPU_" + str(op_name) + "_" + self.prefix_op_name_ = "NPU" + Const.SEP + str(op_name) + Const.SEP super().__init__(hook) @torch_device_guard diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py index bd62b593045d3614242a34a532a1cd76a559a57e..6b49826ab4712d440b4933651eb6b7eab950d023 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_tensor.py @@ -20,9 +20,9 @@ import os import torch import yaml -from atat.core.file_check_util import FileOpen from .hook_module import HOOKModule -from ..common.utils import torch_device_guard, parameter_adapter +from ..common.utils import torch_device_guard, parameter_adapter, Const +from ..common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") @@ -47,7 +47,7 @@ class TensorOPTemplate(HOOKModule): def __init__(self, op_name, hook): self.op_name_ = op_name - self.prefix_op_name_ = "Tensor_" + str(op_name) + "_" + self.prefix_op_name_ = "Tensor" + Const.SEP + str(op_name) + Const.SEP super().__init__(hook) @torch_device_guard diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py index 9fd71e90b88fc66f00b5e2d7d3a2934c7d66e373..889512e9c0c64d9d05dc19cbc30e542c6e5b577c 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_torch.py @@ -20,9 +20,9 @@ import os import torch import yaml -from atat.core.file_check_util import FileOpen from .hook_module import HOOKModule -from ..common.utils import torch_device_guard +from ..common.utils import torch_device_guard, Const +from ..common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") @@ -64,7 +64,7 @@ class TorchOPTemplate(HOOKModule): def __init__(self, op_name, hook): self.op_name_ = op_name - self.prefix_op_name_ = "Torch_" + str(op_name) + "_" + self.prefix_op_name_ = "Torch" + Const.SEP + str(op_name) + Const.SEP super().__init__(hook) @torch_device_guard diff --git a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py index a4da7f17406240e5ca616576702c207536a7dd39..08d47308e077981e65193eea71874d4f9432c6c0 100644 --- a/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py +++ b/debug/accuracy_tools/atat/pytorch/hook_module/wrap_vf.py @@ -20,9 +20,9 @@ import os import torch import yaml -from atat.core.file_check_util import FileOpen from .hook_module import HOOKModule -from ..common.utils import torch_device_guard +from ..common.utils import torch_device_guard, Const +from ..common.file_check import FileOpen cur_path = os.path.dirname(os.path.realpath(__file__)) yaml_path = os.path.join(cur_path, "support_wrap_ops.yaml") @@ -44,7 +44,7 @@ class HOOKVfOP(object): class VfOPTemplate(HOOKModule): def __init__(self, op_name, hook): self.op_name_ = op_name - self.prefix_op_name_ = "VF_" + str(op_name) + "_" + self.prefix_op_name_ = "VF" + Const.SEP + str(op_name) + Const.SEP super().__init__(hook) @torch_device_guard diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py new file mode 100644 index 0000000000000000000000000000000000000000..f7387fa3b08226f2c1d340719c09c9259b9314f0 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/module_processer.py @@ -0,0 +1,76 @@ +from functools import wraps +import torch +from torch.utils.hooks import BackwardHook +from .functional.scope import ModuleRangeScope +from .common.utils import Const + + +class ModuleProcesser: + module_stack = [] + api_parent_node = "" + module_node = {} + current_module_name = "" + + def __init__(self, scope): + if isinstance(scope, ModuleRangeScope): + self.scope = scope + else: + self.scope = None + BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) + BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + self.module_count = {} + + @staticmethod + def clone_return_value(func): + @wraps(func) + def clone_return_value_func(*args, **kwargs): + result = func(*args, **kwargs) + if isinstance(result, torch.Tensor): + result = result.clone() + elif isinstance(result, tuple): + result = tuple(r.clone() for r in result) + return result + + return clone_return_value_func + + def node_hook(self, name_prefix, start_or_stop, **kwargs): + + def pre_hook(module, input, output=None): + try: + index = self.module_count_func(name_prefix) + except IndexError as e: + index = None + pass + module.mindstudio_reserved_name = full_name = name_prefix + Const.SEP + str(index) + if self.module_stack: + ModuleProcesser.module_node[full_name] = self.module_stack[-1] + else: + ModuleProcesser.module_node[full_name] = None + + ModuleProcesser.module_stack.append(full_name) + if self.module_stack: + ModuleProcesser.api_parent_node = self.module_stack[-1] + if self.scope: + self.scope.begin_module(full_name) + + def end_hook(module, input, output=None): + if self.module_stack: + ModuleProcesser.module_stack.pop() + if self.module_stack: + ModuleProcesser.api_parent_node = self.module_stack[-1] + else: + ModuleProcesser.api_parent_node = None + if self.scope: + self.scope.end_module(module.mindstudio_reserved_name) + + if "start" in start_or_stop: + return pre_hook + else: + return end_hook + + def module_count_func(self, module_name): + if module_name not in self.module_count: + self.module_count[module_name] = 0 + else: + self.module_count[module_name] += 1 + return self.module_count[module_name] diff --git a/debug/accuracy_tools/atat/pytorch/service.py b/debug/accuracy_tools/atat/pytorch/service.py new file mode 100644 index 0000000000000000000000000000000000000000..d5bacb27aa332b4e31adb03fa9646ce5dbdba2f2 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/service.py @@ -0,0 +1,161 @@ +import os +from pathlib import Path +import functools +import torch +from .functional import build_repair, build_collect_data, build_step_post_process +from .functional.scope import BaseScope +from .common.utils import get_rank_if_initialized, is_gpu, Const +from .common.file_check import FileChecker, FileCheckConst, check_path_before_create +from .common import print_info_log_rank_0 +from .common.exceptions import MsaccException +from .hook_module.api_registry import api_register +from .hook_module import remove_dropout +from .functional.data_processor import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from .module_processer import ModuleProcesser + + +class Service: + make_dir_flag = True + REGISTER_HOOK_KWARGS = ["overflow_nums", "dump_mode", "dump_config"] + + def __init__(self, config): + self.model = None + self.config = config + self.collect_data = build_collect_data(config) + self.module_processor = ModuleProcesser(self.collect_data.scope) + self.repair = build_repair(config) + self.step_post_process = build_step_post_process(config) + self.switch = False + self.current_iter = 0 + self.first_start = True + self.current_rank = None + self.first_touch_dir = True + + def build_hook(self, module_type, name): + def pre_hook(repair, name_template, module, args, kwargs): + if repair: + args, kwargs = repair.convert(name_template, module_type, args, kwargs) + return args, kwargs + + def forward_hook(repair, name_template, module, args, kwargs, output): + nonlocal module_type, pid + if not self.switch: + return + if self.collect_data: + module_input_output = ModuleForwardInputsOutputs(args=args, kwargs=kwargs, output=output) + self.collect_data(name_template, module_type, module, pid, module_input_output) + if repair: + output = repair.invert(name_template, module_type, output) + + return output + + def backward_hook(repair, name_template, module, grad_input, grad_output): + nonlocal module_type, pid + if not self.switch: + return + if self.collect_data: + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) + self.collect_data(name_template, module_type, module, pid, module_input_output) + + pid = os.getpid() + if module_type == BaseScope.Module_Type_Module: + forward_name_template = name + Const.SEP + "{}" + Const.SEP + "forward" + backward_name_template = name + Const.SEP + "{}" + Const.SEP + "backward" + else: + forward_name_template = name + "forward" + backward_name_template = name + "backward" + pre_forward_hook = functools.partial(pre_hook, self.repair, forward_name_template) + forward_hook = functools.partial(forward_hook, self.repair, forward_name_template) + backward_hook = functools.partial(backward_hook, None, backward_name_template) + return pre_forward_hook, forward_hook, backward_hook + + def step(self): + self.current_iter += 1 + if self.step_post_process: + self.step_post_process() + + @staticmethod + def check_model_valid(model): + if isinstance(model, torch.nn.Module): + return model + raise MsaccException(MsaccException.INVALID_PARAM_ERROR, "model 参数必须是torch.nn.Module类型。") + + def start(self, model): + if self.config.step and self.current_iter > max(self.config.step): + self.stop() + raise Exception("atat: exit after iteration {}".format(max(self.config.step))) + if self.config.step and self.current_iter not in self.config.step: + return + self.model = self.check_model_valid(model) + if self.first_start: + self.current_rank = get_rank_if_initialized() + if self.config.rank and self.current_rank not in self.config.rank: + return + self.register_hook_new() + self.first_start = False + self.switch = True + self.create_dirs() + print_info_log_rank_0(f"Dump switch is turned on at step {self.current_iter}. " + f"Dump data will be saved in {self.dump_iter_dir}.") + + def stop(self): + self.switch = False + self.collect_data.write_json() + + + def create_dirs(self): + check_path_before_create(self.config.dump_path) + if not os.path.exists(self.config.dump_path): + Path(self.config.dump_path).mkdir(mode=0o750, exist_ok=True) + file_check = FileChecker(self.config.dump_path, FileCheckConst.DIR) + file_check.common_check() + self.dump_iter_dir = os.path.join(self.config.dump_path, f"step{self.current_iter}") + dump_dir = os.path.join(self.dump_iter_dir, f"rank{self.current_rank}") + if not os.path.exists(dump_dir): + Path(dump_dir).mkdir(mode=0o750, parents=True, exist_ok=True) + if self.config.task in self.collect_data.tasks_need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + Path(dump_data_dir).mkdir(mode=0o750, exist_ok=True) + else: + dump_data_dir = None + + dump_file_path = os.path.join(dump_dir, "dump.json") + stack_file_path = os.path.join(dump_dir, "stack.json") + construct_file_path = os.path.join(dump_dir, "construct.json") + self.collect_data.update_dump_paths(dump_file_path, stack_file_path, construct_file_path, dump_data_dir) + + def register_hook_new(self): + hook_name = self.config.task + + if "overflow_check" in hook_name and not is_gpu: + pass + + print_info_log_rank_0("The {} hook function is successfully mounted to the model.".format(hook_name)) + if self.config.level in ["L0", "mix"]: + assert self.model is not None + print_info_log_rank_0("The init dump mode is enabled, and the module dump function will not be available") + if not isinstance(self.model, torch.nn.Module): + raise MsaccException(MsaccException.INVALID_PARAM_ERROR, + "The argument model must be an object of torch.nn.Module") + for name, module in self.model.named_modules(): + if module == self.model: + continue + prefix = BaseScope.Module_Type_Module + Const.SEP + name + Const.SEP +\ + module.__class__.__name__ + Const.SEP + + pre_forward_hook, forward_hook, backward_hook = self.build_hook(BaseScope.Module_Type_Module, prefix) + module.register_forward_hook(forward_hook, with_kwargs=True) + module.register_full_backward_hook(backward_hook) + + module.register_forward_pre_hook(self.module_processor.node_hook(prefix + "forward", "start")) + module.register_forward_hook(self.module_processor.node_hook(prefix + "forward", "stop")) + module.register_full_backward_pre_hook(self.module_processor.node_hook(prefix + "backward", "start")) + module.register_full_backward_hook(self.module_processor.node_hook(prefix + "backward", "stop")) + + if self.config.level in ["mix", "L1"]: + api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) + api_register.api_modularity() + + if "acc_cmp_dump" in hook_name: + remove_dropout() +