diff --git a/onnxruntime/python/_pybind_state.py b/onnxruntime/python/_pybind_state.py index 2e3d2c8a58..e76d402681 100644 --- a/onnxruntime/python/_pybind_state.py +++ b/onnxruntime/python/_pybind_state.py @@ -33,8 +33,16 @@ if platform.system() == "Windows": raise ImportError(f"CUDA Toolkit {cuda_version_major}.x not installed on the machine.") cuda_bin_dir = os.path.join(os.environ[cuda_env_variable], "bin") - if not os.path.isfile(os.path.join(cuda_bin_dir, f"cudnn64_{version_info.cudnn_version}.dll")): - raise ImportError(f"cuDNN {version_info.cudnn_version} not installed in {cuda_bin_dir}.") + + # prefer CUDNN_HOME if set. fallback to the CUDA install directory (would have required user to manually + # copy the cudnn dll there + cudnn_path = os.environ["CUDNN_HOME"] if "CUDNN_HOME" in os.environ else os.environ[cuda_env_variable] + cudnn_bin_dir = os.path.join(cudnn_path, "bin") + + if not os.path.isfile(os.path.join(cudnn_bin_dir, f"cudnn64_{version_info.cudnn_version}.dll")): + raise ImportError(f"cuDNN {version_info.cudnn_version} not installed in {cudnn_bin_dir}. " + f"Set the CUDNN_HOME environment variable to the path of the 'cuda' directory " + f"in your CUDNN installation if necessary.") if sys.version_info >= (3, 8): # Python 3.8 (and later) doesn't search system PATH when loading DLLs, so the CUDA location needs to be