From 14eacfccfdf45cf4f967e6e8ed320a5523cf46a7 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Wed, 20 Sep 2023 15:11:05 +0800 Subject: [PATCH] bug fix --- .../api_accuracy_checker/compare/compare.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 59dee51001f..3461a64e5b8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -137,6 +137,7 @@ class Comparator: npu_dtype_total = [] shape_total = [] test_success_total = True + max_abs_error_success = False for name in self.compare_alg.keys(): alg = self.compare_alg[name][0] detailed_result, test_success, bench_dtype, npu_dtype, shape = compare_core(bench_out, npu_out, alg) @@ -145,15 +146,14 @@ class Comparator: shape_total = shape if name not in ["Max Relative Error", "Max Absolute Error"]: test_success_total = test_success_total and test_success + if name == "Max Absolute Error": + max_abs_error_success = test_success if detailed_result_total: for i in range(len(detailed_result_total)): detailed_result_total[i] += detailed_result[i] else: detailed_result_total = detailed_result - for name in self.compare_alg.keys(): - alg = self.compare_alg[name][0] - if name == "Max Absolute Error": - test_success_total = test_success_total or test_success + test_success_total = test_success_total or max_abs_error_success # dtype加到所有指标的前面, 是否pass放到所有指标的后面 for i in range(len(detailed_result_total)): detailed_result = list(detailed_result_total[i]) -- Gitee