diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 627bd769766863683d2191509a54f5e28660a623..f80b97cf1c5edd38540a2edf4ee17ed7c3bf4dc3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -118,6 +118,8 @@ class BaseAPIInfo: return True in data elif operator == 'min': return False not in data + if data.dtype is torch.bfloat16: + data = data.to(torch.float32) if operator == 'max': return torch._C._VariableFunctionsClass.max(data).item() else: