Add an optional Device parameter to pin_memory/is_pinned that does nothing (#60201)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60201

This is to flush out BC/FC problems with adding this parameter.  Later
PR will actually add the desired functionality.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D29331880

Pulled By: ezyang

fbshipit-source-id: 6036716d6ae55e6ea7ef2348b6c34a39613c8dd5
This commit is contained in:
Edward Yang 2021-06-28 10:36:36 -07:00 committed by Facebook GitHub Bot
parent 85af24f52b
commit 3ad3f20bff
4 changed files with 9 additions and 5 deletions

View file

@ -9,14 +9,16 @@
namespace at {
namespace native {
bool is_pinned(const Tensor& self) {
bool is_pinned(const Tensor& self, c10::optional<Device> device) {
TORCH_CHECK(!device.has_value() || device->is_cuda(), "non-cuda device doesn't have a concept of is_pinned");
return detail::getCUDAHooks().isPinnedPtr(self.storage().data());
}
Tensor pin_memory(const Tensor& self) {
Tensor pin_memory(const Tensor& self, c10::optional<Device> device) {
if (!self.device().is_cpu()) {
AT_ERROR("cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
}
TORCH_CHECK(!device.has_value() || device->is_cuda(), "non-cuda device doesn't have a concept of pinned memory");
if (self.is_pinned()) {
return self;
}

View file

@ -3227,10 +3227,10 @@
CPU: channel_shuffle
QuantizedCPU: channel_shuffle_quantized_cpu
- func: is_pinned(Tensor self) -> bool
- func: is_pinned(Tensor self, Device? device=None) -> bool
variants: method
- func: pin_memory(Tensor(a) self) -> Tensor(a)
- func: pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)
variants: method
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor

View file

@ -1034,6 +1034,8 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
return 'generator'
elif t.elem.name == BaseTy.Layout:
return 'layoutWithDefault' if has_default else 'layoutOptional'
elif t.elem.name == BaseTy.Device:
return 'deviceWithDefault' if has_default else 'deviceOptional'
elif isinstance(t.elem, ListType):
if str(t.elem.elem) == 'int':

View file

@ -853,7 +853,7 @@ class ShapePropagator {
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
"aten::permute(Tensor self, int[] dims) -> Tensor",
"aten::pin_memory(Tensor(a) self) -> Tensor(a)",
"aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)",
"aten::pinverse(Tensor self, float rcond) -> Tensor",
"aten::reciprocal(Tensor self) -> Tensor",
"aten::relu(Tensor self) -> Tensor",