From 279ca5f9db306775a1e9e5cb183d5219ffd7c2ba Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Wed, 5 Apr 2023 14:53:00 +0000 Subject: [PATCH] Revert "[CUDA12] set_device change (#94864)" This reverts commit c18be2b2ec00133abe28efcdd0462e50ddd45a1a. Reverted https://github.com/pytorch/pytorch/pull/94864 on behalf of https://github.com/ezyang due to avoid affecting cuda 11 --- .lintrunner.toml | 32 ----- aten/src/ATen/cuda/CuSparseHandlePool.cpp | 2 +- aten/src/ATen/cuda/CublasHandlePool.cpp | 2 +- aten/src/ATen/cuda/detail/CUDAHooks.cpp | 3 +- aten/src/ATen/cudnn/Handle.cpp | 2 +- aten/src/ATen/native/cuda/RNN.cu | 2 +- aten/src/ATen/native/cuda/UniqueCub.cu | 2 +- .../cuda/linalg/CusolverDnHandlePool.cpp | 2 +- aten/src/ATen/native/cudnn/Conv_v8.cpp | 2 +- .../sparse/cuda/SparseCUDATensorMath.cu | 6 +- c10/cuda/CUDACachingAllocator.cpp | 24 ++-- c10/cuda/CUDAFunctions.cpp | 131 +----------------- c10/cuda/CUDAFunctions.h | 15 -- c10/cuda/CUDAMallocAsyncAllocator.cpp | 6 +- c10/cuda/impl/CUDAGuardImpl.h | 27 ++-- test/jit/test_cuda.py | 15 -- torch/_C/__init__.pyi.in | 1 - torch/csrc/autograd/engine.cpp | 2 +- torch/csrc/cuda/CUDAPluggableAllocator.cpp | 6 +- torch/csrc/cuda/Module.cpp | 24 +--- .../csrc/jit/python/python_sugared_value.cpp | 1 - torch/csrc/jit/runtime/register_cuda_ops.cpp | 13 -- torch/csrc/profiler/stubs/cuda.cpp | 2 +- torch/cuda/__init__.py | 12 +- .../_internal/distributed/distributed_test.py | 2 - 25 files changed, 60 insertions(+), 276 deletions(-) diff --git a/.lintrunner.toml b/.lintrunner.toml index fdc393de76c..fb466ef3241 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -637,38 +637,6 @@ command = [ '@{{PATHSFILE}}' ] -[[linter]] -code = 'RAWCUDADEVICE' -include_patterns = [ - 'aten/**', - 'c10/**', - 'torch/csrc/**', -] -exclude_patterns = [ - 'aten/src/ATen/cuda/CUDAContext.cpp', - 'aten/src/ATen/cuda/CUDAGeneratorImpl.cpp', - 'aten/src/ATen/test/**', - 'c10/core/impl/InlineDeviceGuard.h', - 'c10/cuda/CUDADeviceAssertionHost.cpp', - 'c10/cuda/CUDAFunctions.cpp', - 'c10/cuda/CUDAGuard.h', - 'c10/cuda/impl/CUDATest.cpp', - 'torch/csrc/cuda/nccl.cpp', -] -command = [ - 'python3', - 'tools/linter/adapters/grep_linter.py', - '--pattern=cudaSetDevice', - '--pattern=cudaGetDevice', - '--linter-name=RAWCUDADEVICE', - '--error-name=raw CUDA API usage', - """--error-description=\ - This line calls raw CUDA APIs directly; please use c10::cuda wrappers instead. - """, - '--', - '@{{PATHSFILE}}' -] - [[linter]] code = 'ROOT_LOGGING' include_patterns = [ diff --git a/aten/src/ATen/cuda/CuSparseHandlePool.cpp b/aten/src/ATen/cuda/CuSparseHandlePool.cpp index 6f20f7898e7..7101137112c 100644 --- a/aten/src/ATen/cuda/CuSparseHandlePool.cpp +++ b/aten/src/ATen/cuda/CuSparseHandlePool.cpp @@ -27,7 +27,7 @@ using CuSparsePoolType = DeviceThreadHandlePool #include #include -#include #include #if AT_CUDNN_ENABLED() @@ -224,7 +223,7 @@ const at::cuda::NVRTC& CUDAHooks::nvrtc() const { int64_t current_device() { int device; - cudaError_t err = c10::cuda::GetDevice(&device); + cudaError_t err = cudaGetDevice(&device); if (err == cudaSuccess) { return device; } diff --git a/aten/src/ATen/cudnn/Handle.cpp b/aten/src/ATen/cudnn/Handle.cpp index 1a189c1f092..a6eb8fd7815 100644 --- a/aten/src/ATen/cudnn/Handle.cpp +++ b/aten/src/ATen/cudnn/Handle.cpp @@ -33,7 +33,7 @@ using CudnnPoolType = at::cuda::DeviceThreadHandlePool compute_unique( dim3(std::min(static_cast(cuda::getApplyBlock().x), num_inp)); dim3 grid; int curDevice = -1; - c10::cuda::GetDevice(&curDevice); + cudaGetDevice(&curDevice); cuda::getApplyGrid(num_inp, grid, curDevice); adjacent_difference_kernel<<>>( num_inp, data, inv_loc_ptr); diff --git a/aten/src/ATen/native/cuda/linalg/CusolverDnHandlePool.cpp b/aten/src/ATen/native/cuda/linalg/CusolverDnHandlePool.cpp index 29a17648de7..f6e37e41fc0 100644 --- a/aten/src/ATen/native/cuda/linalg/CusolverDnHandlePool.cpp +++ b/aten/src/ATen/native/cuda/linalg/CusolverDnHandlePool.cpp @@ -30,7 +30,7 @@ using CuSolverDnPoolType = DeviceThreadHandlePool(max_block_size); diff --git a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu index 4bf28c1120d..488da3234c4 100644 --- a/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu +++ b/aten/src/ATen/native/sparse/cuda/SparseCUDATensorMath.cu @@ -314,7 +314,7 @@ Tensor& add_out_dense_sparse_cuda(Tensor& r_, const Tensor& dense, const SparseT const dim3 block = cuda::getApplyBlock(); dim3 grid; int curDevice = -1; - c10::cuda::GetDevice(&curDevice); + cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); if (sparse.dense_dim() == 0) { TORCH_CHECK(cuda::getApplyGrid(nnz, grid, curDevice), "add: Argument #0: tensor too large or too many dimensions"); @@ -606,7 +606,7 @@ Tensor _sparse_sum_backward_cuda(const Tensor& grad_, const SparseTensor& input_ } else { int curDevice = -1; - c10::cuda::GetDevice(&curDevice); + cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); at::cuda::ThrustAllocator allocator; auto policy = thrust::cuda::par(allocator).on(stream); @@ -711,7 +711,7 @@ __global__ void search_end_matrix_indices_cuda_kernel( // indices to find the end index for each matrix void search_end_matrix_indices(int64_t* mat_el_end_indices, int64_t num_matrices, const Tensor& indices_1D) { int curDevice = -1; - c10::cuda::GetDevice(&curDevice); + cudaGetDevice(&curDevice); cudaStream_t stream = at::cuda::getCurrentCUDAStream(curDevice); auto indices_1D_ti = getTensorInfo(indices_1D); diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 0d9fba785a0..d841f32983a 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -2225,12 +2225,12 @@ class DeviceCachingAllocator { void insert_events(Block* block) { int prev_device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&prev_device)); + C10_CUDA_CHECK(cudaGetDevice(&prev_device)); stream_set streams(std::move(block->stream_uses)); AT_ASSERT(block->stream_uses.empty()); for (auto& stream : streams) { - C10_CUDA_CHECK(c10::cuda::SetDevice(stream.device_index())); + C10_CUDA_CHECK(cudaSetDevice(stream.device_index())); EventPool::Event event = create_event_internal(static_cast(stream.device_index())); @@ -2240,7 +2240,7 @@ class DeviceCachingAllocator { cuda_events[stream].emplace_back(std::move(event), block); } - C10_CUDA_CHECK(c10::cuda::MaybeSetDevice(prev_device)); + C10_CUDA_CHECK(cudaSetDevice(prev_device)); } void insert_events_deferred_until_no_capture() { @@ -2434,7 +2434,11 @@ class NativeCachingAllocator : public CUDAAllocator { "invalid fraction:", fraction, ". Please set within (0, 1)."); - C10_CUDA_CHECK(c10::cuda::SetDevice(device)); + int activated_device; + C10_CUDA_CHECK(cudaGetDevice(&activated_device)); + if (activated_device != device) { + C10_CUDA_CHECK(cudaSetDevice(device)); + } device_allocator[device]->setMemoryFraction(fraction); } @@ -2444,7 +2448,7 @@ class NativeCachingAllocator : public CUDAAllocator { size_t alloc_trace_max_entries, bool alloc_trace_record_context) override { int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); device_allocator[device]->recordHistory( enabled, std::move(context_recorder), @@ -2454,7 +2458,7 @@ class NativeCachingAllocator : public CUDAAllocator { void attachOutOfMemoryObserver(OutOfMemoryObserver observer) override { int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); device_allocator[device]->attachOutOfMemoryObserver(std::move(observer)); } @@ -2553,7 +2557,7 @@ class NativeCachingAllocator : public CUDAAllocator { size < one_exa_bytes, "CUDA out of memory. Tried to allocate more than 1EB memory."); int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); void* r = nullptr; if (forceUncachedAllocator()) { // Deliberately don't use cudaMallocMaybeCapturing here, to force an error @@ -2630,7 +2634,7 @@ class NativeCachingAllocator : public CUDAAllocator { return nullptr; } int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); void* r = nullptr; malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device)); return r; @@ -2641,7 +2645,7 @@ class NativeCachingAllocator : public CUDAAllocator { return nullptr; } int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); void* r = nullptr; malloc(&r, device, nbytes, stream); return r; @@ -2688,7 +2692,7 @@ class NativeCachingAllocator : public CUDAAllocator { &dev, *ipc_handle, cudaIpcMemLazyEnablePeerAccess)); // devPtr has to be deleted in same device when created. int curr_device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&curr_device)); + C10_CUDA_CHECK(cudaGetDevice(&curr_device)); auto sp = std::shared_ptr(dev, [handle, curr_device, this](void* ptr) { cuda::CUDAGuard device_guard(curr_device); diff --git a/c10/cuda/CUDAFunctions.cpp b/c10/cuda/CUDAFunctions.cpp index b143c4aa69c..813fa765309 100644 --- a/c10/cuda/CUDAFunctions.cpp +++ b/c10/cuda/CUDAFunctions.cpp @@ -15,7 +15,7 @@ int32_t driver_version() { int device_count_impl(bool fail_if_no_driver) { int count; - auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDeviceCount(&count)); + auto err = C10_CUDA_ERROR_HANDLED(cudaGetDeviceCount(&count)); if (err == cudaSuccess) { return count; } @@ -122,12 +122,12 @@ DeviceIndex device_count_ensure_non_zero() { DeviceIndex current_device() { int cur_device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device)); + C10_CUDA_CHECK(cudaGetDevice(&cur_device)); return static_cast(cur_device); } void set_device(DeviceIndex device) { - C10_CUDA_CHECK(c10::cuda::SetDevice(static_cast(device))); + C10_CUDA_CHECK(cudaSetDevice(static_cast(device))); } void device_synchronize() { @@ -182,129 +182,4 @@ bool hasPrimaryContext(int64_t device_index) { return _internal::hasPrimaryContext(device_index); } -// Wrappers for raw CUDA device management functions -cudaError_t GetDeviceCount(int* dev_count) { - return cudaGetDeviceCount(dev_count); -} - -// This is a codepath for CUDA 12 that comes with a critical change in behavior -// of `cudaSetDevice`. Unlike to previous CUDA versions that allocate context -// lazily CUDA 12.x eagerly allocates primary context the moment `cudaSetDevice` -// is called. This can lead to dramatic consequences and pollute the device -// memory in distributed runs. To avoid unnecessary context creation a new -// function called `MaybeSetDevice` was introduced. This function is to be -// called in device guard destructor and at the exit of torch.cuda.device -// context manager. The behavior of `MaybeSetDevice` is quite simple, it calls -// to `cudaSetDevice` if context already exist or if context was not allocated -// on targeted device it simply saves the device index. This way we can keep -// PyTorch backward compatible for applications like this: -// -// ``` -// import torch -// x = torch.empty(1, device=“cuda:1”) # no CUDA context on cuda:0 after this -// call y = torch.empty(1, device=“cuda”) # CUDA context is created on cuda:0 -// ``` -#if CUDA_VERSION >= 11000 -thread_local int targetDeviceIndex = -1; - -cudaError_t GetDevice(int* device) { - if (targetDeviceIndex >= 0) { - *device = targetDeviceIndex; - return cudaSuccess; - } - return cudaGetDevice(device); -} - -cudaError_t SetDevice(int device) { - TORCH_CHECK(device >= 0, "device id must be positive!"); - targetDeviceIndex = -1; - int cur_device = -1; - C10_CUDA_CHECK(cudaGetDevice(&cur_device)); - if (device == cur_device) { - return cudaSuccess; - } - cudaError_t err = cudaSetDevice(device); - C10_CUDA_CHECK(cudaFree(0)); - return err; -} - -cudaError_t MaybeSetDevice(int device) { - if (hasPrimaryContext(device)) { - return c10::cuda::SetDevice(device); - } - targetDeviceIndex = device; - return cudaSuccess; -} - -int ExchangeDevice(int to_device) { - int cur_device = -1; - C10_CUDA_CHECK(cudaGetDevice(&cur_device)); - if (to_device == cur_device) { - targetDeviceIndex = -1; - return cur_device; - } - targetDeviceIndex = -1; - C10_CUDA_CHECK(cudaSetDevice(to_device)); - return cur_device; -} - -int MaybeExchangeDevice(int to_device) { - int cur_device = -1; - C10_CUDA_CHECK(cudaGetDevice(&cur_device)); - if (to_device == cur_device) { - targetDeviceIndex = -1; - return cur_device; - } - if (hasPrimaryContext(to_device)) { - targetDeviceIndex = -1; - C10_CUDA_CHECK(cudaSetDevice(to_device)); - } else { - targetDeviceIndex = to_device; - } - return cur_device; -} - -void SetTargetDevice() { - if (targetDeviceIndex >= 0) { - C10_CUDA_CHECK(c10::cuda::SetDevice(targetDeviceIndex)); - } -} -#else -cudaError_t GetDevice(int* device) { - return cudaGetDevice(device); -} - -cudaError_t SetDevice(int device) { - TORCH_CHECK(device >= 0, "device id must be positive!"); - int cur_device = -1; - C10_CUDA_CHECK(cudaGetDevice(&cur_device)); - if (device == cur_device) { - return cudaSuccess; - } - return cudaSetDevice(device); -} - -cudaError_t MaybeSetDevice(int device) { - return c10::cuda::SetDevice(device); -} - -int ExchangeDevice(int to_device) { - int cur_device = -1; - C10_CUDA_CHECK(c10::cuda::GetDevice(&cur_device)); - if (to_device == cur_device) { - return cur_device; - } - C10_CUDA_CHECK(cudaSetDevice(to_device)); - return cur_device; -} - -int MaybeExchangeDevice(int to_device) { - return c10::cuda::ExchangeDevice(to_device); -} - -void SetTargetDevice() { - // no-op on CUDA version < 12.x -} -#endif - } // namespace c10::cuda diff --git a/c10/cuda/CUDAFunctions.h b/c10/cuda/CUDAFunctions.h index 388afabcf4d..4c4f2bd5b75 100644 --- a/c10/cuda/CUDAFunctions.h +++ b/c10/cuda/CUDAFunctions.h @@ -34,21 +34,6 @@ C10_CUDA_API void device_synchronize(); C10_CUDA_API void warn_or_error_on_sync(); -// Raw CUDA device management functions -C10_CUDA_API cudaError_t GetDeviceCount(int* dev_count); - -C10_CUDA_API cudaError_t GetDevice(int* device); - -C10_CUDA_API cudaError_t SetDevice(int device); - -C10_CUDA_API cudaError_t MaybeSetDevice(int device); - -C10_CUDA_API int ExchangeDevice(int device); - -C10_CUDA_API int MaybeExchangeDevice(int device); - -C10_CUDA_API void SetTargetDevice(); - enum class SyncDebugMode { L_DISABLED = 0, L_WARN, L_ERROR }; // this is a holder for c10 global state (similar to at GlobalContext) diff --git a/c10/cuda/CUDAMallocAsyncAllocator.cpp b/c10/cuda/CUDAMallocAsyncAllocator.cpp index d7dad770bef..24b6ac626e0 100644 --- a/c10/cuda/CUDAMallocAsyncAllocator.cpp +++ b/c10/cuda/CUDAMallocAsyncAllocator.cpp @@ -411,7 +411,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { size < one_exa_bytes, "CUDA out of memory. Tried to allocate more than 1EB memory."); int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); void* r = nullptr; if (size != 0) { mallocAsync(&r, device, size, cuda::getCurrentCUDAStream(device)); @@ -818,7 +818,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { return nullptr; } int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); void* r = nullptr; mallocAsync(&r, device, nbytes, cuda::getCurrentCUDAStream(device)); return r; @@ -829,7 +829,7 @@ struct CudaMallocAsyncAllocator : public CUDAAllocator { return nullptr; } int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); void* r = nullptr; mallocAsync(&r, device, nbytes, stream); return r; diff --git a/c10/cuda/impl/CUDAGuardImpl.h b/c10/cuda/impl/CUDAGuardImpl.h index 1d580ef10a8..0a48ba060aa 100644 --- a/c10/cuda/impl/CUDAGuardImpl.h +++ b/c10/cuda/impl/CUDAGuardImpl.h @@ -29,17 +29,20 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { } Device exchangeDevice(Device d) const override { TORCH_INTERNAL_ASSERT(d.is_cuda()); - int old_device_index = c10::cuda::ExchangeDevice(d.index()); - return Device(DeviceType::CUDA, old_device_index); + Device old_device = getDevice(); + if (old_device.index() != d.index()) { + C10_CUDA_CHECK(cudaSetDevice(d.index())); + } + return old_device; } Device getDevice() const override { int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); return Device(DeviceType::CUDA, device); } c10::optional uncheckedGetDevice() const noexcept { int device; - const auto err = C10_CUDA_ERROR_HANDLED(c10::cuda::GetDevice(&device)); + const auto err = C10_CUDA_ERROR_HANDLED(cudaGetDevice(&device)); C10_CUDA_CHECK_WARN(err); if (err != cudaSuccess) { return c10::nullopt; @@ -48,10 +51,16 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { } void setDevice(Device d) const override { TORCH_INTERNAL_ASSERT(d.is_cuda()); - C10_CUDA_CHECK(c10::cuda::SetDevice(d.index())); + Device current_device = getDevice(); + if (current_device != d) { + C10_CUDA_CHECK(cudaSetDevice(d.index())); + } } void uncheckedSetDevice(Device d) const noexcept override { - C10_CUDA_CHECK_WARN(c10::cuda::MaybeSetDevice(d.index())); + auto current_device = uncheckedGetDevice(); + if (!current_device.has_value() || current_device.value() != d) { + C10_CUDA_CHECK_WARN(cudaSetDevice(d.index())); + } } Stream getStream(Device d) const noexcept override { return getCurrentCUDAStream(d.index()).unwrap(); @@ -105,15 +114,15 @@ struct CUDAGuardImpl final : public c10::impl::DeviceGuardImplInterface { return; auto cuda_event = static_cast(event); int orig_device; - C10_CUDA_CHECK_WARN(c10::cuda::GetDevice(&orig_device)); - C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(device_index)); + C10_CUDA_CHECK_WARN(cudaGetDevice(&orig_device)); + C10_CUDA_CHECK_WARN(cudaSetDevice(device_index)); const c10::impl::PyInterpreter* interp = c10::impl::GPUTrace::get_trace(); if (C10_UNLIKELY(interp)) { (*interp)->trace_gpu_event_deletion( reinterpret_cast(cuda_event)); } C10_CUDA_CHECK_WARN(cudaEventDestroy(cuda_event)); - C10_CUDA_CHECK_WARN(c10::cuda::SetDevice(orig_device)); + C10_CUDA_CHECK_WARN(cudaSetDevice(orig_device)); } void record( diff --git a/test/jit/test_cuda.py b/test/jit/test_cuda.py index b7249a62c40..c4b037d2218 100644 --- a/test/jit/test_cuda.py +++ b/test/jit/test_cuda.py @@ -604,18 +604,3 @@ class TestCUDA(JitTestCase): FileCheck().check("cuda::_exchange_device(").run(g) torch._C._jit_pass_inline(g) FileCheck().check("cuda::_exchange_device(").run(g) - - # Make sure that cuda._maybe_exchange_device doesn't get DCE'ed - @unittest.skipIf(not TEST_CUDA, "Cuda not available") - def test__maybe_exchange_device_op(self): - def fn(device: int, tensor): - torch.cuda._maybe_exchange_device(device) - return tensor.cos().relu() - - fn_s = torch.jit.script(fn) - # Just check the graph, don't run it. Otherwise, we'd need to - # run this test on a multi-gpu CI runner, which is overkill. - g = fn_s.graph - FileCheck().check("cuda::_maybe_exchange_device(").run(g) - torch._C._jit_pass_inline(g) - FileCheck().check("cuda::_maybe_exchange_device(").run(g) diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index 9a7ea52c672..f52c59ac6f4 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -1448,7 +1448,6 @@ def _cuda_getCurrentBlasHandle() -> _int: ... def _cuda_clearCublasWorkspaces() -> None: ... def _cuda_setDevice(device: _int) -> None: ... def _cuda_exchangeDevice(device: _int) -> _int: ... -def _cuda_maybeExchangeDevice(device: _int) -> _int: ... def _cuda_getDevice() -> _int: ... def _cuda_getDeviceCount() -> _int: ... def _cuda_set_sync_debug_mode(warn_level: Union[_int, str]) -> None: ... diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index 6e5c5ce5ad2..b3c2069b6c9 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -774,7 +774,7 @@ void set_device(int device) { // as in some settings we compile with cuda, but // have lazy stubs for CUDA functionality (so actually // attempting to setup a guard(CPU_DEVICE) will cause an - // error, because it will still query GetDevice). + // error, because it will still query cudaGetDevice). // // Don't use DeviceGuard here because its destructor may be called before the // device is reset. This is fine because the device is thread local. diff --git a/torch/csrc/cuda/CUDAPluggableAllocator.cpp b/torch/csrc/cuda/CUDAPluggableAllocator.cpp index af0728091f0..6fa5495c31f 100644 --- a/torch/csrc/cuda/CUDAPluggableAllocator.cpp +++ b/torch/csrc/cuda/CUDAPluggableAllocator.cpp @@ -97,7 +97,7 @@ void* CUDAPluggableAllocator::malloc( c10::DataPtr CUDAPluggableAllocator::allocate(size_t size) const { int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); void* r = const_cast(this)->malloc(size, device, stream); @@ -112,7 +112,7 @@ c10::DeleterFnPtr CUDAPluggableAllocator::raw_deleter() const { void* CUDAPluggableAllocator::raw_alloc(size_t nbytes) { int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); cudaStream_t stream = c10::cuda::getCurrentCUDAStream(device); return malloc(nbytes, device, stream); } @@ -121,7 +121,7 @@ void* CUDAPluggableAllocator::raw_alloc_with_stream( size_t nbytes, cudaStream_t stream) { int device; - C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK(cudaGetDevice(&device)); return malloc(nbytes, device, stream); } diff --git a/torch/csrc/cuda/Module.cpp b/torch/csrc/cuda/Module.cpp index 0999ef13c8a..de52ce77cc9 100644 --- a/torch/csrc/cuda/Module.cpp +++ b/torch/csrc/cuda/Module.cpp @@ -97,24 +97,12 @@ PyObject* THCPModule_exchangeDevice(PyObject* self, PyObject* arg) { } torch::utils::cuda_lazy_init(); - int current_device = c10::cuda::ExchangeDevice(device); - - return THPUtils_packInt32(current_device); - END_HANDLE_TH_ERRORS -} - -PyObject* THCPModule_maybeExchangeDevice(PyObject* self, PyObject* arg) { - HANDLE_TH_ERRORS - TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); - int64_t device = THPUtils_unpackLong(arg); - if (device < 0) { - return THPUtils_packInt32(-1); + auto current_device = c10::cuda::current_device(); + if (current_device != device) { + THCPModule_setDevice(device); } - torch::utils::cuda_lazy_init(); - int current_device = c10::cuda::MaybeExchangeDevice(device); - - return THPUtils_packInt32(current_device); + return THPUtils_packInt32(static_cast(current_device)); END_HANDLE_TH_ERRORS } @@ -1281,10 +1269,6 @@ static struct PyMethodDef _THCPModule_methods[] = { {"_cuda_init", THCPModule_initExtension, METH_NOARGS, nullptr}, {"_cuda_setDevice", THCPModule_setDevice_wrap, METH_O, nullptr}, {"_cuda_exchangeDevice", THCPModule_exchangeDevice, METH_O, nullptr}, - {"_cuda_maybeExchangeDevice", - THCPModule_maybeExchangeDevice, - METH_O, - nullptr}, {"_cuda_getDevice", THCPModule_getDevice_wrap, METH_NOARGS, nullptr}, {"_cuda_getDeviceCount", THCPModule_getDeviceCount_wrap, diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 28d971e3e78..7128adf24c5 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -229,7 +229,6 @@ std::shared_ptr CUDAPythonModuleValue::attr( "default_stream", "current_device", "_exchange_device", - "_maybe_exchange_device", "set_device", "device_index", "device_count", diff --git a/torch/csrc/jit/runtime/register_cuda_ops.cpp b/torch/csrc/jit/runtime/register_cuda_ops.cpp index e58c211a05e..14c689f2b71 100644 --- a/torch/csrc/jit/runtime/register_cuda_ops.cpp +++ b/torch/csrc/jit/runtime/register_cuda_ops.cpp @@ -103,19 +103,6 @@ RegisterOperators const reg({ }, // cuda::set_device has side effects. c10::AliasAnalysisKind::CONSERVATIVE), - Operator( - "cuda::_maybe_exchange_device(int64_t index) -> int", - [](Stack& stack) { - int64_t idx = -1; - pop(stack, idx); - if (idx < 0) { - push(stack, -1); - return; - } - int prev_idx = c10::cuda::MaybeExchangeDevice(static_cast(idx)); - push(stack, prev_idx); - }, - c10::AliasAnalysisKind::CONSERVATIVE), Operator( "cuda::_set_device(int64_t val) -> ()", [](Stack& stack) { diff --git a/torch/csrc/profiler/stubs/cuda.cpp b/torch/csrc/profiler/stubs/cuda.cpp index 664fe107249..6731d0f4d3b 100644 --- a/torch/csrc/profiler/stubs/cuda.cpp +++ b/torch/csrc/profiler/stubs/cuda.cpp @@ -39,7 +39,7 @@ struct CUDAMethods : public ProfilerStubs { void record(int* device, ProfilerEventStub* event, int64_t* cpu_ns) const override { if (device) { - TORCH_CUDA_CHECK(c10::cuda::GetDevice(device)); + TORCH_CUDA_CHECK(cudaGetDevice(device)); } // NOLINTNEXTLINE(cppcoreguidelines-init-variables) CUevent_st* cuda_event_ptr; diff --git a/torch/cuda/__init__.py b/torch/cuda/__init__.py index 4be71fe6adf..9020d32739b 100644 --- a/torch/cuda/__init__.py +++ b/torch/cuda/__init__.py @@ -87,14 +87,6 @@ else: return -1 raise RuntimeError("PyTorch was compiled without CUDA support") -if hasattr(torch._C, '_cuda_maybeExchangeDevice'): - _maybe_exchange_device = torch._C._cuda_maybeExchangeDevice -else: - def _maybe_exchange_device(device: int) -> int: - if device < 0: - return -1 - raise RuntimeError("PyTorch was compiled without CUDA support") - # Global variables dynamically populated by native code has_magma: bool = False @@ -313,7 +305,7 @@ class _DeviceGuard: self.prev_idx = torch.cuda._exchange_device(self.idx) def __exit__(self, type: Any, value: Any, traceback: Any): - self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) + self.idx = torch.cuda._exchange_device(self.prev_idx) return False @@ -333,7 +325,7 @@ class device: self.prev_idx = torch.cuda._exchange_device(self.idx) def __exit__(self, type: Any, value: Any, traceback: Any): - self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) + self.idx = torch.cuda._exchange_device(self.prev_idx) return False diff --git a/torch/testing/_internal/distributed/distributed_test.py b/torch/testing/_internal/distributed/distributed_test.py index 546b7ad7042..269a6198317 100644 --- a/torch/testing/_internal/distributed/distributed_test.py +++ b/torch/testing/_internal/distributed/distributed_test.py @@ -577,8 +577,6 @@ class TestDistBackend(MultiProcessTestCase): rank=self.rank, timeout=timeout, ) - if torch.cuda.is_available(): - torch.cuda.set_device(self.rank) except RuntimeError as e: if "recompile" in e.args[0]: sys.exit(TEST_SKIPS["backend_unavailable"].exit_code)