From a2ccda08920ac152be26ec2c74cdbbaf20497fe4 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Tue, 24 Oct 2023 02:21:52 +0000 Subject: [PATCH] update debug/accuracy_tools/api_accuracy_checker/common/base_api.py. fix bfloat16 item() error Signed-off-by: sunyiming --- debug/accuracy_tools/api_accuracy_checker/common/base_api.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 627bd7697..f80b97cf1 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -118,6 +118,8 @@ class BaseAPIInfo: return True in data elif operator == 'min': return False not in data + if data.dtype is torch.bfloat16: + data = data.to(torch.float32) if operator == 'max': return torch._C._VariableFunctionsClass.max(data).item() else: -- Gitee