diff --git a/debug/accuracy_tools/api_accuracy_checker/README.md b/debug/accuracy_tools/api_accuracy_checker/README.md index 52b5e68f80999951e16332e5554dbb1315e673c3..0de1731460469b31f82240e1f8e196cac0934a56 100644 --- a/debug/accuracy_tools/api_accuracy_checker/README.md +++ b/debug/accuracy_tools/api_accuracy_checker/README.md @@ -22,7 +22,7 @@ Ascend模型精度预检工具能在昇腾NPU上扫描用户训练模型中所 安装依赖tqdm、prettytable、yaml ```bash - pip3 install tqdm prettytable yaml + pip3 install tqdm prettytable pyyaml ``` 2. 在训练脚本(如main.py)中加入以下代码导入工具dump模块,启动训练即可自动抓取网络所有API信息 diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index b0b1aaf605937f0cbf932d83811e6d2735d5a289..7314f517368772b1db5d6b2b62c4a5c06fdc8392 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -61,6 +61,11 @@ def get_max_rel_err(n_value, b_value): bool_result = rel_err < 0.001 return rel_err, bool_result, msg +def get_max_abs_err(n_value, b_value): + n_value, b_value, msg = get_msg_and_handle_value(n_value, b_value) + abs_err = np.abs(n_value - b_value).max() + bool_result = abs_err < 0.001 + return abs_err, bool_result, msg def get_rel_err_ratio_thousandth(n_value, b_value): return get_rel_err_ratio(n_value, b_value, 0.001) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 2192359a02d0d0b450f33ad8c25fe2869b3fe491..a7221cb31f3de857427235d9c7840b8fdade144c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,7 +1,7 @@ # 进行比对及结果展示 import os from prettytable import PrettyTable -from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, \ +from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, get_max_abs_err, \ compare_builtin_type, get_rel_err_ratio_thousandth, get_rel_err_ratio_ten_thousandth from api_accuracy_checker.common.utils import get_json_contents, print_info_log, write_csv from api_accuracy_checker.compare.compare_utils import CompareConst @@ -27,6 +27,7 @@ class Comparator: self.compare_alg = {} self.register_compare_algorithm("Cosine Similarity", cosine_sim, cosine_standard) self.register_compare_algorithm("Max Relative Error", get_max_rel_err, None) + self.register_compare_algorithm("Max Absolute Error", get_max_abs_err, None) self.register_compare_algorithm("Thousandth Relative Error Ratio", get_rel_err_ratio_thousandth, None) self.register_compare_algorithm("Ten Thousandth Relative Error Ratio", get_rel_err_ratio_ten_thousandth, None) self.register_compare_algorithm("Default: isEqual", compare_builtin_type, None) @@ -60,6 +61,7 @@ class Comparator: "Npu Name", "Bench Dtype", "NPU Dtype", "Cosine Similarity", "Cosine Similarity Message", "Max Rel Error", "Max Rel Err Message", + "Max Abs Error", "Max Abs Err Message", "Relative Error (dual thousandth)", "Relative Error (dual thousandth) Message", "Relative Error (dual ten thousandth)", "Relative Error (dual ten thousandth) Message", "Compare Builtin Type", "Builtin Type Message", @@ -139,7 +141,7 @@ class Comparator: detailed_result, test_success, bench_dtype, npu_dtype = compare_core(bench_out, npu_out, alg) bench_dtype_total = bench_dtype npu_dtype_total = npu_dtype - if name != "Max Relative Error": + if name != "Max Relative Error" and name != "Max Absolute Error": test_success_total = test_success_total and test_success if detailed_result_total: for i in range(len(detailed_result_total)):