From 90a1947862a91fc47fde49f54ab915b62ad6c7f1 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Tue, 31 Oct 2023 11:44:27 +0800 Subject: [PATCH 1/4] clean code --- .../src/python/ptdbg_ascend/dump/utils.py | 39 +++++++++++++------ .../online_dispatch/dump_compare.py | 4 +- .../ptdbg_ascend/overflow_check/info_dump.py | 17 ++++---- .../overflow_check/overflow_check.py | 3 +- .../parse_tool/lib/visualization.py | 4 +- 5 files changed, 42 insertions(+), 25 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index 35f5ad4a905..db70bc7946a 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -71,14 +71,14 @@ class DumpUtil(object): DumpUtil.dump_switch_scope = [api_name.replace("backward", "forward") for api_name in scope] DumpUtil.summary_only = summary_only - def check_list_or_acl_mode(name_prefix): + def check_list_or_acl_mode(self, name_prefix): global dump_count for item in DumpUtil.dump_switch_scope: if name_prefix.startswith(item): dump_count = dump_count + 1 return True - def check_range_mode(name_prefix): + def check_range_mode(self, name_prefix): global range_begin_flag global range_end_flag if name_prefix.startswith(DumpUtil.dump_switch_scope[0]): @@ -91,7 +91,7 @@ class DumpUtil(object): return True return False - def check_stack_mode(name_prefix): + def check_stack_mode(self, name_prefix): if len(DumpUtil.dump_switch_scope) == 0: return True elif len(DumpUtil.dump_switch_scope) == 1: @@ -196,7 +196,14 @@ def generate_dump_path_str(): return dump_path -def set_dump_switch(switch, mode=Const.ALL, scope=[], api_list=[], filter_switch=Const.ON, dump_mode=[Const.ALL], summary_only=False): +def set_dump_switch(switch, mode=Const.ALL, scope=None, api_list=None, filter_switch=Const.ON, dump_mode=None, + summary_only=False): + if scope is None: + scope = [] + if api_list is None: + api_list = [] + if dump_mode is None: + dump_mode = [Const.ALL] check_switch_valid(switch) if not DumpUtil.dump_path: set_dump_path() @@ -207,10 +214,18 @@ def set_dump_switch(switch, mode=Const.ALL, scope=[], api_list=[], filter_switch if check_is_npu() and DumpUtil.dump_switch_mode in [Const.ALL, Const.API_STACK, Const.LIST, Const.RANGE]: generate_compare_script(DumpUtil.dump_data_dir, dump.get_pkl_file_path(), DumpUtil.dump_switch_mode) set_dump_switch_print_info(switch, mode, dump_path_str) - set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, dump_mode=dump_mode,summary_only=summary_only) - - -def set_dump_switch_config(mode=Const.ALL, scope=[], api_list=[], filter_switch=Const.ON, dump_mode=[Const.ALL], summary_only=False): + set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, dump_mode=dump_mode, + summary_only=summary_only) + + +def set_dump_switch_config(mode=Const.ALL, scope=None, api_list=None, filter_switch=Const.ON, dump_mode=None, + summary_only=False): + if scope is None: + scope = [] + if api_list is None: + api_list = [] + if dump_mode is None: + dump_mode = [Const.ALL] try: check_mode_valid(mode, scope, api_list) check_switch_valid(filter_switch) @@ -218,10 +233,10 @@ def set_dump_switch_config(mode=Const.ALL, scope=[], api_list=[], filter_switch= summary_only = check_summary_only_valid(summary_only) except (CompareException, AssertionError) as err: print_error_log(str(err)) - raise CompareException(CompareException.INVALID_PARAM_ERROR) + raise CompareException(CompareException.INVALID_PARAM_ERROR) from err switch = DumpUtil.dump_switch DumpUtil.set_dump_switch("OFF", mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, - dump_mode=dump_mode, summary_only=summary_only) + dump_mode=dump_mode, summary_only=summary_only) DumpUtil.dump_switch = switch @@ -290,9 +305,9 @@ def load_env_dump_path(dump_path): if dump_path: try: dump_path = os.path.join(str(dump_path), Const.DUMP_DIR) - except TypeError: + except TypeError as err: print_error_log("Generating dump path from environment variables ASCEND_WORK_PATH failed.") - raise DumpException(DumpException.INVALID_PATH_ERROR) + raise DumpException(DumpException.INVALID_PATH_ERROR) from err else: print_error_log("Dump path is None, you can configure it in the following ways:\n" "1. Configure set_dump_path function.\n" diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dump_compare.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dump_compare.py index a4740fef884..03947fc9e4e 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dump_compare.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/online_dispatch/dump_compare.py @@ -140,8 +140,8 @@ def save_summery(run_param, npu_data, cpu_data, prefix, summery_list, compute_fl data_dict[CompareConst.COSINE], data_dict[CompareConst.MAX_ABS_ERR], data_dict[CompareConst.MAX_RELATIVE_ERR], \ data_dict[CompareConst.ERROR_MESSAGE] = get_compare_result(npu_data, cpu_data) - data_dict[CompareConst.ACCURACY] = check_accuracy(data_dict[CompareConst.COSINE], - data_dict[CompareConst.MAX_ABS_ERR]) + data_dict[CompareConst.ACCURACY] = check_accuracy(data_dict.get(CompareConst.COSINE), + data_dict.get(CompareConst.MAX_ABS_ERR)) else: data_dict[CompareConst.COSINE] = 1 data_dict[CompareConst.MAX_ABS_ERR] = 0 diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py index 22040a62aa5..5e6e347fff9 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/info_dump.py @@ -42,28 +42,29 @@ class APIInfo: out = [] for item in element: out.append(self.analyze_element(item)) + return out elif isinstance(element, dict): - out = {} + out_dict = {} for key, value in element.items(): if key in self.torch_object_key.keys(): fun = self.torch_object_key[key] - out[key] = fun(value) + out_dict[key] = fun(value) elif key in special_torch_object: continue else: - out[key] = self.analyze_element(value) - + out_dict[key] = self.analyze_element(value) + return out_dict elif isinstance(element, torch.Tensor): - out = self.analyze_tensor(element, self.save_real_data) - + out_tensor = self.analyze_tensor(element, self.save_real_data) + return out_tensor elif self.is_builtin_class(element): - out = self.analyze_builtin(element) + out_builtin = self.analyze_builtin(element) + return out_builtin else: msg = f"Type {type(element)} is unsupported at analyze_element" print_error_log(msg) raise NotImplementedError(msg) - return out def analyze_tensor(self, arg, save_real_data): single_arg = {} diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py index f3890818070..735a36a092c 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/overflow_check.py @@ -22,6 +22,7 @@ backward_api_info = {} FORWARD_REAL_DATA_PATH = os.path.join('./', 'forward_real_data') BACKWARD_REAL_DATA_PATH = os.path.join('./', 'backward_real_data') rank = os.getpid() +pkl_name = '' def check_overflow_environment(pid): @@ -75,6 +76,7 @@ def check_data_overflow(x): def check_path(apis, path): return any(api in path for api in apis) + def overflow_check(name, **kwargs): overflow_nums = OverFlowUtil.overflow_nums pid = kwargs.get('pid') @@ -153,7 +155,6 @@ def overflow_check(name, **kwargs): write_api_info_json(backward_api_info[key]) raise ValueError("[overflow {} times]: dump file is saved in '{}'." .format(OverFlowUtil.real_overflow_dump_times, os.path.realpath(dump_file_name))) - return def overflow_type_judge(in_feat, out_feat, module_name): if module_name.endswith(Const.BACKWARD): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/visualization.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/visualization.py index a46f7196668..10dde7e0894 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/visualization.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/visualization.py @@ -32,7 +32,7 @@ class Visualization: except UnicodeError as e: self.util.log.error("%s %s" % ("UnicodeError", str(e))) self.util.log.warning("Please check the npy file") - raise ParseException(ParseException.PARSE_UNICODE_ERROR) + raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e table = self.util.create_table('', ['Index', 'Data']) flatten_data = np_data.flatten() for i in range(min(16, int(np.ceil(flatten_data.size / 8)))): @@ -66,7 +66,7 @@ class Visualization: except json.JSONDecodeError as e: self.util.log.error("%s %s in line %s" % ("JSONDecodeError", str(e), pkl_line)) self.util.log.warning("Please check the pkl file") - raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) + raise ParseException(ParseException.PARSE_JSONDECODE_ERROR) from e info_prefix = msg[0] if not info_prefix.startswith(api_name): continue -- Gitee From 93dbc786d6d119efb3c1fd8c1befba399f4da208 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Tue, 31 Oct 2023 12:00:20 +0800 Subject: [PATCH 2/4] clean code --- .../src/python/ptdbg_ascend/common/utils.py | 23 ++++++++++++------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index 30fe40a95be..f2361a90b76 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -210,7 +210,7 @@ def make_dump_path_if_not_exists(dump_path): except OSError as ex: print_error_log( 'Failed to create {}.Please check the path permission or disk space .{}'.format(dump_path, str(ex))) - raise CompareException(CompareException.INVALID_PATH_ERROR) + raise CompareException(CompareException.INVALID_PATH_ERROR) from ex else: if not os.path.isdir(dump_path): print_error_log('{} already exists and is not a directory.'.format(dump_path)) @@ -253,7 +253,11 @@ def print_warn_log(warn_msg): _print_log("WARNING", warn_msg) -def check_mode_valid(mode, scope=[], api_list=[]): +def check_mode_valid(mode, scope=None, api_list=None): + if scope is None: + scope = [] + if api_list is None: + api_list = [] if not isinstance(scope, list): raise ValueError("scope param set invalid, it's must be a list.") if not isinstance(api_list, list): @@ -272,8 +276,8 @@ def check_mode_valid(mode, scope=[], api_list=[]): (mode, Const.DUMP_MODE) raise CompareException(CompareException.INVALID_DUMP_MODE, msg) - if mode_check[mode]() is not None: - raise mode_check[mode]() + if mode_check.get(mode)() is not None: + raise mode_check.get(mode)() def check_switch_valid(switch): @@ -281,6 +285,7 @@ def check_switch_valid(switch): print_error_log("Please set switch with 'ON' or 'OFF'.") raise CompareException(CompareException.INVALID_PARAM_ERROR) + def check_dump_mode_valid(dump_mode): if not isinstance(dump_mode, list): print_warn_log("Please set dump_mode as a list.") @@ -295,12 +300,14 @@ def check_dump_mode_valid(dump_mode): return ["forward", "backward", "input", "output"] return dump_mode + def check_summary_only_valid(summary_only): if not isinstance(summary_only, bool): print_error_log("Params summary_only only support True or False.") raise CompareException(CompareException.INVALID_PARAM_ERROR) return summary_only + def check_compare_param(input_parma, output_path, stack_mode=False, auto_analyze=True, fuzzy_match=False): # 添加默认值来让不传参时能通过参数检查 if not (isinstance(input_parma, dict) and isinstance(output_path, str) @@ -348,8 +355,8 @@ def _check_pkl(pkl_file_handle, file_name): pkl_file_handle.seek(0, 0) -def is_starts_with(string, prefixes): - return any(string.startswith(prefix) for prefix in prefixes) +def is_starts_with(string, prefix_list): + return any(string.startswith(prefix) for prefix in prefix_list) def check_file_mode(npu_pkl, bench_pkl, stack_mode): @@ -394,9 +401,9 @@ def remove_path(path): os.remove(path) else: shutil.rmtree(path) - except PermissionError: + except PermissionError as err: print_error_log("Failed to delete {}. Please check the permission.".format(path)) - raise CompareException(CompareException.INVALID_PATH_ERROR) + raise CompareException(CompareException.INVALID_PATH_ERROR) from err def get_dump_data_path(dump_dir): -- Gitee From 6cb7b2f1d0508bcd3694eb9eaa1c09d985813c8f Mon Sep 17 00:00:00 2001 From: l30044004 Date: Tue, 31 Oct 2023 15:21:14 +0800 Subject: [PATCH 3/4] clean code --- .../ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index db70bc7946a..ce472247d4e 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -71,14 +71,16 @@ class DumpUtil(object): DumpUtil.dump_switch_scope = [api_name.replace("backward", "forward") for api_name in scope] DumpUtil.summary_only = summary_only - def check_list_or_acl_mode(self, name_prefix): + @staticmethod + def check_list_or_acl_mode(name_prefix): global dump_count for item in DumpUtil.dump_switch_scope: if name_prefix.startswith(item): dump_count = dump_count + 1 return True - def check_range_mode(self, name_prefix): + @staticmethod + def check_range_mode(name_prefix): global range_begin_flag global range_end_flag if name_prefix.startswith(DumpUtil.dump_switch_scope[0]): @@ -91,7 +93,8 @@ class DumpUtil(object): return True return False - def check_stack_mode(self, name_prefix): + @staticmethod + def check_stack_mode(name_prefix): if len(DumpUtil.dump_switch_scope) == 0: return True elif len(DumpUtil.dump_switch_scope) == 1: -- Gitee From d2bb731d504ce6862cb0d36e60e67beb667f4509 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Tue, 31 Oct 2023 18:27:51 +0800 Subject: [PATCH 4/4] clean code --- .../src/python/ptdbg_ascend/dump/utils.py | 70 +++++++++---------- 1 file changed, 35 insertions(+), 35 deletions(-) diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py index ce472247d4e..84b9a4dc63c 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/utils.py @@ -17,6 +17,41 @@ dump_count = 0 range_begin_flag, range_end_flag = False, False +def check_list_or_acl_mode(name_prefix): + global dump_count + for item in DumpUtil.dump_switch_scope: + if name_prefix.startswith(item): + dump_count = dump_count + 1 + return True + + +def check_range_mode(name_prefix): + global range_begin_flag + global range_end_flag + if name_prefix.startswith(DumpUtil.dump_switch_scope[0]): + range_begin_flag = True + return True + if name_prefix.startswith(DumpUtil.dump_switch_scope[1]): + range_end_flag = True + return True + if range_begin_flag and not range_end_flag: + return True + return False + + +def check_stack_mode(name_prefix): + if len(DumpUtil.dump_switch_scope) == 0: + return True + elif len(DumpUtil.dump_switch_scope) == 1: + return name_prefix.startswith(DumpUtil.dump_switch_scope[0]) + elif len(DumpUtil.dump_switch_scope) == 2: + return check_range_mode(name_prefix) + else: + print_error_log("dump scope is invalid, Please set the scope mode in" + " set_dump_switch with 'all', 'list', 'range', 'stack', 'acl', 'api_list'!") + return False + + class DumpUtil(object): dump_root = None dump_data_dir = None @@ -71,41 +106,6 @@ class DumpUtil(object): DumpUtil.dump_switch_scope = [api_name.replace("backward", "forward") for api_name in scope] DumpUtil.summary_only = summary_only - @staticmethod - def check_list_or_acl_mode(name_prefix): - global dump_count - for item in DumpUtil.dump_switch_scope: - if name_prefix.startswith(item): - dump_count = dump_count + 1 - return True - - @staticmethod - def check_range_mode(name_prefix): - global range_begin_flag - global range_end_flag - if name_prefix.startswith(DumpUtil.dump_switch_scope[0]): - range_begin_flag = True - return True - if name_prefix.startswith(DumpUtil.dump_switch_scope[1]): - range_end_flag = True - return True - if range_begin_flag and not range_end_flag: - return True - return False - - @staticmethod - def check_stack_mode(name_prefix): - if len(DumpUtil.dump_switch_scope) == 0: - return True - elif len(DumpUtil.dump_switch_scope) == 1: - return name_prefix.startswith(DumpUtil.dump_switch_scope[0]) - elif len(DumpUtil.dump_switch_scope) == 2: - return DumpUtil.check_range_mode(name_prefix) - else: - print_error_log("dump scope is invalid, Please set the scope mode in" - " set_dump_switch with 'all', 'list', 'range', 'stack', 'acl', 'api_list'!") - return False - check_mapper = { Const.LIST: check_list_or_acl_mode, Const.ACL: check_list_or_acl_mode, -- Gitee