From 42547f8d48b0ebef8719da7cde9fce4e4fbe2c3e Mon Sep 17 00:00:00 2001 From: drisspg Date: Tue, 3 Dec 2024 10:56:39 -0500 Subject: [PATCH] Add support for blackwell codegen (#141724) Pull Request resolved: https://github.com/pytorch/pytorch/pull/141724 Approved by: https://github.com/Skylion007, https://github.com/malfet, https://github.com/eqy --- .../upstream/FindCUDA/select_compute_arch.cmake | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake index 90de8fb0d84..14ca7ee302d 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake @@ -79,6 +79,13 @@ if(NOT CUDA_VERSION VERSION_LESS "12.0") list(REMOVE_ITEM CUDA_ALL_GPU_ARCHITECTURES "3.5") endif() +if(CUDA_VERSION VERSION_GREATER "12.6") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Blackwell") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0") +endif() + + ################################################################################################ # A function for automatic detection of GPUs installed (if autodetection is enabled) # Usage: @@ -186,7 +193,7 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) set(add_ptx TRUE) set(arch_name ${CMAKE_MATCH_1}) endif() - if(arch_name MATCHES "^([0-9]\\.[0-9]a?(\\([0-9]\\.[0-9]\\))?)$") + if(arch_name MATCHES "^([0-9]+\\.[0-9]a?(\\([0-9]+\\.[0-9]\\))?)$") set(arch_bin ${CMAKE_MATCH_1}) set(arch_ptx ${arch_bin}) else() @@ -223,8 +230,11 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) elseif(${arch_name} STREQUAL "Hopper") set(arch_bin 9.0) set(arch_ptx 9.0) + elseif(${arch_name} STREQUAL "Blackwell") + set(arch_bin 10.0) + set(arch_ptx 10.0) else() - message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") + message(SEND_ERROR "Found Unknown CUDA Architecture Name in CUDA_SELECT_NVCC_ARCH_FLAGS: ${arch_name} ") endif() endif() if(NOT arch_bin)