Fix path used in check for cudnn library (#7786)

* There are separate paths for CUDA and CUDNN as they are not guaranteed to be in the same location on a Windows machine. Use the CUDNN path when looking for the CUDNN library.

* Refine check
This commit is contained in:
Scott McKay 2021-05-28 09:32:13 +10:00 committed by GitHub
parent ddf4aaaae1
commit 63df683040
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -33,8 +33,16 @@ if platform.system() == "Windows":
raise ImportError(f"CUDA Toolkit {cuda_version_major}.x not installed on the machine.")
cuda_bin_dir = os.path.join(os.environ[cuda_env_variable], "bin")
if not os.path.isfile(os.path.join(cuda_bin_dir, f"cudnn64_{version_info.cudnn_version}.dll")):
raise ImportError(f"cuDNN {version_info.cudnn_version} not installed in {cuda_bin_dir}.")
# prefer CUDNN_HOME if set. fallback to the CUDA install directory (would have required user to manually
# copy the cudnn dll there
cudnn_path = os.environ["CUDNN_HOME"] if "CUDNN_HOME" in os.environ else os.environ[cuda_env_variable]
cudnn_bin_dir = os.path.join(cudnn_path, "bin")
if not os.path.isfile(os.path.join(cudnn_bin_dir, f"cudnn64_{version_info.cudnn_version}.dll")):
raise ImportError(f"cuDNN {version_info.cudnn_version} not installed in {cudnn_bin_dir}. "
f"Set the CUDNN_HOME environment variable to the path of the 'cuda' directory "
f"in your CUDNN installation if necessary.")
if sys.version_info >= (3, 8):
# Python 3.8 (and later) doesn't search system PATH when loading DLLs, so the CUDA location needs to be