diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index 3f3775cde4d31994d03103ed85ac84baa341296a..ef5ccb61d0e6a4bff8e25824662c4f6ffcfd4df2 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -99,6 +99,10 @@ class Const: ASCEND_WORK_PATH = "ASCEND_WORK_PATH" DUMP_DIR = "dump_data" + OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" + ENV_ENABLE = "1" + ENV_DISABLE = "0" + class CompareConst: """ Class for compare module const diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py index 78de5b5e6b448b597e798310a1bd7d25d26484be..d551dfa71ea39325f0470914370cacfebd5436bb 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py @@ -27,7 +27,7 @@ from .wrap_functional import remove_dropout from ..common.utils import check_file_or_directory_path, print_error_log, CompareException, Const, \ print_info_log, print_warn_log, get_process_rank, torch_without_guard_version from ..dump.utils import make_dump_dirs, DumpUtil -from ..overflow_check.utils import OverFlowUtil +from ..overflow_check.utils import OverFlowUtil, clear_overflow_npu try: import torch_npu @@ -85,7 +85,7 @@ def add_clear_overflow(func, pid): return func(*args, **kwargs) nonlocal first_module if first_module: - torch_npu._C._clear_overflow_npu() + clear_overflow_npu() first_module = False return func(*args, **kwargs) return clear_overflow_wrapper diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py index 561b118e8541bafd4ffb171eb7329c94401636c6..14eac850907d88ea80fe8c5034417abaf39682bb 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py @@ -3,7 +3,7 @@ import torch from pathlib import Path from ..common.utils import print_warn_log, get_time, print_info_log from ..dump.dump import forward_init_status, forward_acl_dump -from .utils import OverFlowUtil, dump_overflow +from .utils import OverFlowUtil, dump_overflow, check_overflow_npu, clear_overflow_npu from ..dump.utils import DumpUtil, Const, get_tensor_rank, create_dirs_if_not_exist from .info_dump import write_api_info_json, ForwardAPIInfo, BackwardAPIInfo from ..dump import dump @@ -118,7 +118,7 @@ def overflow_check(name, **kwargs): check_feat = out_feat module.has_overflow = check_data_overflow(check_feat) else: - module.has_overflow = torch_npu._C._check_overflow_npu() + module.has_overflow = check_overflow_npu() if not module.has_overflow: if hasattr(module, 'input_args'): del module.input_args @@ -146,7 +146,7 @@ def overflow_check(name, **kwargs): acl_dump(module, module_name) dump.write_to_disk() # clear overflow flag for the next check - torch_npu._C._clear_overflow_npu() + clear_overflow_npu() if not OverFlowUtil.check_overflow_dump_times(overflow_nums): for key in forward_api_info: write_api_info_json(forward_api_info[key]) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/utils.py index 1a26c5ac04e770e76fd7a3b01bb88b34d4bcaa20..a70a178e977e79cbacf5000284d7ef56a52d12cd 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/utils.py @@ -1,5 +1,13 @@ +import os import torch +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + from ..common.utils import Const, check_switch_valid from ..dump.dump import dump_stack_info, get_scalar_data_info, dump_data, \ get_not_float_tensor_info, get_float_tensor_info @@ -71,3 +79,28 @@ def _dump_tensor_completely(x, prefix, dump_file_name): if isinstance(x, bool) or isinstance(x, int) or isinstance(x, float): data_info = get_scalar_data_info(x) dump_data(dump_file_name, dump_flag, prefix, data_info) + + +def overflow_debug_mode_enalbe(): + overflow_mode = os.getenv(Const.OVERFLOW_DEBUG_MODE_ENABLE, Const.ENV_DISABLE) + return overflow_mode == Const.ENV_ENABLE + + +def check_overflow_npu(): + if overflow_debug_mode_enalbe(): + float_status = torch.zeros(8).npu() + result = torch_npu.npu_get_float_debug_status(float_status) + if (result.cpu()[0] != 0): + return True + else: + return False + else: + return torch_npu._C._check_overflow_npu() + + +def clear_overflow_npu(): + if overflow_debug_mode_enalbe(): + float_status = torch.zeros(8).npu() + torch_npu.npu_clear_float_debug_status(float_status) + else: + torch_npu._C._clear_overflow_npu()