From 32db636d15fb58a8443f94e52981afdb0b620799 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 18 Sep 2023 10:17:39 +0800 Subject: [PATCH 1/5] max_rel_err --- .../accuracy_tools/api_accuracy_checker/compare/algorithm.py | 5 ++++- debug/accuracy_tools/api_accuracy_checker/compare/compare.py | 2 +- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 9b6b853073..53181614da 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -54,7 +54,10 @@ def get_msg_and_handle_value(n_value, b_value): def get_max_rel_err(n_value, b_value): n_value, b_value, msg = get_msg_and_handle_value(n_value, b_value) rel_err = np.abs((n_value - b_value) / b_value).max() - bool_result = rel_err < 0.001 + if b_value.dtype == np.float16: + bool_result = rel_err < 0.001 + elif b_value.dtype == np.float32: + bool_result = rel_err < 0.0001 return rel_err, bool_result, msg def get_max_abs_err(n_value, b_value): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index a584405c00..c671716c4b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -143,7 +143,7 @@ class Comparator: bench_dtype_total = bench_dtype npu_dtype_total = npu_dtype shape_total = shape - if name != "Max Relative Error" and name != "Max Absolute Error": + if name != "Max Relative Error": test_success_total = test_success_total and test_success if detailed_result_total: for i in range(len(detailed_result_total)): -- Gitee From bf62f73a110acfc9ab5ba594d270c87f59ae3b43 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 18 Sep 2023 11:48:46 +0800 Subject: [PATCH 2/5] update --- .../accuracy_tools/api_accuracy_checker/compare/compare.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index c671716c4b..59dee51001 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -143,13 +143,17 @@ class Comparator: bench_dtype_total = bench_dtype npu_dtype_total = npu_dtype shape_total = shape - if name != "Max Relative Error": + if name not in ["Max Relative Error", "Max Absolute Error"]: test_success_total = test_success_total and test_success if detailed_result_total: for i in range(len(detailed_result_total)): detailed_result_total[i] += detailed_result[i] else: detailed_result_total = detailed_result + for name in self.compare_alg.keys(): + alg = self.compare_alg[name][0] + if name == "Max Absolute Error": + test_success_total = test_success_total or test_success # dtype加到所有指标的前面, 是否pass放到所有指标的后面 for i in range(len(detailed_result_total)): detailed_result = list(detailed_result_total[i]) -- Gitee From 95a57aaefae2c131dfc665beb7285c11e2102bd5 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 18 Sep 2023 14:24:50 +0800 Subject: [PATCH 3/5] update --- .../api_accuracy_checker/compare/algorithm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 53181614da..8a7c3cabf3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -54,10 +54,10 @@ def get_msg_and_handle_value(n_value, b_value): def get_max_rel_err(n_value, b_value): n_value, b_value, msg = get_msg_and_handle_value(n_value, b_value) rel_err = np.abs((n_value - b_value) / b_value).max() - if b_value.dtype == np.float16: - bool_result = rel_err < 0.001 - elif b_value.dtype == np.float32: + if b_value.dtype == np.float32: bool_result = rel_err < 0.0001 + else: + bool_result = rel_err < 0.001 return rel_err, bool_result, msg def get_max_abs_err(n_value, b_value): -- Gitee From 734cf6bf9b1a702c793b5f45b80bad17dd5e496f Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 18 Sep 2023 14:32:49 +0800 Subject: [PATCH 4/5] fix name --- .../api_accuracy_checker/compare/algorithm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 8a7c3cabf3..f6a1e79249 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -51,10 +51,10 @@ def get_msg_and_handle_value(n_value, b_value): return n_value, b_value, msg -def get_max_rel_err(n_value, b_value): - n_value, b_value, msg = get_msg_and_handle_value(n_value, b_value) - rel_err = np.abs((n_value - b_value) / b_value).max() - if b_value.dtype == np.float32: +def get_max_rel_err(b_value, n_value): + b_value, n_value, msg = get_msg_and_handle_value(b_value, n_value) + rel_err = np.abs((b_value - n_value) / n_value).max() + if n_value.dtype == np.float32: bool_result = rel_err < 0.0001 else: bool_result = rel_err < 0.001 -- Gitee From c8da53d4afb9a12587740a10acffc4d6b1da86ce Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 18 Sep 2023 14:46:21 +0800 Subject: [PATCH 5/5] fix --- debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index f6a1e79249..cff0070cb6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -53,7 +53,7 @@ def get_msg_and_handle_value(n_value, b_value): def get_max_rel_err(b_value, n_value): b_value, n_value, msg = get_msg_and_handle_value(b_value, n_value) - rel_err = np.abs((b_value - n_value) / n_value).max() + rel_err = np.abs((n_value - b_value) / b_value).max() if n_value.dtype == np.float32: bool_result = rel_err < 0.0001 else: -- Gitee