[CUDA][cuBLAS] Check if a context is present when grabbing a cuBLAS handle (#120131)

cuBLAS has indicated that certain kernels will transition to using the driver API over the CUDA runtime API, which we've observed to break existing tests (e.g., DataParallel) that use multithreading and may not eagerly grab a context via `cudaSetDevice`.

CC @Aidyn-A @ptrblck

Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/120131
Approved by: https://github.com/atalman
This commit is contained in:
Eddie Yan 2024-02-27 22:45:12 +00:00 committed by PyTorch MergeBot
parent f36e00b8ce
commit 5929d4e830

View file

@ -1,4 +1,5 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
#include <ATen/cuda/detail/DeviceThreadHandles.h>
#include <c10/cuda/CUDACachingAllocator.h>
@ -125,6 +126,18 @@ cublasHandle_t getCurrentCUDABlasHandle() {
c10::DeviceIndex device = 0;
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
#if !defined(USE_ROCM)
CUcontext pctx = nullptr;
at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
if (C10_UNLIKELY(!pctx)) {
// workaround for corner case where a primary context exists but is not
// the current context, seen in multithreaded use-cases
TORCH_WARN_ONCE("Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context...");
at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, device);
at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
}
#endif
// Thread local PoolWindows are lazily-initialized
// to avoid initialization issues that caused hangs on Windows.
// See: https://github.com/pytorch/pytorch/pull/22405