diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 5cb0777e43f0d51d2c3864b55ee887e927750f13..ed3c50a0cd352590cea358c3c629cd3c3fedb29e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -93,7 +93,7 @@ class Comparator: def _compare_dropout(bench_out, npu_out): tensor_num = bench_out.numel() if tensor_num >= 100: - if abs((bench_out == 0).sum() - (npu_out == 0).sum()) / tensor_num < 0.1: + if abs((bench_out == 0).sum() - (npu_out == 0).cpu().sum()) / tensor_num < 0.1: return True, 1 else: return False, 0