diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index c01fa300ae0a66747b4aa833cf8482427f9fd795..71df553503b275de16839327c8867ed4771c689b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -174,12 +174,13 @@ def compare_uint8_data(b_value, n_value): return 0, False -def compare_builtin_type(bench_out, npu_out): +def compare_builtin_type(bench_out, npu_out, compare_column): if not isinstance(bench_out, (bool, int, float, str)): - return CompareConst.NA, CompareConst.PASS, "" + return CompareConst.PASS, compare_column, "" if bench_out != npu_out: - return CompareConst.NA, CompareConst.ERROR, "" - return True, CompareConst.PASS, "" + return CompareConst.ERROR, compare_column, "" + compare_column.error_rate = 0 + return CompareConst.PASS, compare_column, "" def flatten_compare_result(result): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py index ba35b4670b56d8a299767d2ae04974924f30b90b..12f954e6156d6c8224ef90b1f57e9608e3ef7b36 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/compare/test_algorithm.py @@ -66,9 +66,11 @@ class TestAlgorithmMethods(unittest.TestCase): self.assertEqual(alg.compare_uint8_data(b_value, n_value), (1, True)) def test_compare_builtin_type(self): + compare_column = CompareColumn() bench_out = 1 npu_out = 1 - self.assertEqual(alg.compare_builtin_type(bench_out, npu_out), (True, 'pass', '')) + status, compare_result, message = alg.compare_builtin_type(bench_out, npu_out, compare_column) + self.assertEqual((status, compare_result.error_rate, message), ('pass', 0, '')) def test_flatten_compare_result(self): result = [[1, 2], [3, 4]]