From a7420d2ccb62d005f2e1853cfef8d25eb7748a90 Mon Sep 17 00:00:00 2001 From: Eddie Yan Date: Wed, 9 Nov 2022 01:49:50 +0000 Subject: [PATCH] Hopper (`sm90`) support (#87736) Essentially a followup of #87436 CC @xwang233 @ptrblck Pull Request resolved: https://github.com/pytorch/pytorch/pull/87736 Approved by: https://github.com/xwang233, https://github.com/malfet --- .../upstream/FindCUDA/select_compute_arch.cmake | 13 ++++++++++++- torch/utils/cpp_extension.py | 3 ++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake index 822c041ee52..65e7a6ac899 100644 --- a/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake +++ b/cmake/Modules_CUDA_fix/upstream/FindCUDA/select_compute_arch.cmake @@ -94,23 +94,28 @@ if(CUDA_VERSION VERSION_GREATER "10.5") endif() if(NOT CUDA_VERSION VERSION_LESS "11.1") - list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6" "8.6+PTX") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6") list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.6") set(CUDA_LIMIT_GPU_ARCHITECUTRE "8.6") if(CUDA_VERSION VERSION_LESS "11.8") set(CUDA_LIMIT_GPU_ARCHITECTURE "8.9") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.6+PTX") endif() endif() if(NOT CUDA_VERSION VERSION_LESS "11.8") list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Ada") + list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Hopper") list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0") list(APPEND CUDA_ALL_GPU_ARCHITECTURES "8.9") + list(APPEND CUDA_ALL_GPU_ARCHITECTURES "9.0") if(CUDA_VERSION VERSION_LESS "12.0") set(CUDA_LIMIT_GPU_ARCHITECTURE "9.0") list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "8.9+PTX") + list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "9.0+PTX") endif() endif() @@ -248,6 +253,12 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable) elseif(${arch_name} STREQUAL "Ampere") set(arch_bin 8.0) set(arch_ptx 8.0) + elseif(${arch_name} STREQUAL "Ada") + set(arch_bin 8.9) + set(arch_ptx 8.9) + elseif(${arch_name} STREQUAL "Hopper") + set(arch_bin 9.0) + set(arch_ptx 9.0) else() message(SEND_ERROR "Unknown CUDA Architecture Name ${arch_name} in CUDA_SELECT_NVCC_ARCH_FLAGS") endif() diff --git a/torch/utils/cpp_extension.py b/torch/utils/cpp_extension.py index 612ae9fdf07..aa03da23b38 100644 --- a/torch/utils/cpp_extension.py +++ b/torch/utils/cpp_extension.py @@ -1730,10 +1730,11 @@ def _get_cuda_arch_flags(cflags: Optional[List[str]] = None) -> List[str]: ('Turing', '7.5+PTX'), ('Ampere', '8.0;8.6+PTX'), ('Ada', '8.9+PTX'), + ('Hopper', '9.0+PTX'), ]) supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2', - '7.0', '7.2', '7.5', '8.0', '8.6', '8.9'] + '7.0', '7.2', '7.5', '8.0', '8.6', '8.9', '9.0'] valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches] # The default is sm_30 for CUDA 9.x and 10.x