mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
pin_memory should not copy on already pinned tensors (#23484)
Summary: fixes https://github.com/pytorch/pytorch/issues/21076 Pull Request resolved: https://github.com/pytorch/pytorch/pull/23484 Differential Revision: D16546264 Pulled By: ezyang fbshipit-source-id: 8058e0bbc6336751f36b884d71234feef498a982
This commit is contained in:
parent
3fe00f0c90
commit
af638ad5d7
16 changed files with 119 additions and 47 deletions
|
|
@ -47,6 +47,9 @@ class CAFFE2_API Context {
|
|||
AT_ERROR(DeviceTypeName(device_type), " device type not enabled.");
|
||||
}
|
||||
}
|
||||
bool isPinnedPtr(void* data) {
|
||||
return detail::getCUDAHooks().isPinnedPtr(data);
|
||||
}
|
||||
bool hasOpenMP() const;
|
||||
bool hasMKL() const;
|
||||
bool hasLAPACK() const;
|
||||
|
|
|
|||
|
|
@ -485,6 +485,7 @@ class CAFFE2_API Tensor {
|
|||
Tensor narrow(int64_t dim, int64_t start, int64_t length) const;
|
||||
Tensor permute(IntArrayRef dims) const;
|
||||
Tensor numpy_T() const;
|
||||
bool is_pinned() const;
|
||||
Tensor pin_memory() const;
|
||||
Tensor pinverse(double rcond=1e-15) const;
|
||||
Tensor reciprocal() const;
|
||||
|
|
|
|||
|
|
@ -591,6 +591,10 @@ inline Tensor Tensor::numpy_T() const {
|
|||
static auto table = globalATenDispatch().getOpTable("aten::numpy_T(Tensor(a) self) -> Tensor(a)");
|
||||
return table->getOp<Tensor (const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(*this);
|
||||
}
|
||||
inline bool Tensor::is_pinned() const {
|
||||
static auto table = globalATenDispatch().getOpTable("aten::is_pinned(Tensor self) -> bool");
|
||||
return table->getOp<bool (const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(*this);
|
||||
}
|
||||
inline Tensor Tensor::pin_memory() const {
|
||||
static auto table = globalATenDispatch().getOpTable("aten::pin_memory(Tensor self) -> Tensor");
|
||||
return table->getOp<Tensor (const Tensor &)>(tensorTypeIdToBackend(type_id()), is_variable())(*this);
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
|
||||
#include <ATen/CUDAGenerator.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/DeviceGuard.h>
|
||||
#include <ATen/DynamicLibrary.h>
|
||||
#include <ATen/cuda/CUDAConfig.h>
|
||||
#include <ATen/cuda/CUDADevice.h>
|
||||
|
|
@ -64,6 +65,41 @@ Device CUDAHooks::getDeviceFromPtr(void* data) const {
|
|||
return at::cuda::getDeviceFromPtr(data);
|
||||
}
|
||||
|
||||
bool CUDAHooks::isPinnedPtr(void* data) const {
|
||||
// First check if driver is broken/missing, in which case PyTorch CPU
|
||||
// functionalities should still work, we should report `false` here.
|
||||
if (!CUDAHooks::hasCUDA()) {
|
||||
return false;
|
||||
}
|
||||
// cudaPointerGetAttributes grabs context on the current device, so we set
|
||||
// device to one that already has context, if exists.
|
||||
at::OptionalDeviceGuard device_guard;
|
||||
auto primary_ctx_device_index = CUDAHooks::getDevceIndexWithPrimaryContext();
|
||||
if (primary_ctx_device_index.has_value()) {
|
||||
device_guard.reset_device(at::Device(at::DeviceType::CUDA, *primary_ctx_device_index));
|
||||
}
|
||||
cudaPointerAttributes attr;
|
||||
cudaError_t err = cudaPointerGetAttributes(&attr, data);
|
||||
#ifndef __HIP_PLATFORM_HCC__
|
||||
if (err == cudaErrorInvalidValue) {
|
||||
cudaGetLastError();
|
||||
return false;
|
||||
}
|
||||
AT_CUDA_CHECK(err);
|
||||
#else
|
||||
// HIP throws hipErrorUnknown here
|
||||
if (err != cudaSuccess) {
|
||||
cudaGetLastError();
|
||||
return false;
|
||||
}
|
||||
#endif
|
||||
#if CUDA_VERSION >= 10000
|
||||
return attr.type == cudaMemoryTypeHost;
|
||||
#else
|
||||
return attr.memoryType == cudaMemoryTypeHost;
|
||||
#endif
|
||||
}
|
||||
|
||||
bool CUDAHooks::hasCUDA() const {
|
||||
return at::cuda::is_available();
|
||||
}
|
||||
|
|
@ -117,13 +153,30 @@ int64_t CUDAHooks::current_device() const {
|
|||
|
||||
bool CUDAHooks::hasPrimaryContext(int64_t device_index) const {
|
||||
TORCH_CHECK(device_index >= 0 && device_index < at::cuda::device_count(),
|
||||
"hasPrimaryContext expects valid device index, but got device_index=", device_index);
|
||||
"hasPrimaryContext expects a valid device index, but got device_index=", device_index);
|
||||
unsigned int ctx_flags;
|
||||
int ctx_is_active;
|
||||
AT_CUDA_DRIVER_CHECK(CUDAHooks::nvrtc().cuDevicePrimaryCtxGetState(device_index, &ctx_flags, &ctx_is_active));
|
||||
return ctx_is_active == 1;
|
||||
}
|
||||
|
||||
c10::optional<int64_t> CUDAHooks::getDevceIndexWithPrimaryContext() const {
|
||||
// check current device first
|
||||
int64_t current_device_index = CUDAHooks::current_device();
|
||||
if (current_device_index >= 0) {
|
||||
if (CUDAHooks::hasPrimaryContext(current_device_index)) {
|
||||
return current_device_index;
|
||||
}
|
||||
}
|
||||
for (int64_t device_index = 0; device_index < CUDAHooks::getNumGPUs(); device_index++) {
|
||||
if (device_index == current_device_index) continue;
|
||||
if (CUDAHooks::hasPrimaryContext(device_index)) {
|
||||
return device_index;
|
||||
}
|
||||
}
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
Allocator* CUDAHooks::getPinnedMemoryAllocator() const {
|
||||
return at::cuda::getPinnedMemoryAllocator();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#include <ATen/detail/CUDAHooksInterface.h>
|
||||
|
||||
#include <ATen/Generator.h>
|
||||
#include <c10/util/Optional.h>
|
||||
|
||||
// TODO: No need to have this whole header, we can just put it all in
|
||||
// the cpp file
|
||||
|
|
@ -12,6 +13,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
|||
CUDAHooks(at::CUDAHooksArgs) {}
|
||||
std::unique_ptr<THCState, void(*)(THCState*)> initCUDA() const override;
|
||||
Device getDeviceFromPtr(void* data) const override;
|
||||
bool isPinnedPtr(void* data) const override;
|
||||
Generator* getDefaultCUDAGenerator(DeviceIndex device_index = -1) const override;
|
||||
bool hasCUDA() const override;
|
||||
bool hasMAGMA() const override;
|
||||
|
|
@ -19,6 +21,7 @@ struct CUDAHooks : public at::CUDAHooksInterface {
|
|||
const at::cuda::NVRTC& nvrtc() const override;
|
||||
int64_t current_device() const override;
|
||||
bool hasPrimaryContext(int64_t device_index) const override;
|
||||
c10::optional<int64_t> getDevceIndexWithPrimaryContext() const override;
|
||||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
bool compiledWithCuDNN() const override;
|
||||
bool compiledWithMIOpen() const override;
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@
|
|||
#include <c10/core/Allocator.h>
|
||||
#include <ATen/core/Generator.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#include <c10/util/Optional.h>
|
||||
#include <c10/util/Registry.h>
|
||||
|
||||
#include <cstddef>
|
||||
|
|
@ -71,6 +71,10 @@ struct CAFFE2_API CUDAHooksInterface {
|
|||
TORCH_CHECK(false, "Cannot get device of pointer on CUDA without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual bool isPinnedPtr(void* data) const {
|
||||
return false;
|
||||
}
|
||||
|
||||
virtual bool hasCUDA() const {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -95,6 +99,10 @@ struct CAFFE2_API CUDAHooksInterface {
|
|||
TORCH_CHECK(false, "Cannot call hasPrimaryContext(", device_index, ") without ATen_cuda library. ", CUDA_HELP);
|
||||
}
|
||||
|
||||
virtual c10::optional<int64_t> getDevceIndexWithPrimaryContext() const {
|
||||
return c10::nullopt;
|
||||
}
|
||||
|
||||
virtual Allocator* getPinnedMemoryAllocator() const {
|
||||
TORCH_CHECK(false, "Pinned memory requires CUDA. ", CUDA_HELP);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -9,10 +9,17 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
|
||||
bool is_pinned(const Tensor& self) {
|
||||
return detail::getCUDAHooks().isPinnedPtr(self.storage().data());
|
||||
}
|
||||
|
||||
Tensor pin_memory(const Tensor& self) {
|
||||
if (self.type().backend() != Backend::CPU) {
|
||||
AT_ERROR("cannot pin '", self.type().toString(), "' only dense CPU tensors can be pinned");
|
||||
}
|
||||
if (self.is_pinned()) {
|
||||
return self;
|
||||
}
|
||||
auto* allocator = detail::getCUDAHooks().getPinnedMemoryAllocator();
|
||||
auto storage = Storage(
|
||||
self.dtype(),
|
||||
|
|
|
|||
|
|
@ -1545,8 +1545,11 @@
|
|||
|
||||
- func: pixel_shuffle(Tensor self, int upscale_factor) -> Tensor
|
||||
|
||||
- func: is_pinned(Tensor self) -> bool
|
||||
variants: method
|
||||
|
||||
- func: pin_memory(Tensor self) -> Tensor
|
||||
variants: function, method
|
||||
variants: method
|
||||
|
||||
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
|
||||
variants: function, method
|
||||
|
|
|
|||
|
|
@ -42,24 +42,6 @@ static bool BlockComparator(const BlockSize& a, const BlockSize& b)
|
|||
return (uintptr_t)a.ptr < (uintptr_t)b.ptr;
|
||||
}
|
||||
|
||||
static int64_t inline get_device_index_with_primary_context() {
|
||||
const auto& cuda_hooks = at::detail::getCUDAHooks();
|
||||
// check current device first
|
||||
int64_t current_device_index = cuda_hooks.current_device();
|
||||
if (current_device_index >= 0) {
|
||||
if (cuda_hooks.hasPrimaryContext(current_device_index)) {
|
||||
return current_device_index;
|
||||
}
|
||||
}
|
||||
for (int64_t device_index = 0; device_index < cuda_hooks.getNumGPUs(); device_index++) {
|
||||
if (device_index == current_device_index) continue;
|
||||
if (cuda_hooks.hasPrimaryContext(device_index)) {
|
||||
return device_index;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
struct HostAllocator
|
||||
{
|
||||
typedef bool (*Comparison)(const BlockSize&, const BlockSize&);
|
||||
|
|
@ -106,9 +88,9 @@ struct HostAllocator
|
|||
// So we grab any existing primary context, if available.
|
||||
// See pytorch/pytorch#21081.
|
||||
at::OptionalDeviceGuard device_guard;
|
||||
auto primary_ctx_device_index = get_device_index_with_primary_context();
|
||||
if (primary_ctx_device_index >= 0) {
|
||||
device_guard.reset_device(at::Device(at::DeviceType::CUDA, primary_ctx_device_index));
|
||||
auto primary_ctx_device_index = at::detail::getCUDAHooks().getDevceIndexWithPrimaryContext();
|
||||
if (primary_ctx_device_index.has_value()) {
|
||||
device_guard.reset_device(at::Device(at::DeviceType::CUDA, *primary_ctx_device_index));
|
||||
}
|
||||
|
||||
// note that cudaHostAlloc may not touch pointer if size is 0
|
||||
|
|
|
|||
|
|
@ -1344,7 +1344,7 @@ class TestCuda(TestCase):
|
|||
|
||||
# Bool test case
|
||||
t = torch.tensor([[False, True], [True, True]], device='cuda')
|
||||
self.assertEqual(torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]], device='cuda')),
|
||||
self.assertEqual(torch.gather(t, 1, torch.tensor([[0, 0], [1, 0]], device='cuda')),
|
||||
torch.tensor([[False, False], [True, True]], device='cuda'))
|
||||
|
||||
def test_gather(self):
|
||||
|
|
|
|||
|
|
@ -66,12 +66,24 @@ class TestCudaPrimaryCtx(TestCase):
|
|||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
self.assertFalse(x.is_pinned())
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
x = torch.randn(3, device='cpu').pin_memory()
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
self.assertTrue(x.is_pinned())
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
self.assertFalse(torch._C._cuda_hasPrimaryContext(0))
|
||||
self.assertTrue(torch._C._cuda_hasPrimaryContext(1))
|
||||
|
||||
x = torch.randn(3, device='cpu', pin_memory=True)
|
||||
|
||||
# We should still have only created context on 'cuda:1'
|
||||
|
|
|
|||
|
|
@ -10928,14 +10928,19 @@ tensor([[[1., 1., 1., ..., 1., 1., 1.],
|
|||
self.assertEqual(empty_strided.shape, as_strided.shape)
|
||||
self.assertEqual(empty_strided.stride(), as_strided.stride())
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
|
||||
def test_pin_memory(self):
|
||||
x = torch.randn(3, 5)
|
||||
self.assertFalse(x.is_pinned())
|
||||
pinned = x.pin_memory()
|
||||
self.assertTrue(pinned.is_pinned())
|
||||
self.assertEqual(pinned, x)
|
||||
self.assertNotEqual(pinned.data_ptr(), x.data_ptr())
|
||||
if not torch.cuda.is_available():
|
||||
self.assertRaises(RuntimeError, lambda: x.pin_memory())
|
||||
else:
|
||||
pinned = x.pin_memory()
|
||||
self.assertTrue(pinned.is_pinned())
|
||||
self.assertEqual(pinned, x)
|
||||
self.assertNotEqual(pinned.data_ptr(), x.data_ptr())
|
||||
# test that pin_memory on already pinned tensor has no effect
|
||||
self.assertIs(pinned, pinned.pin_memory())
|
||||
self.assertEqual(pinned.data_ptr(), pinned.pin_memory().data_ptr())
|
||||
|
||||
@unittest.skipIf(not torch.cuda.is_available(), 'no CUDA')
|
||||
def test_pin_memory_from_constructor(self):
|
||||
|
|
|
|||
|
|
@ -93,7 +93,6 @@ class Tensor:
|
|||
# Manually defined methods from torch/tensor.py
|
||||
def register_hook(self, hook: Callable) -> Any: ...
|
||||
def retain_grad(self) -> None: ...
|
||||
def is_pinned(self) -> bool: ...
|
||||
def is_shared(self) -> bool: ...
|
||||
def share_memory_(self) -> None: ...
|
||||
# TODO: fill in the types for these, or otherwise figure out some
|
||||
|
|
|
|||
|
|
@ -1310,6 +1310,11 @@ is_contiguous() -> bool
|
|||
Returns True if :attr:`self` tensor is contiguous in memory in C order.
|
||||
""")
|
||||
|
||||
add_docstr_all('is_pinned',
|
||||
r"""
|
||||
Returns true if this tensor resides in pinned memory.
|
||||
""")
|
||||
|
||||
add_docstr_all('is_floating_point',
|
||||
r"""
|
||||
is_floating_point() -> bool
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
#include <ATen/ATen.h>
|
||||
|
||||
#ifdef USE_CUDA
|
||||
#include <cuda_runtime.h>
|
||||
#endif
|
||||
|
|
@ -27,17 +29,7 @@ static PyObject * THPStorage_(isPinned)(THPStorage *self)
|
|||
{
|
||||
HANDLE_TH_ERRORS
|
||||
#if defined(USE_CUDA)
|
||||
cudaPointerAttributes attr;
|
||||
cudaError_t err = cudaPointerGetAttributes(&attr, THWStorage_(data)(LIBRARY_STATE self->cdata));
|
||||
if (err != cudaSuccess) {
|
||||
cudaGetLastError();
|
||||
Py_RETURN_FALSE;
|
||||
}
|
||||
#if CUDA_VERSION >= 10000
|
||||
return PyBool_FromLong(attr.type == cudaMemoryTypeHost);
|
||||
#else
|
||||
return PyBool_FromLong(attr.memoryType == cudaMemoryTypeHost);
|
||||
#endif
|
||||
return PyBool_FromLong(at::globalContext().isPinnedPtr(THWStorage_(data)(LIBRARY_STATE self->cdata)));
|
||||
#else
|
||||
Py_RETURN_FALSE;
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -231,11 +231,6 @@ class Tensor(torch._C._TensorBase):
|
|||
self.register_hook(retain_grad_hook)
|
||||
self.retains_grad = True
|
||||
|
||||
def is_pinned(self):
|
||||
r"""Returns true if this tensor resides in pinned memory"""
|
||||
storage = self.storage()
|
||||
return storage.is_pinned() if storage else False
|
||||
|
||||
def is_shared(self):
|
||||
r"""Checks if tensor is in shared memory.
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue