diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/__init__.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/__init__.py index 52fa7d5e74a7163989f2bd3f2231cf00666a489f..6c7bd08931e41e4e4cf6d0f5a0ce7959f95862f1 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/__init__.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/__init__.py @@ -27,7 +27,6 @@ from .overflow_check.utils import set_overflow_check_switch from .dump.utils import set_dump_path, set_dump_switch, set_backward_input from .hook_module.register_hook import register_hook from .common.utils import seed_all, torch_without_guard_version, print_info_log -from .common.version import __version__ from .debugger.debugger_config import DebuggerConfig from .debugger.precision_debugger import PrecisionDebugger seed_all() diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index 3f3775cde4d31994d03103ed85ac84baa341296a..ecfc420fdf580e393c31c7d477a05bbde3c61f95 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -99,6 +99,9 @@ class Const: ASCEND_WORK_PATH = "ASCEND_WORK_PATH" DUMP_DIR = "dump_data" + MAX_SEED_VALUE = 2**32 - 1 + + class CompareConst: """ Class for compare module const @@ -251,7 +254,7 @@ def print_warn_log(warn_msg): def check_mode_valid(mode, scope=[], api_list=[]): if not isinstance(scope, list): raise ValueError("scope param set invalid, it's must be a list.") - elif not isinstance(api_list, list): + if not isinstance(api_list, list): raise ValueError("api_list param set invalid, it's must be a list.") mode_check = { Const.ALL: lambda: None, @@ -259,7 +262,7 @@ def check_mode_valid(mode, scope=[], api_list=[]): Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None, Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None, Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None, - Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if not isinstance(api_list, list) or len(api_list) < 1 else None, + Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None, Const.API_STACK: lambda: None, } if mode not in Const.DUMP_MODE: @@ -270,9 +273,11 @@ def check_mode_valid(mode, scope=[], api_list=[]): if mode_check[mode]() is not None: raise mode_check[mode]() + def check_switch_valid(switch): if switch not in ["ON", "OFF"]: - raise ValueError("Please set switch with 'ON' or 'OFF'.") + print_error_log("Please set switch with 'ON' or 'OFF'.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) def check_dump_mode_valid(dump_mode): if not isinstance(dump_mode, list): @@ -290,7 +295,7 @@ def check_dump_mode_valid(dump_mode): def check_summary_only_valid(summary_only): if not isinstance(summary_only, bool): - print_error_log("Params auto_analyze only support True or False.") + print_error_log("Params summary_only only support True or False.") raise CompareException(CompareException.INVALID_PARAM_ERROR) return summary_only @@ -545,6 +550,7 @@ def torch_device_guard(func): def seed_all(seed=1234, mode=False): + check_seed_all(seed, mode) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) @@ -561,6 +567,19 @@ def seed_all(seed=1234, mode=False): torch_npu.npu.manual_seed(seed) +def check_seed_all(seed, mode): + if isinstance(seed, int): + if seed < 0 or seed > Const.MAX_SEED_VALUE: + print_error_log(f"Seed must be between 0 and {Const.MAX_SEED_VALUE}.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + else: + print_error_log(f"Seed must be integer.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + if not isinstance(mode, bool): + print_error_log(f"seed_all mode must be bool.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + + def get_process_rank(model): print_info_log("Rank id is not provided. Trying to get the rank id of the model.") try: diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index 3dadd540dc4dcc47aec7138adc78c8c56991889c..b55f5b27157a3ed9d66a86ea3bd2dbdb9956b266 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -1,12 +1,13 @@ import os import torch -from ..common.utils import Const, check_switch_valid, generate_compare_script, check_is_npu +from ..common.utils import Const, check_switch_valid, generate_compare_script, check_is_npu, print_error_log, \ + CompareException from ..dump.dump import DumpUtil, acc_cmp_dump, write_to_disk, get_pkl_file_path from ..dump.utils import set_dump_path, set_dump_switch_print_info, generate_dump_path_str, \ set_dump_switch_config, set_backward_input from ..overflow_check.utils import OverFlowUtil from ..overflow_check.overflow_check import overflow_check -from ..hook_module.register_hook import register_hook_core +from ..hook_module.register_hook import register_hook_core, init_overflow_nums from ..hook_module.hook_module import HOOKModule from .debugger_config import DebuggerConfig @@ -28,6 +29,9 @@ class PrecisionDebugger: DumpUtil.target_rank = self.config.rank set_dump_path(self.config.dump_path) PrecisionDebugger.hook_func = overflow_check if self.config.hook_name == "overflow_check" else acc_cmp_dump + if not isinstance(enable_dataloader, bool): + print_error_log("Params auto_analyze only support True or False.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) if enable_dataloader: DumpUtil.iter_num -= 1 torch.utils.data.dataloader._BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) @@ -57,10 +61,7 @@ class PrecisionDebugger: DumpUtil.dump_config = acl_config if acl_config is None: raise ValueError("acl_config must be configured when mode is 'acl'") - if isinstance(overflow_nums, int) and overflow_nums >= -1: - OverFlowUtil.overflow_nums = overflow_nums - else: - raise ValueError("overflow_nums must be int") + init_overflow_nums(overflow_nums) check_switch_valid(filter_switch) OverFlowUtil.overflow_filter_switch = filter_switch diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index 0657c7b7cc383e21927a5534fd2675a125b16e47..41d8502ab662a063f8eee921b3aa54eff7b972d0 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -1,4 +1,5 @@ import os +import re import shutil import sys from pathlib import Path @@ -132,6 +133,9 @@ class DumpUtil(object): def set_dump_path(fpath=None, dump_tag='ptdbg_dump'): fpath = load_env_dump_path(fpath) check_file_valid(fpath) + if not re.match(Const.FILE_PATTERN, dump_tag): + print_error_log('The file path {} contains special characters.'.format(dump_tag)) + raise CompareException(CompareException.INVALID_PATH_ERROR) real_path = os.path.realpath(fpath) make_dump_path_if_not_exists(real_path) DumpUtil.set_dump_path(real_path) @@ -184,11 +188,7 @@ def generate_dump_path_str(): def set_dump_switch(switch, mode=Const.ALL, scope=[], api_list=[], filter_switch=Const.ON, dump_mode=[Const.ALL], summary_only=False): - try: - check_switch_valid(switch) - except (CompareException, AssertionError) as err: - print_error_log(str(err)) - sys.exit() + check_switch_valid(switch) if not DumpUtil.dump_path: set_dump_path() DumpUtil.set_dump_switch(switch, summary_only=summary_only) @@ -209,7 +209,7 @@ def set_dump_switch_config(mode=Const.ALL, scope=[], api_list=[], filter_switch= summary_only = check_summary_only_valid(summary_only) except (CompareException, AssertionError) as err: print_error_log(str(err)) - sys.exit() + raise CompareException(CompareException.INVALID_PARAM_ERROR) switch = DumpUtil.dump_switch DumpUtil.set_dump_switch("OFF", mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, dump_mode=dump_mode, summary_only=summary_only) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py index 78de5b5e6b448b597e798310a1bd7d25d26484be..88299c288440183609c1b2bb9ebe91c159240684 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/register_hook.py @@ -18,6 +18,7 @@ import functools import os +from inspect import isfunction import torch import torch.distributed as dist @@ -38,6 +39,7 @@ else: from . import wrap_npu_custom make_dir_flag = True +REGISTER_HOOK_KWARGS = ["overflow_nums", "dump_mode", "dump_config"] def initialize_hook(hook): @@ -70,7 +72,7 @@ def initialize_hook(hook): for attr_name in dir(wrap_vf.HOOKVfOP): if attr_name.startswith("wrap_"): setattr(torch._VF, attr_name[5:], getattr(wrap_vf.HOOKVfOP, attr_name)) - + if not is_gpu: wrap_npu_custom.wrap_npu_ops_and_bind(hook) for attr_name in dir(wrap_npu_custom.HOOKNpuOP): @@ -92,8 +94,10 @@ def add_clear_overflow(func, pid): def register_hook(model, hook, **kwargs): + check_register_hook(hook, **kwargs) print_info_log("Please disable dataloader shuffle before running the program.") - OverFlowUtil.overflow_nums = kwargs.get('overflow_nums', 1) + overflow_nums = kwargs.get('overflow_nums', 1) + init_overflow_nums(overflow_nums) dump_mode, dump_config_file = init_dump_config(kwargs) if dump_mode == 'acl': DumpUtil.dump_switch_mode = dump_mode @@ -101,6 +105,24 @@ def register_hook(model, hook, **kwargs): register_hook_core(hook, **kwargs) +def init_overflow_nums(overflow_nums): + if isinstance(overflow_nums, int) and overflow_nums > 0 or overflow_nums == -1: + OverFlowUtil.overflow_nums = overflow_nums + else: + print_error_log("overflow_nums must be an integer greater than 0 or set -1.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + + +def check_register_hook(hook, **kwargs): + if not isfunction(hook) or hook.__name__ not in ["overflow_check", "acc_cmp_dump"]: + print_error_log("hook function must be set overflow_check or acc_cmp_dump") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + for item in kwargs.keys(): + if item not in REGISTER_HOOK_KWARGS: + print_error_log(f"{item} not a valid keyword arguments in register_hook.") + raise CompareException(CompareException.INVALID_PARAM_ERROR) + + def register_hook_core(hook, **kwargs): global make_dir_flag @@ -114,7 +136,7 @@ def register_hook_core(hook, **kwargs): if "overflow_check" in hook_name and not is_gpu: if hasattr(torch_npu._C, "_enable_overflow_npu"): torch_npu._C._enable_overflow_npu() - print_info_log("Enable overflow function success.") + print_info_log("Enable overflow function success.") else: print_warn_log("Api '_enable_overflow_npu' is not exist, " "the overflow detection function on milan platform maybe not work! " diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/debugger/test_precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/debugger/test_precision_debugger.py index f2c82a938c2126fb31f448e70b24e900dffbe11a..1ae90307873b5e225c091d981ef565ffe1aeb324 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/debugger/test_precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/debugger/test_precision_debugger.py @@ -24,7 +24,7 @@ class TestPrecisionDebugger(unittest.TestCase): self.assertRaises(ValueError, self.precision_debugger.configure_full_dump, mode='acl', acl_config=None) def test_configure_overflow_dump(self): - self.assertRaises(ValueError, self.precision_debugger.configure_overflow_dump, overflow_nums='invalid') + self.assertRaises(Exception, self.precision_debugger.configure_overflow_dump, overflow_nums='invalid') @patch('ptdbg_ascend.debugger.precision_debugger.register_hook_core') def test_start(self, mock_register_hook_core): diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_register_hook.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_register_hook.py index 7f91d6bd9cf21a17f626d09dbcc0584afa7f1bfb..c3b1a39bc662c3f36b329e267aaa4b8d33f4d0e3 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_register_hook.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/hook_module/test_register_hook.py @@ -1,12 +1,13 @@ import unittest from unittest.mock import patch, MagicMock from ptdbg_ascend.hook_module import register_hook +from ptdbg_ascend.dump.dump import acc_cmp_dump class TestRegisterHook(unittest.TestCase): def setUp(self): self.model = MagicMock() - self.hook = MagicMock() + self.hook = acc_cmp_dump def test_register_hook(self): with patch('ptdbg_ascend.hook_module.register_hook.register_hook_core') as mock_core: