mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[NVIDIA] RTX50 Blackwell Support codegen (#145270)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145270 Approved by: https://github.com/ezyang
This commit is contained in:
parent
895659cb41
commit
35f5668f7e
2 changed files with 6 additions and 4 deletions
|
|
@ -82,7 +82,9 @@ endif()
|
|||
if(CUDA_VERSION VERSION_GREATER "12.6")
|
||||
list(APPEND CUDA_KNOWN_GPU_ARCHITECTURES "Blackwell")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "10.0")
|
||||
list(APPEND CUDA_COMMON_GPU_ARCHITECTURES "12.0")
|
||||
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "10.0")
|
||||
list(APPEND CUDA_ALL_GPU_ARCHITECTURES "12.0")
|
||||
endif()
|
||||
|
||||
|
||||
|
|
@ -231,8 +233,8 @@ function(CUDA_SELECT_NVCC_ARCH_FLAGS out_variable)
|
|||
set(arch_bin 9.0)
|
||||
set(arch_ptx 9.0)
|
||||
elseif(${arch_name} STREQUAL "Blackwell")
|
||||
set(arch_bin 10.0)
|
||||
set(arch_ptx 10.0)
|
||||
set(arch_bin 10.0 12.0)
|
||||
set(arch_ptx 10.0 12.0)
|
||||
else()
|
||||
message(SEND_ERROR "Found Unknown CUDA Architecture Name in CUDA_SELECT_NVCC_ARCH_FLAGS: ${arch_name} ")
|
||||
endif()
|
||||
|
|
|
|||
|
|
@ -2050,12 +2050,12 @@ def _get_cuda_arch_flags(cflags: Optional[list[str]] = None) -> list[str]:
|
|||
('Ampere', '8.0;8.6+PTX'),
|
||||
('Ada', '8.9+PTX'),
|
||||
('Hopper', '9.0+PTX'),
|
||||
('Blackwell', '10.0+PTX'),
|
||||
('Blackwell', '10.0;12.0+PTX'),
|
||||
])
|
||||
|
||||
supported_arches = ['3.5', '3.7', '5.0', '5.2', '5.3', '6.0', '6.1', '6.2',
|
||||
'7.0', '7.2', '7.5', '8.0', '8.6', '8.7', '8.9', '9.0', '9.0a',
|
||||
'10.0']
|
||||
'10.0', '12.0']
|
||||
valid_arch_strings = supported_arches + [s + "+PTX" for s in supported_arches]
|
||||
|
||||
# The default is sm_30 for CUDA 9.x and 10.x
|
||||
|
|
|
|||
Loading…
Reference in a new issue