diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 29b10608e496a7808a13dcb8656521e6bcc8cba8..7caf8a7b5224b0194de7b50c8ecb095a1ffb3fc1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -168,9 +168,12 @@ def compare_core(bench_out, npu_out, alg): CompareConst.NAN, CompareConst.NAN compare_result, test_success, bench_dtype, npu_dtype = compare_core(list(bench_out.values()), list(npu_out.values()), alg) elif isinstance(bench_out, torch.Tensor): - compare_result, test_success, msg = compare_torch_tensor(bench_out.detach().numpy(), npu_out.detach().cpu().numpy(), alg) bench_dtype = str(bench_out.dtype) npu_dtype = str(npu_out.dtype) + if bench_out.dtype == torch.bfloat16: + bench_out = bench_out.to(torch.float32) + npu_out = npu_out.to(torch.float32) + compare_result, test_success, msg = compare_torch_tensor(bench_out.detach().numpy(), npu_out.detach().cpu().numpy(), alg) elif isinstance(bench_out, (bool, int, float, str)): compare_result, test_success, msg = compare_builtin_type(bench_out, npu_out) bench_dtype = str(type(bench_out))