mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[CUDA][cuBLAS] Check if a context is present when grabbing a cuBLAS handle (#120131)
cuBLAS has indicated that certain kernels will transition to using the driver API over the CUDA runtime API, which we've observed to break existing tests (e.g., DataParallel) that use multithreading and may not eagerly grab a context via `cudaSetDevice`. CC @Aidyn-A @ptrblck Co-authored-by: Aidyn-A <31858918+Aidyn-A@users.noreply.github.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/120131 Approved by: https://github.com/atalman
This commit is contained in:
parent
f36e00b8ce
commit
5929d4e830
1 changed files with 13 additions and 0 deletions
|
|
@ -1,4 +1,5 @@
|
|||
#include <ATen/cuda/CUDAContext.h>
|
||||
#include <ATen/cuda/nvrtc_stub/ATenNVRTC.h>
|
||||
#include <ATen/cuda/detail/DeviceThreadHandles.h>
|
||||
|
||||
#include <c10/cuda/CUDACachingAllocator.h>
|
||||
|
|
@ -125,6 +126,18 @@ cublasHandle_t getCurrentCUDABlasHandle() {
|
|||
c10::DeviceIndex device = 0;
|
||||
AT_CUDA_CHECK(c10::cuda::GetDevice(&device));
|
||||
|
||||
#if !defined(USE_ROCM)
|
||||
CUcontext pctx = nullptr;
|
||||
at::globalContext().getNVRTC().cuCtxGetCurrent(&pctx);
|
||||
if (C10_UNLIKELY(!pctx)) {
|
||||
// workaround for corner case where a primary context exists but is not
|
||||
// the current context, seen in multithreaded use-cases
|
||||
TORCH_WARN_ONCE("Attempting to run cuBLAS, but there was no current CUDA context! Attempting to set the primary context...");
|
||||
at::globalContext().getNVRTC().cuDevicePrimaryCtxRetain(&pctx, device);
|
||||
at::globalContext().getNVRTC().cuCtxSetCurrent(pctx);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Thread local PoolWindows are lazily-initialized
|
||||
// to avoid initialization issues that caused hangs on Windows.
|
||||
// See: https://github.com/pytorch/pytorch/pull/22405
|
||||
|
|
|
|||
Loading…
Reference in a new issue