From aa0c7025622ec931e0131c85385a8469ce5b70ed Mon Sep 17 00:00:00 2001 From: s30048155 Date: Fri, 8 Sep 2023 09:39:28 +0800 Subject: [PATCH 1/2] Max Absolute Error --- .../api_accuracy_checker/compare/algorithm.py | 5 +++++ .../accuracy_tools/api_accuracy_checker/compare/compare.py | 6 ++++-- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index b0b1aaf605..7314f51736 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -61,6 +61,11 @@ def get_max_rel_err(n_value, b_value): bool_result = rel_err < 0.001 return rel_err, bool_result, msg +def get_max_abs_err(n_value, b_value): + n_value, b_value, msg = get_msg_and_handle_value(n_value, b_value) + abs_err = np.abs(n_value - b_value).max() + bool_result = abs_err < 0.001 + return abs_err, bool_result, msg def get_rel_err_ratio_thousandth(n_value, b_value): return get_rel_err_ratio(n_value, b_value, 0.001) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 2192359a02..a7221cb31f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,7 +1,7 @@ # 进行比对及结果展示 import os from prettytable import PrettyTable -from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, \ +from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, get_max_abs_err, \ compare_builtin_type, get_rel_err_ratio_thousandth, get_rel_err_ratio_ten_thousandth from api_accuracy_checker.common.utils import get_json_contents, print_info_log, write_csv from api_accuracy_checker.compare.compare_utils import CompareConst @@ -27,6 +27,7 @@ class Comparator: self.compare_alg = {} self.register_compare_algorithm("Cosine Similarity", cosine_sim, cosine_standard) self.register_compare_algorithm("Max Relative Error", get_max_rel_err, None) + self.register_compare_algorithm("Max Absolute Error", get_max_abs_err, None) self.register_compare_algorithm("Thousandth Relative Error Ratio", get_rel_err_ratio_thousandth, None) self.register_compare_algorithm("Ten Thousandth Relative Error Ratio", get_rel_err_ratio_ten_thousandth, None) self.register_compare_algorithm("Default: isEqual", compare_builtin_type, None) @@ -60,6 +61,7 @@ class Comparator: "Npu Name", "Bench Dtype", "NPU Dtype", "Cosine Similarity", "Cosine Similarity Message", "Max Rel Error", "Max Rel Err Message", + "Max Abs Error", "Max Abs Err Message", "Relative Error (dual thousandth)", "Relative Error (dual thousandth) Message", "Relative Error (dual ten thousandth)", "Relative Error (dual ten thousandth) Message", "Compare Builtin Type", "Builtin Type Message", @@ -139,7 +141,7 @@ class Comparator: detailed_result, test_success, bench_dtype, npu_dtype = compare_core(bench_out, npu_out, alg) bench_dtype_total = bench_dtype npu_dtype_total = npu_dtype - if name != "Max Relative Error": + if name != "Max Relative Error" and name != "Max Absolute Error": test_success_total = test_success_total and test_success if detailed_result_total: for i in range(len(detailed_result_total)): -- Gitee From 6db9566b6e61d1d211ba8ed3fe9d55fa5a698ea0 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Fri, 8 Sep 2023 17:26:06 +0800 Subject: [PATCH 2/2] update --- debug/accuracy_tools/api_accuracy_checker/README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/README.md b/debug/accuracy_tools/api_accuracy_checker/README.md index 52b5e68f80..0de1731460 100644 --- a/debug/accuracy_tools/api_accuracy_checker/README.md +++ b/debug/accuracy_tools/api_accuracy_checker/README.md @@ -22,7 +22,7 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 安装依赖tqdm、prettytable、yaml ```bash - pip3 install tqdm prettytable yaml + pip3 install tqdm prettytable pyyaml ``` 2. 在训练脚本(如main.py)中加入以下代码导入工具dump模块,启动训练即可自动抓取网络所有API信息 -- Gitee