Add support for Python 3.8+ on Windows when CUDA is enabled (#5956)

This commit is contained in:
Ivan Stojiljkovic 2020-11-27 00:52:30 +01:00 committed by GitHub
parent e207589631
commit 015fbb3dbb
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -5,9 +5,23 @@
import os
import platform
import sys
import warnings
import onnxruntime.capi._ld_preload # noqa: F401
# Python 3.8 (and later) on Windows doesn't search system PATH when loading DLLs,
# so CUDA location needs to be specified explicitly.
if platform.system() == "Windows" and sys.version_info >= (3, 8):
CUDA_VERSION = "10.2"
CUDNN_VERSION = "8"
cuda_env_variable = "CUDA_PATH_V" + CUDA_VERSION.replace(".", "_")
if cuda_env_variable not in os.environ:
raise ImportError(f"CUDA Toolkit {CUDA_VERSION} not installed on the machine.")
cuda_bin_dir = os.path.join(os.environ[cuda_env_variable], "bin")
if not os.path.isfile(os.path.join(cuda_bin_dir, f"cudnn64_{CUDNN_VERSION}.dll")):
raise ImportError(f"cuDNN {CUDNN_VERSION} not installed on the machine.")
os.add_dll_directory(cuda_bin_dir)
try:
from onnxruntime.capi.onnxruntime_pybind11_state import * # noqa
except ImportError as e:
@ -21,6 +35,6 @@ except ImportError as e:
# TODO: Add a guard against False Positive error message
# As a proxy for checking if the 2019 VC Runtime is installed,
# we look for a specific dll only shipped with the 2019 VC Runtime
if platform.system().lower() == 'windows' and not os.path.isfile('c:\\Windows\\System32\\vcruntime140_1.dll'):
if platform.system() == "Windows" and not os.path.isfile("C:\\Windows\\System32\\vcruntime140_1.dll"):
warnings.warn("Unless you have built the wheel using VS 2017, "
"please install the 2019 Visual C++ runtime and then try again")
"please install the 2019 Visual C++ runtime and then try again.")