From 82c8c4260fb5d989ccfd63e17b80edc447eacaf2 Mon Sep 17 00:00:00 2001 From: pengxiaopeng Date: Wed, 6 Dec 2023 09:37:36 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Bugfix=E3=80=91=E8=A7=A3=E5=86=B3?= =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7=E8=A7=A3=E6=9E=90=20bool?= =?UTF-8?q?=20=E7=B1=BB=E5=9E=8B=E6=95=B0=E6=8D=AE=E6=97=B6=E6=8A=A5?= =?UTF-8?q?=E9=94=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/compare/algorithm.py | 9 +++++---- .../test/ut/compare/test_algorithm.py | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index c01fa300a..71df55350 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -174,12 +174,13 @@ def compare_uint8_data(b_value, n_value): return 0, False -def compare_builtin_type(bench_out, npu_out): +def compare_builtin_type(bench_out, npu_out, compare_column): if not isinstance(bench_out, (bool, int, float, str)): - return CompareConst.NA, CompareConst.PASS, "" + return CompareConst.PASS, compare_column, "" if bench_out != npu_out: - return CompareConst.NA, CompareConst.ERROR, "" - return True, CompareConst.PASS, "" + return CompareConst.ERROR, compare_column, "" + compare_column.error_rate = 0 + return CompareConst.PASS, compare_column, "" def flatten_compare_result(result): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py index ba35b4670..12f954e61 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py @@ -66,9 +66,11 @@ class TestAlgorithmMethods(unittest.TestCase): self.assertEqual(alg.compare_uint8_data(b_value, n_value), (1, True)) def test_compare_builtin_type(self): + compare_column = CompareColumn() bench_out = 1 npu_out = 1 - self.assertEqual(alg.compare_builtin_type(bench_out, npu_out), (True, 'pass', '')) + status, compare_result, message = alg.compare_builtin_type(bench_out, npu_out, compare_column) + self.assertEqual((status, compare_result.error_rate, message), ('pass', 0, '')) def test_flatten_compare_result(self): result = [[1, 2], [3, 4]] -- Gitee