From a5dae7c65a683f00a816657138948e3f2546168e Mon Sep 17 00:00:00 2001 From: liguanchi Date: Fri, 14 Jun 2024 16:54:00 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BA=86warning=5Fflag?= =?UTF-8?q?=E7=9A=84=E5=88=A4=E6=96=AD=E5=85=AC=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/atat/pytorch/compare/acc_compare.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index b6a06ede3a..f74df65c48 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -335,14 +335,13 @@ def get_accuracy(result, n_dict, b_dict, summary_compare=False, md5_compare=Fals if isinstance(npu_val, (float, int)) and isinstance(bench_val, (float, int)): diff = npu_val - bench_val if bench_val != 0: - relative = str(abs((diff/bench_val) * 100)) + '%' + if abs(diff/bench_val) > 0.5: + warning_flag = True + relative = str(abs(diff/bench_val) * 100) + '%' else: relative = "N/A" result_item[start_idx + i] = diff result_item[start_idx + i + 4] = relative - magnitude_diff = abs(diff) / (max(abs(npu_val), abs(bench_val)) + 1e-10) - if magnitude_diff > 0.5: - warning_flag = True else: result_item[start_idx + i] = CompareConst.NONE accuracy_check = CompareConst.WARNING if warning_flag else "" -- Gitee