From 74c9e99b6b0e3a5982ae7ea9ea15922486e29e51 Mon Sep 17 00:00:00 2001 From: yang-minghai22 Date: Thu, 7 Dec 2023 11:11:36 +0800 Subject: [PATCH] [bugfix]update code for check input parameters --- .../ptdbg_ascend/online_dispatch/dispatch.py | 45 +++++++++++++------ .../ptdbg_ascend/online_dispatch/utils.py | 12 +++++ 2 files changed, 44 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dispatch.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dispatch.py index 3704b44b55..2bb4fa9ffb 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dispatch.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dispatch.py @@ -21,7 +21,8 @@ from ..common.utils import Const, CompareConst, add_time_as_suffix, check_file_o from ..common.version import __version__ from .dump_compare import dispatch_workflow, dispatch_multiprocess, error_call, TimeStatistics, \ DispatchRunParam, save_csv -from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info +from .utils import get_callstack, data_to_cpu, logger_debug, logger_error, logger_warn, logger_logo, get_sys_info, \ + DispatchException from ..common.file_check_util import FileOpen @@ -39,7 +40,7 @@ class PtdbgDispatch(TorchDispatchMode): self.device_id = torch_npu._C._npu_getDevice() self.dump_mode = dump_mode - self.dump_api_list = self.get_dump_api(api_list) + self.dump_api_list = api_list self.debug_flag = debug self.api_index = 0 self.single_api_index_dict = {} @@ -47,10 +48,14 @@ class PtdbgDispatch(TorchDispatchMode): self.device_dump_path_npu = None self.all_summery = [] self.call_stack_list = [] + self.process_num = process_num + self.check_param() + self.filter_dump_api() # guarantee file uniqueness time.sleep(1) time_now = time.strftime("%Y%m%d%H%M%S", time.localtime(time.time())) - if tag is None: + if tag is None or not isinstance(tag, str): + logger_warn('There is not tag or the type of tag is not string.') dir_name = f'ptdbg_v{__version__}_rank{self.device_id}_{time_now}' else: dir_name = f'ptdbg_v{__version__}_{tag}_rank{self.device_id}_{time_now}' @@ -72,7 +77,6 @@ class PtdbgDispatch(TorchDispatchMode): self.aten_ops_blacklist = yaml_file.get('aten_ops_blacklist') self.npu_adjust_autogard = yaml_file.get('npu_adjust_autogard') - self.process_num = process_num self.lock = None if process_num > 0: self.pool = Pool(process_num) @@ -83,17 +87,18 @@ class PtdbgDispatch(TorchDispatchMode): f'dump_mode:{self.dump_mode} cpu_path[{self.root_cpu_path}], npu_path[{self.root_npu_path}], ' f'process[{process_num}]') - @staticmethod - def get_dump_api(api_list): + def filter_dump_api(self): + if self.dump_mode != Const.LIST or not self.dump_api_list: + self.dump_api_list = [] + return aten_api_list = dir(torch.ops.aten) dump_api_list = [] - if api_list is not None: - for aten_api in api_list: - if aten_api in aten_api_list: - dump_api_list.append(aten_api) - else: - logger_warn(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten') - return dump_api_list + for aten_api in self.dump_api_list: + if aten_api in aten_api_list: + dump_api_list.append(aten_api) + else: + logger_warn(f'{aten_api} is not aten api will not dump, please refer to torch.ops.aten') + self.dump_api_list = dump_api_list def get_dump_flag(self, aten_api): dump_flag = False @@ -117,6 +122,20 @@ class PtdbgDispatch(TorchDispatchMode): return True return False + def check_param(self): + if self.dump_mode not in [Const.OFF, Const.ALL, Const.LIST, Const.AUTO]: + logger_error('The parameter "dump mode" can only be "off", "all", "list", "auto".') + raise DispatchException(DispatchException.INVALID_PARAMETER) + if not isinstance(self.dump_api_list, list): + logger_error('The type of parameter "api_list" can only be list.') + raise DispatchException(DispatchException.INVALID_PARAMETER) + if not isinstance(self.debug_flag, bool): + logger_error('The type of parameter "debug" can only be bool.') + raise DispatchException(DispatchException.INVALID_PARAMETER) + if not isinstance(self.process_num, int) or self.process_num < 0: + logger_error('The type of parameter "process_num" can only be int and it should not be less than 0.') + raise DispatchException(DispatchException.INVALID_PARAMETER) + def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/utils.py index 5cb8e0f3e5..f62cfa91a6 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/utils.py @@ -165,3 +165,15 @@ def get_sys_info(): f'Used: {mem.used / 1024 / 1024:.2f} MB '\ f'CPU: {cpu_percent}% ' return sys_info + + +class DispatchException(Exception): + INVALID_PARAMETER = 0 + + def __init__(self, err_code, err_msg=""): + super(DispatchException, self).__init__() + self.err_code = err_code + self.err_msg = err_msg + + def __str__(self): + return self.err_msg -- Gitee