Try creating a bf16 tensor as a last resort of is_bf16_supported(). (#115924)

Fix: #115900 https://github.com/pytorch/xla/issues/6085

This PR adds a last resort for testing for BF16 support on CUDA. This is necessary on GPUs
such as RTX 2060, where `torch.cuda.is_bf_supported()` returns False, but we can
successfully create a BF16 tensor on CUDA.

Before this PR:

```python
>>> torch.cuda.is_bf_supported()
False
>>> torch.tensor([1.], dtype=torch.bfloat16, device="cuda")
tensor([...], device='cuda:0', dtype=torch.bfloat16)
```

After this PR:

```python
>>> torch.cuda.is_bf_supported()
True
>>> torch.tensor([1.], dtype=torch.bfloat16, device="cuda")
tensor([...], device='cuda:0', dtype=torch.bfloat16)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115924
Approved by: https://github.com/jansel
This commit is contained in:
Yukio Siraichi 2023-12-15 10:22:02 -03:00 committed by PyTorch MergeBot
parent 127812efee
commit fc5fda14bc

View file

@ -148,15 +148,29 @@ def is_bf16_supported():
if torch.version.hip:
return True
cu_vers = torch.version.cuda
if cu_vers is not None:
cuda_maj_decide = int(cu_vers.split(".")[0]) >= 11
else:
cuda_maj_decide = False
return (
torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8
and cuda_maj_decide
)
device = torch.cuda.current_device()
# Check for CUDA version and device compute capability.
# This is a fast way to check for it.
cuda_version = torch.version.cuda
if (
cuda_version is not None
and int(cuda_version.split(".")[0]) >= 11
and torch.cuda.get_device_properties(device).major >= 8
):
return True
# Finally try to create a bfloat16 device.
return _check_bf16_tensor_supported(device)
@lru_cache(maxsize=16)
def _check_bf16_tensor_supported(device: _device_t):
try:
torch.tensor([1.0], dtype=torch.bfloat16, device=device)
return True
except Exception:
return False
def _sleep(cycles):