diff --git a/onnxruntime/python/_pybind_state.py b/onnxruntime/python/_pybind_state.py index e4f8ec4e97..4f0a243899 100644 --- a/onnxruntime/python/_pybind_state.py +++ b/onnxruntime/python/_pybind_state.py @@ -5,9 +5,23 @@ import os import platform +import sys import warnings import onnxruntime.capi._ld_preload # noqa: F401 +# Python 3.8 (and later) on Windows doesn't search system PATH when loading DLLs, +# so CUDA location needs to be specified explicitly. +if platform.system() == "Windows" and sys.version_info >= (3, 8): + CUDA_VERSION = "10.2" + CUDNN_VERSION = "8" + cuda_env_variable = "CUDA_PATH_V" + CUDA_VERSION.replace(".", "_") + if cuda_env_variable not in os.environ: + raise ImportError(f"CUDA Toolkit {CUDA_VERSION} not installed on the machine.") + cuda_bin_dir = os.path.join(os.environ[cuda_env_variable], "bin") + if not os.path.isfile(os.path.join(cuda_bin_dir, f"cudnn64_{CUDNN_VERSION}.dll")): + raise ImportError(f"cuDNN {CUDNN_VERSION} not installed on the machine.") + os.add_dll_directory(cuda_bin_dir) + try: from onnxruntime.capi.onnxruntime_pybind11_state import * # noqa except ImportError as e: @@ -21,6 +35,6 @@ except ImportError as e: # TODO: Add a guard against False Positive error message # As a proxy for checking if the 2019 VC Runtime is installed, # we look for a specific dll only shipped with the 2019 VC Runtime - if platform.system().lower() == 'windows' and not os.path.isfile('c:\\Windows\\System32\\vcruntime140_1.dll'): + if platform.system() == "Windows" and not os.path.isfile("C:\\Windows\\System32\\vcruntime140_1.dll"): warnings.warn("Unless you have built the wheel using VS 2017, " - "please install the 2019 Visual C++ runtime and then try again") + "please install the 2019 Visual C++ runtime and then try again.")