Add support for blackwell codegen (#141724)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141724
Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/eqy
This commit is contained in:
drisspg 2024-12-03 10:56:39 -05:00 committed by PyTorch MergeBot
parent 8b0fcad0fd
commit 42547f8d48

View file

@ -79,6 +79,13 @@ if(NOT CUDA_VERSION VERSION_LESS "12.0")
list(REMOVE_ITEM CUDA_ALL_GPU_ARCHITECTURES "3.5")
endif()
if(CUDA_VERSION VERSION_GREATER "12.6")
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Blackwell")
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0")
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0")
endif()
################################################################################################
# A function for automatic detection of GPUs installed (if autodetection is enabled)
# Usage:
@ -186,7 +193,7 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
set(add_ptx TRUE)
set(arch_name ${CMAKE_MATCH_1})
endif()
if(arch_name MATCHES "^([0-9]\\.[0-9]a?(\\([0-9]\\.[0-9]\\))?)$")
if(arch_name MATCHES "^([0-9]+\\.[0-9]a?(\\([0-9]+\\.[0-9]\\))?)$")
set(arch_bin ${CMAKE_MATCH_1})
set(arch_ptx ${arch_bin})
else()
@ -223,8 +230,11 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
elseif(${arch_name} STREQUAL "Hopper")
set(arch_bin 9.0)
set(arch_ptx 9.0)
elseif(${arch_name} STREQUAL "Blackwell")
set(arch_bin 10.0)
set(arch_ptx 10.0)
else()
message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS")
message(SEND_ERROR "Found Unknown CUDA Architecture Name in CUDA_SELECT_NVCC_ARCH_FLAGS: ${arch_name} ")
endif()
endif()
if(NOT arch_bin)