From 54437f68940f32098eaa7d9fb003e11342518773 Mon Sep 17 00:00:00 2001 From: l00816237 Date: Thu, 7 Dec 2023 10:41:11 +0800 Subject: [PATCH] use the parameter scope instead of api_list when mode is api_list --- .../src/python/ptdbg_ascend/common/utils.py | 8 ++---- .../debugger/precision_debugger.py | 9 +++---- .../src/python/ptdbg_ascend/dump/dump.py | 4 +-- .../src/python/ptdbg_ascend/dump/utils.py | 27 +++++++------------ .../ptdbg_ascend/test/ut/test_common_util.py | 2 +- .../ptdbg_ascend/test/ut/test_hooks.py | 6 ++--- 6 files changed, 21 insertions(+), 35 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index a78eb595f..dc6744ef3 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -288,22 +288,18 @@ def print_warn_log(warn_msg): _print_log("WARNING", warn_msg) -def check_mode_valid(mode, scope=None, api_list=None): +def check_mode_valid(mode, scope=None): if scope is None: scope = [] - if api_list is None: - api_list = [] if not isinstance(scope, list): raise ValueError("scope param set invalid, it's must be a list.") - if not isinstance(api_list, list): - raise ValueError("api_list param set invalid, it's must be a list.") mode_check = { Const.ALL: lambda: None, Const.RANGE: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end].") if len(scope) != 2 else None, Const.LIST: lambda: ValueError("set_dump_switch, scope param set invalid, it's should not be an empty list.") if len(scope) == 0 else None, Const.STACK: lambda: ValueError("set_dump_switch, scope param set invalid, it's must be [start, end] or [].") if len(scope) > 2 else None, Const.ACL: lambda: ValueError("set_dump_switch, scope param set invalid, only one api name is supported in acl mode.") if len(scope) != 1 else None, - Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(api_list) < 1 else None, + Const.API_LIST: lambda: ValueError("Current dump mode is 'api_list', but the content of api_list parameter is empty or valid.") if len(scope) == 0 else None, Const.API_STACK: lambda: None, } if mode not in Const.DUMP_MODE: diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index 693fa06ee..499577f64 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -43,13 +43,12 @@ class PrecisionDebugger: hook_dict = {"dump": self.configure_full_dump, "overflow_check": self.configure_overflow_dump} return hook_dict.get(hook_name, lambda: ValueError("hook name {} is not in ['dump', 'overflow_check']".format(hook_name))) - def configure_full_dump(self, mode='api_stack', scope=None, api_list=None, filter_switch=Const.OFF, + def configure_full_dump(self, mode='api_stack', scope=None, filter_switch=Const.OFF, input_output_mode=[Const.ALL], acl_config=None, backward_input=None, summary_only=False): - scope = scope or [] - api_list = api_list or [] + scope = scope or [] backward_input = backward_input or [] - set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, - filter_switch=filter_switch, dump_mode=input_output_mode, summary_only=summary_only) + set_dump_switch_config(mode=mode, scope=scope, filter_switch=filter_switch, dump_mode=input_output_mode, + summary_only=summary_only) if mode == 'acl': DumpUtil.set_acl_config(acl_config) if not scope or not isinstance(scope, list) or len(scope) != 1: diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 17fe015ac..c4c85fed9 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -32,7 +32,7 @@ except ImportError: else: is_gpu = False -from .utils import DumpUtil, check_if_in_api_list, make_dump_data_dir, get_tensor_rank, create_dirs_if_not_exist +from .utils import DumpUtil, check_if_name_in_api_list, make_dump_data_dir, get_tensor_rank, create_dirs_if_not_exist from ..common.utils import print_warn_log, Const, print_info_log, modify_dump_path, check_inplace_op, CompareConst from ..dump.utils import check_writable from ..common.file_check_util import FileOpen, change_mode, FileCheckConst, check_path_pattern_vaild, check_path_length @@ -223,7 +223,7 @@ def rename_(): def dump_acc_cmp(name, in_feat, out_feat, dump_step, module): dump_file = DumpUtil.get_dump_path() dump_file = modify_dump_path(dump_file, DumpUtil.dump_switch_mode) - if DumpUtil.dump_switch_mode == Const.API_LIST and not check_if_in_api_list(name): + if DumpUtil.dump_switch_mode == Const.API_LIST and not check_if_name_in_api_list(name): return if DumpUtil.get_dump_switch(): global rank diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index b619d0cf4..fa1cd9648 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -61,7 +61,6 @@ class DumpUtil(object): dump_switch_mode = Const.ALL # all, api_stack, list, stack... dump_switch_scope = [] dump_init_enable = False - dump_api_list = [] dump_filter_switch = None dump_mode = ['forward', 'backward', 'input', 'output'] backward_input = {} @@ -89,15 +88,13 @@ class DumpUtil(object): DumpUtil.dump_config = acl_config @staticmethod - def set_dump_switch(switch, mode=None, scope=None, api_list=None, filter_switch=None, dump_mode=None, summary_only=False): + def set_dump_switch(switch, mode=None, scope=None, filter_switch=None, dump_mode=None, summary_only=False): DumpUtil.dump_switch = switch if mode is not None: DumpUtil.dump_switch_mode = mode DumpUtil.dump_init_enable = True if scope is not None: DumpUtil.dump_switch_scope = scope - if api_list is not None: - DumpUtil.dump_api_list = [api.lower() for api in api_list] if filter_switch is not None: DumpUtil.dump_filter_switch = filter_switch if dump_mode is not None: @@ -200,12 +197,10 @@ def generate_dump_path_str(): return dump_path -def set_dump_switch(switch, mode=Const.ALL, scope=None, api_list=None, filter_switch=Const.OFF, dump_mode=None, +def set_dump_switch(switch, mode=Const.ALL, scope=None, filter_switch=Const.OFF, dump_mode=None, summary_only=False): if scope is None: scope = [] - if api_list is None: - api_list = [] if dump_mode is None: dump_mode = [Const.ALL] check_switch_valid(switch) @@ -218,20 +213,18 @@ def set_dump_switch(switch, mode=Const.ALL, scope=None, api_list=None, filter_sw if check_is_npu() and DumpUtil.dump_switch_mode in [Const.ALL, Const.API_STACK, Const.LIST, Const.RANGE]: generate_compare_script(DumpUtil.dump_data_dir, dump.get_pkl_file_path(), DumpUtil.dump_switch_mode) set_dump_switch_print_info(switch, mode, dump_path_str) - set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, dump_mode=dump_mode, + set_dump_switch_config(mode=mode, scope=scope, filter_switch=filter_switch, dump_mode=dump_mode, summary_only=summary_only) -def set_dump_switch_config(mode=Const.ALL, scope=None, api_list=None, filter_switch=Const.OFF, dump_mode=None, +def set_dump_switch_config(mode=Const.ALL, scope=None, filter_switch=Const.OFF, dump_mode=None, summary_only=False): if scope is None: scope = [] - if api_list is None: - api_list = [] if dump_mode is None: dump_mode = [Const.ALL] try: - check_mode_valid(mode, scope, api_list) + check_mode_valid(mode, scope) check_switch_valid(filter_switch) dump_mode = check_dump_mode_valid(dump_mode) summary_only = check_summary_only_valid(summary_only) @@ -239,7 +232,7 @@ def set_dump_switch_config(mode=Const.ALL, scope=None, api_list=None, filter_swi print_error_log(str(err)) raise CompareException(CompareException.INVALID_PARAM_ERROR) from err switch = DumpUtil.dump_switch - DumpUtil.set_dump_switch("OFF", mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, + DumpUtil.set_dump_switch("OFF", mode=mode, scope=scope, filter_switch=filter_switch, dump_mode=dump_mode, summary_only=summary_only) DumpUtil.dump_switch = switch @@ -256,10 +249,10 @@ def set_dump_switch_print_info(switch, mode, dump_path_str): print_info_log("The number of matched dump is {}".format(dump_count)) -def check_if_in_api_list(name): - if not DumpUtil.dump_api_list: - return False - for api in DumpUtil.dump_api_list: +def check_if_name_in_api_list(name): + if not DumpUtil.dump_switch_scope: + raise DumpException(DumpException.INVALID_PARAM_ERROR) + for api in DumpUtil.dump_switch_scope: if api.lower() in name.lower(): return True return False diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_common_util.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_common_util.py index 61cbf29f2..bd0d40133 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_common_util.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_common_util.py @@ -24,7 +24,7 @@ class TestCommonUtilsMethods(unittest.TestCase): self.assertEqual(mode_check("range", scope=["Tensor_abs_1_forward", "Tensor_transpose_3_forward"]), None) self.assertEqual(mode_check("stack",scope=["Tensor_abs_1_forward", "Tensor_transpose_3_forward"]), None) self.assertEqual(mode_check("acl",scope=["Tensor_permute_1_forward"]), None) - self.assertEqual(mode_check("api_list",api_list=["relu"]), None) + self.assertEqual(mode_check("api_list",scope=["relu"]), None) self.assertEqual(mode_check("api_stack"), None) self.assertRaises(common.CompareException, mode_check, "api_stack_123") diff --git a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_hooks.py b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_hooks.py index 82f3d8dfe..4aee5defb 100644 --- a/debug/accuracy_tools/ptdbg_ascend/test/ut/test_hooks.py +++ b/debug/accuracy_tools/ptdbg_ascend/test/ut/test_hooks.py @@ -16,7 +16,6 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(dump_util.dump_switch_mode, mode_all) self.assertTrue(dump_util.dump_init_enable) self.assertEqual(dump_util.dump_switch_scope, []) - self.assertEqual(dump_util.dump_api_list, []) self.assertEqual(dump_util.dump_filter_switch, "OFF") self.assertEqual(dump_count, 0) @@ -42,12 +41,11 @@ class TestUtilsMethods(unittest.TestCase): self.assertEqual(dump_util.dump_switch_scope, scope_list) def test_set_dump_switch_mode_is_api_list(self): - api_list = ["Transpose", "Relu", "triu"] + scope = ["Transpose", "Relu", "triu"] lower_api_list = ["transpose", "relu", "triu"] dump_util = hooks.DumpUtil - hooks.set_dump_switch("ON", mode="api_list", api_list=api_list) + hooks.set_dump_switch("ON", mode="api_list", scope=scope) self.assertEqual(dump_util.dump_switch_mode, "api_list") - self.assertEqual(dump_util.dump_api_list, lower_api_list) def test_set_dump_switch_mode_is_acl(self): scope_list = ["Tensor_transpose_3_backward"] -- Gitee