mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Try creating a bf16 tensor as a last resort of is_bf16_supported(). (#115924)
Fix: #115900 https://github.com/pytorch/xla/issues/6085 This PR adds a last resort for testing for BF16 support on CUDA. This is necessary on GPUs such as RTX 2060, where `torch.cuda.is_bf_supported()` returns False, but we can successfully create a BF16 tensor on CUDA. Before this PR: ```python >>> torch.cuda.is_bf_supported() False >>> torch.tensor([1.], dtype=torch.bfloat16, device="cuda") tensor([...], device='cuda:0', dtype=torch.bfloat16) ``` After this PR: ```python >>> torch.cuda.is_bf_supported() True >>> torch.tensor([1.], dtype=torch.bfloat16, device="cuda") tensor([...], device='cuda:0', dtype=torch.bfloat16) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/115924 Approved by: https://github.com/jansel
This commit is contained in:
parent
127812efee
commit
fc5fda14bc
1 changed files with 23 additions and 9 deletions
|
|
@ -148,15 +148,29 @@ def is_bf16_supported():
|
|||
if torch.version.hip:
|
||||
return True
|
||||
|
||||
cu_vers = torch.version.cuda
|
||||
if cu_vers is not None:
|
||||
cuda_maj_decide = int(cu_vers.split(".")[0]) >= 11
|
||||
else:
|
||||
cuda_maj_decide = False
|
||||
return (
|
||||
torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 8
|
||||
and cuda_maj_decide
|
||||
)
|
||||
device = torch.cuda.current_device()
|
||||
|
||||
# Check for CUDA version and device compute capability.
|
||||
# This is a fast way to check for it.
|
||||
cuda_version = torch.version.cuda
|
||||
if (
|
||||
cuda_version is not None
|
||||
and int(cuda_version.split(".")[0]) >= 11
|
||||
and torch.cuda.get_device_properties(device).major >= 8
|
||||
):
|
||||
return True
|
||||
|
||||
# Finally try to create a bfloat16 device.
|
||||
return _check_bf16_tensor_supported(device)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def _check_bf16_tensor_supported(device: _device_t):
|
||||
try:
|
||||
torch.tensor([1.0], dtype=torch.bfloat16, device=device)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _sleep(cycles):
|
||||
|
|
|
|||
Loading…
Reference in a new issue