[NVIDIA] RTX50 Blackwell Support codegen (#145270)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145270
Approved by: https://github.com/ezyang
This commit is contained in:
johnnynunez 2025-01-21 21:10:02 +00:00 committed by PyTorch MergeBot
parent 895659cb41
commit 35f5668f7e
2 changed files with 6 additions and 4 deletions

View file

@ -82,7 +82,9 @@ endif()
if(CUDA_VERSION VERSION_GREATER "12.6")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Blackwell")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "12.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "12.0")
endif()
@ -231,8 +233,8 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
set(arch_bin 9.0)
set(arch_ptx 9.0)
elseif(${arch_name} STREQUAL "Blackwell")
set(arch_bin 10.0)
set(arch_ptx 10.0)
set(arch_bin 10.0 12.0)
set(arch_ptx 10.0 12.0)
else()
message(SEND_ERROR "Found Unknown CUDA Architecture Name in CUDA_SELECT_NVCC_ARCH_FLAGS: ${arch_name} ")
endif()

View file

@ -2050,12 +2050,12 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]:
('Ampere', '8.0;8.6+PTX'),
('Ada', '8.9+PTX'),
('Hopper', '9.0+PTX'),
('Blackwell', '10.0+PTX'),
('Blackwell', '10.0;12.0+PTX'),
])
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a',
'10.0']
'10.0', '12.0']
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
# The default is sm_30 for CUDA 9.x and 10.x