diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py index 1fac930d60ea17a38d55e57c2da800c95324c321..7db518793c33c3a8848cd64f703a3e1d6ee0ffe6 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/compare/acc_compare.py @@ -67,15 +67,33 @@ def cosine_similarity(n_value, b_value): def get_rmse(n_value, b_value): - rmse = np.linalg.norm(n_value - b_value) / np.sqrt(len(n_value)) + if len(n_value) == 0 and len(b_value) == 0: + rmse = '0' + elif len(n_value) == 0: + rmse = CompareConst.NAN + elif len(b_value) == 0: + rmse = CompareConst.NAN + else: + rmse = np.linalg.norm(n_value - b_value) / np.sqrt(len(n_value)) if np.isnan(rmse): rmse = CompareConst.NAN return rmse, "" def get_mape(n_value, b_value): - mape_val = np.sum(np.abs((n_value - b_value) / b_value)) / len(b_value) * 100 - mape = CompareConst.NAN if np.isnan(mape_val) else str(round(mape_val, 4)) + '%' + if len(n_value) == 0 and len(b_value) == 0: + mape = '0' + elif len(n_value) == 0: + mape = CompareConst.NAN + elif len(b_value) == 0: + mape = CompareConst.NAN + elif not np.all(n_value) and not np.all(b_value): + mape = '0' + elif not np.all(b_value): + mape = CompareConst.NAN + else: + mape_val = np.sum(np.abs((n_value - b_value) / b_value)) / len(b_value) * 100 + mape = CompareConst.NAN if np.isnan(mape_val) else str(round(mape_val, 4)) + '%' return mape, "" @@ -116,8 +134,7 @@ def check_op(npu_dict, bench_dict, fuzzy_match): except Exception as err: print_warn_log("%s and %s can not fuzzy match." % (a_op_name, b_op_name)) is_match = False - finally: - return is_match and struct_match + return is_match and struct_match def check_struct_match(npu_dict, bench_dict): @@ -305,12 +322,12 @@ def read_dump_path(result_path): bench_dump_name = bench_dump_name_list[index] op_name_mapping_dict[npu_dump_name] = [npu_dump_name, bench_dump_name] return op_name_mapping_dict - except FileNotFoundError as error: - print(error) - raise FileNotFoundError(error) - except IOError as error: - print(error) - raise IOError(error) + except FileNotFoundError as e: + print_error_log('{} file is not found.'.format(result_path)) + raise CompareException(CompareException.OPEN_FILE_ERROR) from e + except IOError as e: + print_error_log('{} read csv failed.'.format(result_path)) + raise CompareException(CompareException.READ_FILE_ERROR) from e def _handle_multi_process(func, input_parma, result_path, lock): @@ -326,13 +343,13 @@ def _handle_multi_process(func, input_parma, result_path, lock): pool = multiprocessing.Pool(process_num) def err_call(args): + print_error_log('multiprocess compare failed! season:{}'.format(args)) try: pool.terminate() if os.path.exists(result_path): os.remove(result_path) - sys.exit(args) - except SystemExit as error: - print('multiprocess compare failed! season:{}'.format(args)) + except OSError as e: + print_error_log("pool terminate failed") for process_idx, fusion_op_names in enumerate(op_names): idx = [process_num, process_idx] @@ -377,12 +394,12 @@ def _save_cmp_result(idx, cos_result, max_err_result, max_relative_err_result, e csv_pd.loc[process_index, CompareConst.ERROR_MESSAGE] = err_msg[i] csv_pd.loc[process_index, CompareConst.ACCURACY] = check_accuracy(cos_result[i], max_err_result[i]) csv_pd.to_csv(result_path, index=False) - except FileNotFoundError as error: - print(error) - raise FileNotFoundError(error) - except IOError as error: - print(error) - raise IOError(error) + except FileNotFoundError as e: + print_error_log('{} file is not found.'.format(result_path)) + raise CompareException(CompareException.OPEN_FILE_ERROR) from e + except IOError as e: + print_error_log('{} read csv failed.'.format(result_path)) + raise CompareException(CompareException.READ_FILE_ERROR) from e finally: lock.release() diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/debugger_config.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/debugger_config.py index 7278dfb26e64ec2869f1fc877c7b56bacaca143f..b08b7d661117a4341f6480d1d92cc5094f7a5d60 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/debugger_config.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/debugger_config.py @@ -3,11 +3,11 @@ from ..common.utils import print_warn_log class DebuggerConfig: - def __init__(self, dump_path, hook_name, rank=None, step=[]): + def __init__(self, dump_path, hook_name, rank=None, step=None): self.dump_path = dump_path self.hook_name = hook_name self.rank = rank - self.step = step + self.step = step or [] self.check() if self.step: self.step.sort() diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py index de1ff96b2633bd4e123e963c00c10538c92d72ba..c85ba0d5212f912e5b9496457bbd87a63147c695 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/debugger/precision_debugger.py @@ -17,11 +17,12 @@ class PrecisionDebugger: hook_func = None config = None - def __init__(self, dump_path=None, hook_name=None, rank=None, step=[], enable_dataloader=False): + def __init__(self, dump_path=None, hook_name=None, rank=None, step=None, enable_dataloader=False): if hook_name is None: err_msg = "You must provide hook_name argument to PrecisionDebugger\ when config is not provided." raise Exception(err_msg) + step = step or [] self.config = DebuggerConfig(dump_path, hook_name, rank, step) self.configure_hook = self.get_configure_hook(self.config.hook_name) self.configure_hook() @@ -40,8 +41,11 @@ class PrecisionDebugger: hook_dict = {"dump": self.configure_full_dump, "overflow_check": self.configure_overflow_dump} return hook_dict.get(hook_name, lambda: ValueError("hook name {} is not in ['dump', 'overflow_check']".format(hook_name))) - def configure_full_dump(self, mode='api_stack', scope=[], api_list=[], filter_switch=Const.ON, - input_output_mode=[Const.ALL], acl_config=None, backward_input=[], summary_only=False): + def configure_full_dump(self, mode='api_stack', scope=None, api_list=None, filter_switch=Const.ON, + input_output_mode=[Const.ALL], acl_config=None, backward_input=None, summary_only=False): + scope = scope or [] + api_list = api_list or [] + backward_input = backward_input or [] set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, dump_mode=input_output_mode, summary_only=summary_only) if mode == 'acl': diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/compare.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/compare.py index b7268a347c071f8547f0c37c3b33589951de098c..175b4f937337f3f74cd4d12e724f5c83eed6bcf4 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/compare.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/compare.py @@ -90,10 +90,10 @@ class Compare: except UnicodeError as e: self.log.error("%s %s" % ("UnicodeError", str(e))) self.log.warning("Please check the npy file") - raise ParseException(ParseException.PARSE_UNICODE_ERROR) + raise ParseException(ParseException.PARSE_UNICODE_ERROR) from e except IOError: self.log.error("Failed to load npy %s or %s." % (left, right)) - raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) + raise ParseException(ParseException.PARSE_LOAD_NPY_ERROR) from e # save to txt if save_txt: @@ -147,6 +147,9 @@ class Compare: if err_cnt < diff_count: err_table.add_row(str(i), str(data_left[i]), str(data_right[i]), str(abs_diff)) err_cnt += 1 - err_percent = float(err_cnt / total_cnt) + if total_cnt == 0: + err_percent = float(0) + else: + err_percent = float(err_cnt / total_cnt) self.util.print(self.util.create_columns([err_table, top_table])) return total_cnt, all_close, cos_sim, err_percent diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/utils.py index 04b4fa3af2340f62029ac4e5398e00a2d3374b06..20c5f6c749996796450a7c5f4728696d8120463d 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/parse_tool/lib/utils.py @@ -37,8 +37,7 @@ except ImportError as err: Table = None Columns = None rich_print = None - print("[Warning] Failed to import rich.", err) - print("[Warning] Some features may not be available. Please run 'pip install rich' to fix it.") + print("[Warning] Failed to import rich, Some features may not be available. Please run 'pip install rich' to fix it.") class Util: @@ -108,8 +107,8 @@ class Util: try: os.makedirs(path, mode=0o750) except OSError as e: - self.log.error("Failed to create %s. %s", path, str(e)) - raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) + self.log.error("Failed to create %s.", path) + raise ParseException(ParseException.PARSE_INVALID_PATH_ERROR) from e def gen_npy_info_txt(self, source_data): shape, dtype, max_data, min_data, mean = \