diff --git a/setup.py b/setup.py index 30909ad3f51..49dbf9d6951 100644 --- a/setup.py +++ b/setup.py @@ -404,8 +404,9 @@ def checkout_nccl(): report(f'-- Checkout nccl: {get_cmake_cache_vars()["USE_CUDA"]}') #if get_cmake_cache_vars()["USE_CUDA"]: cuda_version = os.getenv("DESIRED_CUDA", "") + cuda_version_2 = os.getenv("CUDA_VERSION", "") commit_hash = "80f6bda4378b99d99e82b4d76a633791cc45fef0" - if cuda_version.startswith("11.8"): + if cuda_version.startswith("11.8") or cuda_version_2.startswith("11.8"): commit_hash = "ab2b89c4c339bd7f816fbc114a4b05d386b66290" nccl_basedir = os.path.join(cwd, "nccl") report(f"-- Calling nccl checkout: {commit_hash}")