From 175161806c70cbbde287e9c317d510378ba92ea6 Mon Sep 17 00:00:00 2001 From: h00613304 Date: Sat, 5 Aug 2023 17:04:39 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BF=AE=E5=A4=8Ddropout=E7=B1=BB=E8=BF=90?= =?UTF-8?q?=E7=AE=97=E7=9A=84=E6=AF=94=E5=AF=B9=E6=96=B9=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/compare/compare.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 5cb0777e4..ed3c50a0c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -93,7 +93,7 @@ class Comparator: def _compare_dropout(bench_out, npu_out): tensor_num = bench_out.numel() if tensor_num >= 100: - if abs((bench_out == 0).sum() - (npu_out == 0).sum()) / tensor_num < 0.1: + if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < 0.1: return True, 1 else: return False, 0 -- Gitee