diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 9b6b85307345beaaff678caafc6d4f77f4690fd2..cff0070cb668f498bbdd46509206447db46e9e61 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -51,10 +51,13 @@ def get_msg_and_handle_value(n_value, b_value): return n_value, b_value, msg -def get_max_rel_err(n_value, b_value): - n_value, b_value, msg = get_msg_and_handle_value(n_value, b_value) +def get_max_rel_err(b_value, n_value): + b_value, n_value, msg = get_msg_and_handle_value(b_value, n_value) rel_err = np.abs((n_value - b_value) / b_value).max() - bool_result = rel_err < 0.001 + if n_value.dtype == np.float32: + bool_result = rel_err < 0.0001 + else: + bool_result = rel_err < 0.001 return rel_err, bool_result, msg def get_max_abs_err(n_value, b_value): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index a584405c007ec02ca56a92d8cc008eea35dd29db..59dee51001f68329f2163c76b4afcb93433b38c3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -143,13 +143,17 @@ class Comparator: bench_dtype_total = bench_dtype npu_dtype_total = npu_dtype shape_total = shape - if name != "Max Relative Error" and name != "Max Absolute Error": + if name not in ["Max Relative Error", "Max Absolute Error"]: test_success_total = test_success_total and test_success if detailed_result_total: for i in range(len(detailed_result_total)): detailed_result_total[i] += detailed_result[i] else: detailed_result_total = detailed_result + for name in self.compare_alg.keys(): + alg = self.compare_alg[name][0] + if name == "Max Absolute Error": + test_success_total = test_success_total or test_success # dtype加到所有指标的前面, 是否pass放到所有指标的后面 for i in range(len(detailed_result_total)): detailed_result = list(detailed_result_total[i])