From fe3ed5374ff7bb45fcecb52ebeeeb4653f47dd9c Mon Sep 17 00:00:00 2001 From: h00613304 Date: Tue, 29 Aug 2023 20:15:46 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7=E6=94=AF?= =?UTF-8?q?=E6=8C=81bf16?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../accuracy_tools/api_accuracy_checker/compare/algorithm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 29b10608e..7caf8a7b5 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -168,9 +168,12 @@ def compare_core(bench_out, npu_out, alg): CompareConst.NAN, CompareConst.NAN compare_result, test_success, bench_dtype, npu_dtype = compare_core(list(bench_out.values()), list(npu_out.values()), alg) elif isinstance(bench_out, torch.Tensor): - compare_result, test_success, msg = compare_torch_tensor(bench_out.detach().numpy(), npu_out.detach().cpu().numpy(), alg) bench_dtype = str(bench_out.dtype) npu_dtype = str(npu_out.dtype) + if bench_out.dtype == torch.bfloat16: + bench_out = bench_out.to(torch.float32) + npu_out = npu_out.to(torch.float32) + compare_result, test_success, msg = compare_torch_tensor(bench_out.detach().numpy(), npu_out.detach().cpu().numpy(), alg) elif isinstance(bench_out, (bool, int, float, str)): compare_result, test_success, msg = compare_builtin_type(bench_out, npu_out) bench_dtype = str(type(bench_out)) -- Gitee