From 015fbb3dbb8547294171990e1104cdf101be4a7d Mon Sep 17 00:00:00 2001 From: Ivan Stojiljkovic <17503404+ivanst0@users.noreply.github.com> Date: Fri, 27 Nov 2020 00:52:30 +0100 Subject: [PATCH] Add support for Python 3.8+ on Windows when CUDA is enabled (#5956) --- onnxruntime/python/_pybind_state.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/onnxruntime/python/_pybind_state.py b/onnxruntime/python/_pybind_state.py index e4f8ec4e97..4f0a243899 100644 --- a/onnxruntime/python/_pybind_state.py +++ b/onnxruntime/python/_pybind_state.py @@ -5,9 +5,23 @@ import os import platform +import sys import warnings import onnxruntime.capi._ld_preload # noqa: F401 +# Python 3.8 (and later) on Windows doesn't search system PATH when loading DLLs, +# so CUDA location needs to be specified explicitly. +if platform.system() == "Windows" and sys.version_info >= (3, 8): + CUDA_VERSION = "10.2" + CUDNN_VERSION = "8" + cuda_env_variable = "CUDA_PATH_V" + CUDA_VERSION.replace(".", "_") + if cuda_env_variable not in os.environ: + raise ImportError(f"CUDA Toolkit {CUDA_VERSION} not installed on the machine.") + cuda_bin_dir = os.path.join(os.environ[cuda_env_variable], "bin") + if not os.path.isfile(os.path.join(cuda_bin_dir, f"cudnn64_{CUDNN_VERSION}.dll")): + raise ImportError(f"cuDNN {CUDNN_VERSION} not installed on the machine.") + os.add_dll_directory(cuda_bin_dir) + try: from onnxruntime.capi.onnxruntime_pybind11_state import * # noqa except ImportError as e: @@ -21,6 +35,6 @@ except ImportError as e: # TODO: Add a guard against False Positive error message # As a proxy for checking if the 2019 VC Runtime is installed, # we look for a specific dll only shipped with the 2019 VC Runtime - if platform.system().lower() == 'windows' and not os.path.isfile('c:\\Windows\\System32\\vcruntime140_1.dll'): + if platform.system() == "Windows" and not os.path.isfile("C:\\Windows\\System32\\vcruntime140_1.dll"): warnings.warn("Unless you have built the wheel using VS 2017, " - "please install the 2019 Visual C++ runtime and then try again") + "please install the 2019 Visual C++ runtime and then try again.")