diff --git a/aten/src/ATen/cuda/CublasHandlePool.cpp b/aten/src/ATen/cuda/CublasHandlePool.cpp index d8ee09b1486..6913d2cd95e 100644 --- a/aten/src/ATen/cuda/CublasHandlePool.cpp +++ b/aten/src/ATen/cuda/CublasHandlePool.cpp @@ -1,4 +1,5 @@ #include +#include #include #include @@ -125,6 +126,18 @@ cublasHandle_t getCurrentCUDABlasHandle() { c10::DeviceIndex device = 0; AT_CUDA_CHECK(c10::cuda::GetDevice(&device)); +#if !defined(USE_ROCM) + CUcontext pctx = nullptr; + at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx); + if (C10_UNLIKELY(!pctx)) { + // workaround for corner case where a primary context exists but is not + // the current context, seen in multithreaded use-cases + TORCH_WARN_ONCE("Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context..."); + at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, device); + at::globalContext().getNVRTC().cuCtxSetCurrent(pctx); + } +#endif + // Thread local PoolWindows are lazily-initialized // to avoid initialization issues that caused hangs on Windows. // See: https://github.com/pytorch/pytorch/pull/22405