mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add an optional Device parameter to pin_memory/is_pinned that does nothing (#60201)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/60201 This is to flush out BC/FC problems with adding this parameter. Later PR will actually add the desired functionality. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: ngimel Differential Revision: D29331880 Pulled By: ezyang fbshipit-source-id: 6036716d6ae55e6ea7ef2348b6c34a39613c8dd5
This commit is contained in:
parent
85af24f52b
commit
3ad3f20bff
4 changed files with 9 additions and 5 deletions
|
|
@ -9,14 +9,16 @@
|
|||
namespace at {
|
||||
namespace native {
|
||||
|
||||
bool is_pinned(const Tensor& self) {
|
||||
bool is_pinned(const Tensor& self, c10::optional<Device> device) {
|
||||
TORCH_CHECK(!device.has_value() || device->is_cuda(), "non-cuda device doesn't have a concept of is_pinned");
|
||||
return detail::getCUDAHooks().isPinnedPtr(self.storage().data());
|
||||
}
|
||||
|
||||
Tensor pin_memory(const Tensor& self) {
|
||||
Tensor pin_memory(const Tensor& self, c10::optional<Device> device) {
|
||||
if (!self.device().is_cpu()) {
|
||||
AT_ERROR("cannot pin '", self.toString(), "' only dense CPU tensors can be pinned");
|
||||
}
|
||||
TORCH_CHECK(!device.has_value() || device->is_cuda(), "non-cuda device doesn't have a concept of pinned memory");
|
||||
if (self.is_pinned()) {
|
||||
return self;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3227,10 +3227,10 @@
|
|||
CPU: channel_shuffle
|
||||
QuantizedCPU: channel_shuffle_quantized_cpu
|
||||
|
||||
- func: is_pinned(Tensor self) -> bool
|
||||
- func: is_pinned(Tensor self, Device? device=None) -> bool
|
||||
variants: method
|
||||
|
||||
- func: pin_memory(Tensor(a) self) -> Tensor(a)
|
||||
- func: pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)
|
||||
variants: method
|
||||
|
||||
- func: pinverse(Tensor self, float rcond=1e-15) -> Tensor
|
||||
|
|
|
|||
|
|
@ -1034,6 +1034,8 @@ def arg_parser_unpack_method(t: Type, has_default: bool) -> str:
|
|||
return 'generator'
|
||||
elif t.elem.name == BaseTy.Layout:
|
||||
return 'layoutWithDefault' if has_default else 'layoutOptional'
|
||||
elif t.elem.name == BaseTy.Device:
|
||||
return 'deviceWithDefault' if has_default else 'deviceOptional'
|
||||
|
||||
elif isinstance(t.elem, ListType):
|
||||
if str(t.elem.elem) == 'int':
|
||||
|
|
|
|||
|
|
@ -853,7 +853,7 @@ class ShapePropagator {
|
|||
"aten::normal(float mean, Tensor std, *, Generator? generator) -> Tensor",
|
||||
"aten::normal(Tensor mean, float std, *, Generator? generator) -> Tensor",
|
||||
"aten::permute(Tensor self, int[] dims) -> Tensor",
|
||||
"aten::pin_memory(Tensor(a) self) -> Tensor(a)",
|
||||
"aten::pin_memory(Tensor(a) self, Device? device=None) -> Tensor(a)",
|
||||
"aten::pinverse(Tensor self, float rcond) -> Tensor",
|
||||
"aten::reciprocal(Tensor self) -> Tensor",
|
||||
"aten::relu(Tensor self) -> Tensor",
|
||||
|
|
|
|||
Loading…
Reference in a new issue