2018-04-03 20:29:25 +00:00
|
|
|
import sys
|
|
|
|
|
import torch
|
|
|
|
|
import torch._C as _C
|
|
|
|
|
from collections import OrderedDict
|
|
|
|
|
import torch.utils.hooks as hooks
|
|
|
|
|
import warnings
|
|
|
|
|
import weakref
|
|
|
|
|
from torch._six import imap
|
|
|
|
|
from torch._C import _add_docstr
|
2019-03-11 15:55:01 +00:00
|
|
|
from numbers import Number
|
2018-04-03 20:29:25 +00:00
|
|
|
|
|
|
|
|
|
Correctly share CUDA Parameters. (#10220)
Summary:
```
Correctly share CUDA Parameters, requires_grad and hooks.
Previously, the following was true:
- If you put a Parameter for a CUDA tensor
in multiprocessing queue (or otherwise tried to transfer it),
this failed, saying that we cannot pickle CUDA storage.
This is issue #9996.
- If you put a leaf Tensor that requires_grad=True through the
multiprocessing queue, it would come out the other end as
requires_grad=False (It should have come out the other end
as requires_grad=True). Similarly, backwards hooks were
lost.
- If you put a non-leaf Tensor that requires_grad=True through
the multiprocessing queue, it would come out the other end
as requires_grad=False.
The root cause for the first issue was that implementation of
reductions for Parameter used the superclass implementation
(tensor) in __reduce_ex__, but this always picks up the
non-ForkingPickler reduction, which doesn't work with CUDA tensors.
So, we registered a new ForkingPickler specifically for Parameter,
and adjusted the code to correctly rewrap a Tensor in a Parameter
if it was originally a parameter.
While working on this, we realized that requires_grad and backwards
hooks would not be preserved in the ForkingPickler reduction
implementation. We fixed the reducer to save these parameters.
However, Adam Paszke pointed out that we shouldn't allow sending
requires_grad=True, non-leaf Tensors over a multiprocessing
queue, since we don't actually support autograd over process
boundar. We now throw an error in this case; this may cause
previously working code to fail, but this is easy enough to fix;
just detach() the tensor before sending it. The error message says
so.
Fixes #9996.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10220
Differential Revision: D9160746
Pulled By: ezyang
fbshipit-source-id: a39c0dbc012ba5afc7a9e646da5c7f325b3cf05c
2018-08-10 20:46:54 +00:00
|
|
|
# NB: If you subclass Tensor, and want to share the subclassed class
|
|
|
|
|
# across processes, you must also update torch/multiprocessing/reductions.py
|
|
|
|
|
# to define a ForkingPickler serialization mode for the class.
|
2019-01-29 19:19:51 +00:00
|
|
|
#
|
|
|
|
|
# NB: If you add a new method to Tensor, you must update
|
|
|
|
|
# torch/__init__.py.in to add a type annotation for your method;
|
|
|
|
|
# otherwise, it will not show up in autocomplete.
|
2018-04-03 20:29:25 +00:00
|
|
|
class Tensor(torch._C._TensorBase):
|
|
|
|
|
def __deepcopy__(self, memo):
|
|
|
|
|
if not self.is_leaf:
|
2018-04-04 17:36:56 +00:00
|
|
|
raise RuntimeError("Only Tensors created explicitly by the user "
|
2018-04-03 20:29:25 +00:00
|
|
|
"(graph leaves) support the deepcopy protocol at the moment")
|
|
|
|
|
if id(self) in memo:
|
|
|
|
|
return memo[id(self)]
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
if self.is_sparse:
|
|
|
|
|
new_tensor = self.clone()
|
|
|
|
|
else:
|
|
|
|
|
new_storage = self.storage().__deepcopy__(memo)
|
|
|
|
|
new_tensor = self.new()
|
|
|
|
|
new_tensor.set_(new_storage, self.storage_offset(), self.size(), self.stride())
|
|
|
|
|
memo[id(self)] = new_tensor
|
|
|
|
|
new_tensor.requires_grad = self.requires_grad
|
|
|
|
|
return new_tensor
|
|
|
|
|
|
|
|
|
|
def __reduce_ex__(self, proto):
|
2018-10-17 03:08:45 +00:00
|
|
|
# See Note [Don't serialize hooks]
|
|
|
|
|
torch.utils.hooks.warn_if_has_hooks(self)
|
2018-04-03 20:29:25 +00:00
|
|
|
args = (self.storage(),
|
|
|
|
|
self.storage_offset(),
|
|
|
|
|
tuple(self.size()),
|
|
|
|
|
self.stride(),
|
|
|
|
|
self.requires_grad,
|
2018-10-17 03:08:45 +00:00
|
|
|
OrderedDict()) # previously was self._backward_hooks
|
2018-04-03 20:29:25 +00:00
|
|
|
return (torch._utils._rebuild_tensor_v2, args)
|
|
|
|
|
|
|
|
|
|
def __setstate__(self, state):
|
2018-10-17 03:08:45 +00:00
|
|
|
# Warning: this method is NOT called when you torch.load() a tensor;
|
|
|
|
|
# that is managed by _rebuild_tensor_v2
|
2018-04-04 17:36:56 +00:00
|
|
|
if not self.is_leaf:
|
|
|
|
|
raise RuntimeError('__setstate__ can be only called on leaf Tensors')
|
|
|
|
|
if len(state) == 4:
|
|
|
|
|
# legacy serialization of Tensor
|
|
|
|
|
self.set_(*state)
|
|
|
|
|
return
|
|
|
|
|
elif len(state) == 5:
|
2018-04-03 20:29:25 +00:00
|
|
|
# legacy serialization of Variable
|
|
|
|
|
self.data = state[0]
|
|
|
|
|
state = (state[3], state[4], state[2])
|
2018-10-17 03:08:45 +00:00
|
|
|
# The setting of _backward_hooks is expected to be a no-op.
|
|
|
|
|
# See Note [Don't serialize hooks]
|
2018-04-03 20:29:25 +00:00
|
|
|
self.requires_grad, _, self._backward_hooks = state
|
|
|
|
|
|
|
|
|
|
def __repr__(self):
|
|
|
|
|
# All strings are unicode in Python 3, while we have to encode unicode
|
|
|
|
|
# strings in Python2. If we can't, let python decide the best
|
|
|
|
|
# characters to replace unicode characters with.
|
|
|
|
|
if sys.version_info > (3,):
|
|
|
|
|
return torch._tensor_str._str(self)
|
|
|
|
|
else:
|
|
|
|
|
if hasattr(sys.stdout, 'encoding'):
|
|
|
|
|
return torch._tensor_str._str(self).encode(
|
|
|
|
|
sys.stdout.encoding or 'UTF-8', 'replace')
|
|
|
|
|
else:
|
|
|
|
|
return torch._tensor_str._str(self).encode('UTF-8', 'replace')
|
|
|
|
|
|
|
|
|
|
def backward(self, gradient=None, retain_graph=None, create_graph=False):
|
2018-04-04 17:36:56 +00:00
|
|
|
r"""Computes the gradient of current tensor w.r.t. graph leaves.
|
2018-04-03 20:29:25 +00:00
|
|
|
|
2018-04-04 17:36:56 +00:00
|
|
|
The graph is differentiated using the chain rule. If the tensor is
|
2018-04-03 20:29:25 +00:00
|
|
|
non-scalar (i.e. its data has more than one element) and requires
|
|
|
|
|
gradient, the function additionally requires specifying ``gradient``.
|
|
|
|
|
It should be a tensor of matching type and location, that contains
|
|
|
|
|
the gradient of the differentiated function w.r.t. ``self``.
|
|
|
|
|
|
|
|
|
|
This function accumulates gradients in the leaves - you might need to
|
|
|
|
|
zero them before calling it.
|
|
|
|
|
|
|
|
|
|
Arguments:
|
2018-04-04 17:36:56 +00:00
|
|
|
gradient (Tensor or None): Gradient w.r.t. the
|
|
|
|
|
tensor. If it is a tensor, it will be automatically converted
|
|
|
|
|
to a Tensor that does not require grad unless ``create_graph`` is True.
|
|
|
|
|
None values can be specified for scalar Tensors or ones that
|
2018-04-03 20:29:25 +00:00
|
|
|
don't require grad. If a None value would be acceptable then
|
|
|
|
|
this argument is optional.
|
|
|
|
|
retain_graph (bool, optional): If ``False``, the graph used to compute
|
|
|
|
|
the grads will be freed. Note that in nearly all cases setting
|
|
|
|
|
this option to True is not needed and often can be worked around
|
|
|
|
|
in a much more efficient way. Defaults to the value of
|
|
|
|
|
``create_graph``.
|
|
|
|
|
create_graph (bool, optional): If ``True``, graph of the derivative will
|
|
|
|
|
be constructed, allowing to compute higher order derivative
|
|
|
|
|
products. Defaults to ``False``.
|
|
|
|
|
"""
|
|
|
|
|
torch.autograd.backward(self, gradient, retain_graph, create_graph)
|
|
|
|
|
|
|
|
|
|
def register_hook(self, hook):
|
|
|
|
|
r"""Registers a backward hook.
|
|
|
|
|
|
|
|
|
|
The hook will be called every time a gradient with respect to the
|
2018-04-04 17:36:56 +00:00
|
|
|
Tensor is computed. The hook should have the following signature::
|
2018-04-03 20:29:25 +00:00
|
|
|
|
2018-04-04 17:36:56 +00:00
|
|
|
hook(grad) -> Tensor or None
|
2018-04-03 20:29:25 +00:00
|
|
|
|
2018-07-09 01:54:24 +00:00
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
The hook should not modify its argument, but it can optionally return
|
|
|
|
|
a new gradient which will be used in place of :attr:`grad`.
|
|
|
|
|
|
|
|
|
|
This function returns a handle with a method ``handle.remove()``
|
|
|
|
|
that removes the hook from the module.
|
|
|
|
|
|
2018-07-09 01:54:24 +00:00
|
|
|
Example::
|
|
|
|
|
|
2018-04-04 17:36:56 +00:00
|
|
|
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
|
2018-04-03 20:29:25 +00:00
|
|
|
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
|
2018-04-04 17:36:56 +00:00
|
|
|
>>> v.backward(torch.tensor([1., 2., 3.]))
|
|
|
|
|
>>> v.grad
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
2
|
2018-04-04 17:36:56 +00:00
|
|
|
4
|
|
|
|
|
6
|
|
|
|
|
[torch.FloatTensor of size (3,)]
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
>>> h.remove() # removes the hook
|
|
|
|
|
"""
|
|
|
|
|
if not self.requires_grad:
|
2018-04-04 17:36:56 +00:00
|
|
|
raise RuntimeError("cannot register a hook on a tensor that "
|
2018-04-03 20:29:25 +00:00
|
|
|
"doesn't require gradient")
|
|
|
|
|
if self._backward_hooks is None:
|
|
|
|
|
self._backward_hooks = OrderedDict()
|
|
|
|
|
if self.grad_fn is not None:
|
|
|
|
|
self.grad_fn._register_hook_dict(self)
|
|
|
|
|
handle = hooks.RemovableHandle(self._backward_hooks)
|
|
|
|
|
self._backward_hooks[handle.id] = hook
|
|
|
|
|
return handle
|
|
|
|
|
|
|
|
|
|
def reinforce(self, reward):
|
|
|
|
|
def trim(str):
|
|
|
|
|
return '\n'.join([line.strip() for line in str.split('\n')])
|
|
|
|
|
|
|
|
|
|
raise RuntimeError(trim(r"""reinforce() was removed.
|
|
|
|
|
Use torch.distributions instead.
|
2018-10-15 19:55:10 +00:00
|
|
|
See https://pytorch.org/docs/master/distributions.html
|
2018-04-03 20:29:25 +00:00
|
|
|
|
|
|
|
|
Instead of:
|
|
|
|
|
|
|
|
|
|
probs = policy_network(state)
|
|
|
|
|
action = probs.multinomial()
|
|
|
|
|
next_state, reward = env.step(action)
|
|
|
|
|
action.reinforce(reward)
|
|
|
|
|
action.backward()
|
|
|
|
|
|
|
|
|
|
Use:
|
|
|
|
|
|
|
|
|
|
probs = policy_network(state)
|
|
|
|
|
# NOTE: categorical is equivalent to what used to be called multinomial
|
|
|
|
|
m = torch.distributions.Categorical(probs)
|
|
|
|
|
action = m.sample()
|
|
|
|
|
next_state, reward = env.step(action)
|
|
|
|
|
loss = -m.log_prob(action) * reward
|
|
|
|
|
loss.backward()
|
|
|
|
|
"""))
|
|
|
|
|
|
|
|
|
|
detach = _add_docstr(_C._TensorBase.detach, r"""
|
2018-04-04 17:36:56 +00:00
|
|
|
Returns a new Tensor, detached from the current graph.
|
2018-04-03 20:29:25 +00:00
|
|
|
|
|
|
|
|
The result will never require gradient.
|
|
|
|
|
|
|
|
|
|
.. note::
|
|
|
|
|
|
2018-12-27 00:31:47 +00:00
|
|
|
Returned Tensor shares the same storage with the original one.
|
2018-04-03 20:29:25 +00:00
|
|
|
In-place modifications on either of them will be seen, and may trigger
|
|
|
|
|
errors in correctness checks.
|
2018-12-27 00:31:47 +00:00
|
|
|
IMPORTANT NOTE: Previously, in-place size / stride / storage changes
|
|
|
|
|
(such as `resize_` / `resize_as_` / `set_` / `transpose_`) to the returned tensor
|
|
|
|
|
also update the original tensor. Now, these in-place changes will not update the
|
|
|
|
|
original tensor anymore, and will instead trigger an error.
|
|
|
|
|
For sparse tensors:
|
|
|
|
|
In-place indices / values changes (such as `zero_` / `copy_` / `add_`) to the
|
|
|
|
|
returned tensor will not update the original tensor anymore, and will instead
|
|
|
|
|
trigger an error.
|
2018-04-03 20:29:25 +00:00
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
detach_ = _add_docstr(_C._TensorBase.detach_, r"""
|
2018-04-04 17:36:56 +00:00
|
|
|
Detaches the Tensor from the graph that created it, making it a leaf.
|
2018-04-03 20:29:25 +00:00
|
|
|
Views cannot be detached in-place.
|
|
|
|
|
""")
|
|
|
|
|
|
|
|
|
|
def retain_grad(self):
|
2018-04-04 17:36:56 +00:00
|
|
|
r"""Enables .grad attribute for non-leaf Tensors."""
|
2018-04-03 20:29:25 +00:00
|
|
|
if self.grad_fn is None: # no-op for leaves
|
|
|
|
|
return
|
|
|
|
|
if not self.requires_grad:
|
2018-04-04 17:36:56 +00:00
|
|
|
raise RuntimeError("can't retain_grad on Tensor that has requires_grad=False")
|
2018-04-03 20:29:25 +00:00
|
|
|
if hasattr(self, 'retains_grad'):
|
|
|
|
|
return
|
|
|
|
|
weak_self = weakref.ref(self)
|
|
|
|
|
|
|
|
|
|
def retain_grad_hook(grad):
|
|
|
|
|
var = weak_self()
|
|
|
|
|
if var is None:
|
|
|
|
|
return
|
|
|
|
|
if var._grad is None:
|
|
|
|
|
var._grad = grad.clone()
|
|
|
|
|
else:
|
|
|
|
|
var._grad = var._grad + grad
|
|
|
|
|
|
|
|
|
|
self.register_hook(retain_grad_hook)
|
|
|
|
|
self.retains_grad = True
|
|
|
|
|
|
|
|
|
|
def is_pinned(self):
|
|
|
|
|
r"""Returns true if this tensor resides in pinned memory"""
|
|
|
|
|
storage = self.storage()
|
|
|
|
|
return storage.is_pinned() if storage else False
|
|
|
|
|
|
|
|
|
|
def is_shared(self):
|
|
|
|
|
r"""Checks if tensor is in shared memory.
|
|
|
|
|
|
|
|
|
|
This is always ``True`` for CUDA tensors.
|
|
|
|
|
"""
|
|
|
|
|
return self.storage().is_shared()
|
|
|
|
|
|
|
|
|
|
def share_memory_(self):
|
|
|
|
|
r"""Moves the underlying storage to shared memory.
|
|
|
|
|
|
|
|
|
|
This is a no-op if the underlying storage is already in shared memory
|
|
|
|
|
and for CUDA tensors. Tensors in shared memory cannot be resized.
|
|
|
|
|
"""
|
|
|
|
|
self.storage().share_memory_()
|
|
|
|
|
return self
|
|
|
|
|
|
2018-07-13 02:24:10 +00:00
|
|
|
def __reversed__(self):
|
|
|
|
|
r"""Reverses the tensor along dimension 0."""
|
|
|
|
|
if self.dim() == 0:
|
|
|
|
|
return self
|
|
|
|
|
else:
|
|
|
|
|
return self.flip(0)
|
|
|
|
|
|
2019-01-17 06:12:13 +00:00
|
|
|
def norm(self, p="fro", dim=None, keepdim=False, dtype=None):
|
2019-02-15 01:07:12 +00:00
|
|
|
r"""See :func:`torch.norm`"""
|
2019-01-17 06:12:13 +00:00
|
|
|
return torch.norm(self, p, dim, keepdim, dtype=dtype)
|
2018-09-20 21:40:17 +00:00
|
|
|
|
2019-04-22 15:14:49 +00:00
|
|
|
def pstrf(self, upper=True):
|
|
|
|
|
r"""See :func:`torch.pstrf`"""
|
|
|
|
|
warnings.warn("torch.pstrf is deprecated in favour of torch.cholesky and will be removed "
|
|
|
|
|
"in the next release.", stacklevel=2)
|
|
|
|
|
return super(Tensor, self).pstrf(upper=upper)
|
|
|
|
|
|
2018-11-01 22:07:24 +00:00
|
|
|
def potrf(self, upper=True):
|
|
|
|
|
r"""See :func:`torch.cholesky`"""
|
|
|
|
|
warnings.warn("torch.potrf is deprecated in favour of torch.cholesky and will be removed "
|
|
|
|
|
"in the next release. Please use torch.cholesky instead and note that the "
|
|
|
|
|
":attr:`upper` argument in torch.cholesky defaults to ``False``.", stacklevel=2)
|
|
|
|
|
return super(Tensor, self).cholesky(upper=upper)
|
|
|
|
|
|
2019-04-22 15:14:49 +00:00
|
|
|
def potri(self, upper=True):
|
|
|
|
|
r"""See :func:`torch.cholesky_inverse`"""
|
|
|
|
|
warnings.warn("torch.potri is deprecated in favour of torch.cholesky_inverse and will be "
|
|
|
|
|
"removed in the next release. Please use torch.cholesky_inverse instead and "
|
|
|
|
|
"note that the :attr:`upper` argument in torch.cholesky_inverse defaults to "
|
|
|
|
|
"``False``.", stacklevel=2)
|
|
|
|
|
return super(Tensor, self).cholesky_inverse(upper=upper)
|
2019-03-11 19:15:41 +00:00
|
|
|
|
2018-12-19 20:11:49 +00:00
|
|
|
def potrs(self, u, upper=True):
|
|
|
|
|
r"""See :func:`torch.cholesky_solve`"""
|
|
|
|
|
warnings.warn("torch.potrs is deprecated in favour of torch.cholesky_solve and "
|
|
|
|
|
"will be removed in the next release. Please use torch.cholesky_solve instead "
|
|
|
|
|
"and note that the :attr:`upper` argument in torch.cholesky_solve defaults "
|
|
|
|
|
"to ``False``.", stacklevel=2)
|
|
|
|
|
return super(Tensor, self).cholesky_solve(u, upper=upper)
|
|
|
|
|
|
2019-03-18 23:01:02 +00:00
|
|
|
def gesv(self, A):
|
|
|
|
|
r"""See :func:`torch.solve`"""
|
|
|
|
|
warnings.warn("torch.gesv is deprecated in favour of torch.solve and will be removed in the "
|
|
|
|
|
"next release. Please use torch.solve instead.", stacklevel=2)
|
|
|
|
|
return super(Tensor, self).solve(A)
|
|
|
|
|
|
2019-03-21 21:18:38 +00:00
|
|
|
def trtrs(self, A, upper=True, transpose=False, unitriangular=False):
|
|
|
|
|
r"""See :func:`torch.triangular_solve`"""
|
|
|
|
|
warnings.warn("torch.trtrs is deprecated in favour of torch.triangular_solve and will be "
|
2019-03-29 07:27:48 +00:00
|
|
|
"removed in the next release. Please use torch.triangular_solve instead.",
|
|
|
|
|
stacklevel=2)
|
2019-03-21 21:18:38 +00:00
|
|
|
return super(Tensor, self).triangular_solve(A, upper=upper,
|
|
|
|
|
transpose=transpose, unitriangular=unitriangular)
|
|
|
|
|
|
2019-03-29 07:27:48 +00:00
|
|
|
def btrifact(self, pivot=True):
|
|
|
|
|
r"""See :func:`torch.lu`"""
|
|
|
|
|
warnings.warn("torch.btrifact is deprecated in favour of torch.lu and will be removed in "
|
|
|
|
|
"the next release. Please use torch.lu instead.", stacklevel=2)
|
|
|
|
|
return torch._lu_with_info(self, pivot=pivot, check_errors=True)
|
|
|
|
|
|
|
|
|
|
def btrifact_with_info(self, pivot=True):
|
|
|
|
|
r"""See :func:`torch.lu`"""
|
|
|
|
|
warnings.warn("torch.btrifact_with_info is deprecated in favour of torch.lu with the "
|
2019-04-09 22:15:06 +00:00
|
|
|
"get_infos argument and will be removed in the next release. Please use "
|
|
|
|
|
"torch.lu with the get_infos argument set to True instead.", stacklevel=2)
|
2019-03-29 07:27:48 +00:00
|
|
|
return torch._lu_with_info(self, pivot=pivot, check_errors=False)
|
|
|
|
|
|
2019-04-09 22:15:06 +00:00
|
|
|
def btrisolve(self, LU_data, LU_pivots):
|
|
|
|
|
r"""See :func:`torch.lu_solve`"""
|
|
|
|
|
warnings.warn("torch.btrisolve is deprecated in favour of torch.lu_solve and will be "
|
|
|
|
|
"removed in the next release. Please use torch.lu_solve instead.",
|
|
|
|
|
stacklevel=2)
|
|
|
|
|
return super(Tensor, self).lu_solve(LU_data=LU_data, LU_pivots=LU_pivots)
|
|
|
|
|
|
2019-03-29 07:27:48 +00:00
|
|
|
def lu(self, pivot=True, get_infos=False):
|
|
|
|
|
r"""See :func:`torch.lu`"""
|
|
|
|
|
# If get_infos is True, then we don't need to check for errors and vice versa
|
|
|
|
|
LU, pivots, infos = torch._lu_with_info(self, pivot=pivot, check_errors=(not get_infos))
|
|
|
|
|
if get_infos:
|
|
|
|
|
return LU, pivots, infos
|
|
|
|
|
else:
|
|
|
|
|
return LU, pivots
|
|
|
|
|
|
2018-07-17 17:54:03 +00:00
|
|
|
def stft(self, n_fft, hop_length=None, win_length=None, window=None,
|
|
|
|
|
center=True, pad_mode='reflect', normalized=False, onesided=True):
|
|
|
|
|
r"""See :func:`torch.stft`
|
|
|
|
|
|
|
|
|
|
.. warning::
|
|
|
|
|
This function changed signature at version 0.4.1. Calling with
|
|
|
|
|
the previous signature may cause error or return incorrect result.
|
|
|
|
|
"""
|
|
|
|
|
return torch.stft(self, n_fft, hop_length, win_length, window, center,
|
|
|
|
|
pad_mode, normalized, onesided)
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
def resize(self, *sizes):
|
|
|
|
|
warnings.warn("non-inplace resize is deprecated")
|
|
|
|
|
from torch.autograd._functions import Resize
|
|
|
|
|
return Resize.apply(self, sizes)
|
|
|
|
|
|
2018-04-04 17:36:56 +00:00
|
|
|
def resize_as(self, tensor):
|
2018-04-03 20:29:25 +00:00
|
|
|
warnings.warn("non-inplace resize_as is deprecated")
|
|
|
|
|
from torch.autograd._functions import Resize
|
2018-04-04 17:36:56 +00:00
|
|
|
return Resize.apply(self, tensor.size())
|
2018-04-03 20:29:25 +00:00
|
|
|
|
|
|
|
|
def split(self, split_size, dim=0):
|
|
|
|
|
r"""See :func:`torch.split`
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(split_size, int):
|
|
|
|
|
return super(Tensor, self).split(split_size, dim)
|
|
|
|
|
else:
|
|
|
|
|
return super(Tensor, self).split_with_sizes(split_size, dim)
|
|
|
|
|
|
2019-04-16 20:55:37 +00:00
|
|
|
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
|
|
|
|
r"""Returns the unique elements of the input tensor.
|
2018-04-03 20:29:25 +00:00
|
|
|
|
|
|
|
|
See :func:`torch.unique`
|
|
|
|
|
"""
|
2019-04-16 20:55:37 +00:00
|
|
|
return torch.unique(self, sorted=sorted, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
|
2018-04-03 20:29:25 +00:00
|
|
|
|
Add torch.unique_consecutive (#19060)
Summary:
Fixes: https://github.com/pytorch/pytorch/issues/19045
Please review: VitalyFedyunin ngimel
This is independent on the #18649 series. This will cause merge conflicts in #18649 series, but please merge this first, and I will resolve the merge conflicts there.
The new feature is exposed in `_unique2_temporary_will_remove_soon` and `_unique_dim2_temporary_will_remove_soon`. But not at `torch.unique` yet. I will take care of the API after #18649 series get merged completely.
Benchmark on a tensor of shape `torch.Size([15320, 2])`:
```python
print(torch.__version__)
print()
a = tensor.sort().values.to('cpu')
print('cpu, sorted_input=False:')
%timeit torch._unique2_temporary_will_remove_soon(a)
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True)
%timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True)
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True)
print()
print('cpu, sorted_input=True:')
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True)
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True)
print()
a = a.to('cuda')
print('cuda, sorted_input=False:')
%timeit torch._unique2_temporary_will_remove_soon(a); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, return_inverse=True, return_counts=True); torch.cuda.synchronize()
print()
print('cuda, sorted_input=True:')
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique2_temporary_will_remove_soon(a, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize()
```
```
1.1.0a0+2addccc
cpu, sorted_input=False:
340 µs ± 5.88 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
717 µs ± 14.9 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
52.3 ms ± 2.75 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
52.3 ms ± 1.79 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
cpu, sorted_input=True:
32.8 µs ± 285 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
49.9 µs ± 557 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
51.6 µs ± 1.08 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
78 µs ± 782 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
cuda, sorted_input=False:
213 µs ± 1.52 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
291 µs ± 3.81 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
250 µs ± 1.05 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
321 µs ± 1.59 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cuda, sorted_input=True:
45.6 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
110 µs ± 2.47 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
82 µs ± 857 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
143 µs ± 409 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
```python
print(torch.__version__)
print()
a1, a2 = tensor.unbind(1)
indices = (a1 * tensor.max() + a2).sort().indices
a = tensor.index_select(0, indices).to('cpu')
print('cpu, sorted_input=False:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True)
print()
print('cpu, sorted_input=True:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True)
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True)
print()
a = a.to('cuda')
print('cuda, sorted_input=False:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, return_inverse=True, return_counts=True); torch.cuda.synchronize()
print()
print('cuda, sorted_input=True:')
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_counts=True); torch.cuda.synchronize()
%timeit torch._unique_dim2_temporary_will_remove_soon(a, dim=0, sorted_input=True, return_inverse=True, return_counts=True); torch.cuda.synchronize()
```
```
cpu, sorted_input=False:
55.4 ms ± 1.12 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.8 ms ± 616 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.2 ms ± 402 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.1 ms ± 725 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cpu, sorted_input=True:
54.7 ms ± 585 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
55.2 ms ± 1.23 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.5 ms ± 865 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
54.9 ms ± 577 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
cuda, sorted_input=False:
171 µs ± 783 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
220 µs ± 1.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
203 µs ± 2.95 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
251 µs ± 2.83 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
cuda, sorted_input=True:
59.6 µs ± 757 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
113 µs ± 431 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
93.2 µs ± 2.13 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
147 µs ± 2.81 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
```
The CPU implementation of `unique_dim` is super slow, see https://github.com/pytorch/pytorch/issues/18987, but this PR will not worry about this issue.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19060
Differential Revision: D14866909
Pulled By: ezyang
fbshipit-source-id: d20012cec68c37b05cf770a6f4d6524f910b950f
2019-04-10 14:33:15 +00:00
|
|
|
def unique_consecutive(self, return_inverse=False, return_counts=False, dim=None):
|
|
|
|
|
r"""Eliminates all but the first element from every consecutive group of equivalent elements.
|
|
|
|
|
|
|
|
|
|
See :func:`torch.unique_consecutive`
|
|
|
|
|
"""
|
|
|
|
|
return torch.unique_consecutive(self, return_inverse=return_inverse, return_counts=return_counts, dim=dim)
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
def __rsub__(self, other):
|
2018-10-30 23:18:55 +00:00
|
|
|
return _C._VariableFunctions.rsub(self, other)
|
2018-04-03 20:29:25 +00:00
|
|
|
|
|
|
|
|
def __rdiv__(self, other):
|
2018-05-03 21:34:59 +00:00
|
|
|
if self.dtype.is_floating_point:
|
|
|
|
|
return self.reciprocal() * other
|
|
|
|
|
else:
|
|
|
|
|
return (self.double().reciprocal() * other).type_as(self)
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
__rtruediv__ = __rdiv__
|
|
|
|
|
__itruediv__ = _C._TensorBase.__idiv__
|
|
|
|
|
|
|
|
|
|
__pow__ = _C._TensorBase.pow
|
|
|
|
|
|
|
|
|
|
def __format__(self, format_spec):
|
|
|
|
|
if self.dim() == 0:
|
|
|
|
|
return self.item().__format__(format_spec)
|
|
|
|
|
return object.__format__(self, format_spec)
|
|
|
|
|
|
|
|
|
|
def __ipow__(self, other):
|
|
|
|
|
raise NotImplementedError("in-place pow not implemented")
|
|
|
|
|
|
|
|
|
|
def __rpow__(self, other):
|
2019-02-03 02:52:55 +00:00
|
|
|
return self.new_tensor(other) ** self
|
2018-04-03 20:29:25 +00:00
|
|
|
|
2018-05-03 21:34:59 +00:00
|
|
|
def __floordiv__(self, other):
|
|
|
|
|
result = self / other
|
|
|
|
|
if result.dtype.is_floating_point:
|
|
|
|
|
result = result.trunc()
|
|
|
|
|
return result
|
|
|
|
|
|
|
|
|
|
def __rfloordiv__(self, other):
|
|
|
|
|
result = other / self
|
|
|
|
|
if result.dtype.is_floating_point:
|
|
|
|
|
result = result.trunc()
|
|
|
|
|
return result
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
__neg__ = _C._TensorBase.neg
|
|
|
|
|
|
|
|
|
|
__eq__ = _C._TensorBase.eq
|
|
|
|
|
__ne__ = _C._TensorBase.ne
|
|
|
|
|
__lt__ = _C._TensorBase.lt
|
|
|
|
|
__le__ = _C._TensorBase.le
|
|
|
|
|
__gt__ = _C._TensorBase.gt
|
|
|
|
|
__ge__ = _C._TensorBase.ge
|
|
|
|
|
__abs__ = _C._TensorBase.abs
|
|
|
|
|
|
2019-05-15 23:39:26 +00:00
|
|
|
def __std_mean__(self, dim=None, unbiased=True, keepdim=False):
|
|
|
|
|
if dim is None:
|
|
|
|
|
return _C._VariableFunctions.std_mean(self, unbiased)
|
|
|
|
|
else:
|
|
|
|
|
return _C._VariableFunctions.std_mean(self, dim, unbiased, keepdim)
|
|
|
|
|
|
|
|
|
|
def __var_mean__(self, dim=None, unbiased=True, keepdim=False):
|
|
|
|
|
if dim is None:
|
|
|
|
|
return _C._VariableFunctions.var_mean(self, unbiased)
|
|
|
|
|
else:
|
|
|
|
|
return _C._VariableFunctions.var_mean(self, dim, unbiased, keepdim)
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
def __len__(self):
|
|
|
|
|
if self.dim() == 0:
|
|
|
|
|
raise TypeError("len() of a 0-d tensor")
|
|
|
|
|
return self.shape[0]
|
|
|
|
|
|
|
|
|
|
def __iter__(self):
|
|
|
|
|
# NB: we use 'imap' and not 'map' here, so that in Python 2 we get a
|
|
|
|
|
# generator and don't eagerly perform all the indexes. This could
|
|
|
|
|
# save us work, and also helps keep trace ordering deterministic
|
|
|
|
|
# (e.g., if you zip(*hiddens), the eager map will force all the
|
|
|
|
|
# indexes of hiddens[0] before hiddens[1], while the generator
|
|
|
|
|
# map will interleave them.)
|
|
|
|
|
if self.dim() == 0:
|
|
|
|
|
raise TypeError('iteration over a 0-d tensor')
|
2018-08-31 21:16:31 +00:00
|
|
|
if torch._C._get_tracing_state():
|
|
|
|
|
warnings.warn('Iterating over a tensor might cause the trace to be incorrect. '
|
|
|
|
|
'Passing a tensor of different shape won\'t change the number of '
|
|
|
|
|
'iterations executed (and might lead to errors or silently give '
|
|
|
|
|
'incorrect results).', category=RuntimeWarning)
|
2018-04-03 20:29:25 +00:00
|
|
|
return iter(imap(lambda i: self[i], range(self.size(0))))
|
|
|
|
|
|
|
|
|
|
def __hash__(self):
|
|
|
|
|
return id(self)
|
|
|
|
|
|
|
|
|
|
def __dir__(self):
|
2018-04-04 17:36:56 +00:00
|
|
|
tensor_methods = dir(self.__class__)
|
|
|
|
|
tensor_methods.remove('volatile') # deprecated
|
2018-04-03 20:29:25 +00:00
|
|
|
attrs = list(self.__dict__.keys())
|
2018-04-04 17:36:56 +00:00
|
|
|
keys = tensor_methods + attrs
|
2018-10-12 20:33:43 +00:00
|
|
|
|
|
|
|
|
# property only available dense, cuda tensors
|
|
|
|
|
if (not self.is_cuda) or self.is_sparse:
|
|
|
|
|
keys.remove("__cuda_array_interface__")
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
return sorted(keys)
|
|
|
|
|
|
|
|
|
|
# Numpy array interface, to support `numpy.asarray(tensor) -> ndarray`
|
2018-07-30 21:37:21 +00:00
|
|
|
__array_priority__ = 1000 # prefer Tensor ops over numpy ones
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
def __array__(self, dtype=None):
|
|
|
|
|
if dtype is None:
|
2018-08-16 19:02:45 +00:00
|
|
|
return self.numpy()
|
2018-04-03 20:29:25 +00:00
|
|
|
else:
|
2018-08-16 19:02:45 +00:00
|
|
|
return self.numpy().astype(dtype, copy=False)
|
2018-04-03 20:29:25 +00:00
|
|
|
|
|
|
|
|
# Wrap Numpy array again in a suitable tensor when done, to support e.g.
|
|
|
|
|
# `numpy.sin(tensor) -> tensor` or `numpy.greater(tensor, 0) -> ByteTensor`
|
|
|
|
|
def __array_wrap__(self, array):
|
|
|
|
|
if array.dtype == bool:
|
|
|
|
|
# Workaround, torch has no built-in bool tensor
|
|
|
|
|
array = array.astype('uint8')
|
|
|
|
|
return torch.from_numpy(array)
|
|
|
|
|
|
2019-03-11 15:55:01 +00:00
|
|
|
def __contains__(self, element):
|
|
|
|
|
r"""Check if `element` is present in tensor
|
|
|
|
|
|
|
|
|
|
Arguments:
|
|
|
|
|
element (Tensor or scalar): element to be checked
|
|
|
|
|
for presence in current tensor"
|
|
|
|
|
"""
|
|
|
|
|
if isinstance(element, (torch.Tensor, Number)):
|
|
|
|
|
return (element == self).any().item()
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
2018-10-12 20:33:43 +00:00
|
|
|
@property
|
|
|
|
|
def __cuda_array_interface__(self):
|
|
|
|
|
"""Array view description for cuda tensors.
|
|
|
|
|
|
|
|
|
|
See:
|
|
|
|
|
https://numba.pydata.org/numba-doc/latest/cuda/cuda_array_interface.html
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
# raise AttributeError for unsupported tensors, so that
|
|
|
|
|
# hasattr(cpu_tensor, "__cuda_array_interface__") is False.
|
|
|
|
|
if not self.is_cuda:
|
|
|
|
|
raise AttributeError(
|
|
|
|
|
"Can't get __cuda_array_interface__ on non-CUDA tensor type: %s "
|
|
|
|
|
"If CUDA data is required use tensor.cuda() to copy tensor to device memory." %
|
|
|
|
|
self.type()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
if self.is_sparse:
|
|
|
|
|
raise AttributeError(
|
|
|
|
|
"Can't get __cuda_array_interface__ on sparse type: %s "
|
|
|
|
|
"Use Tensor.to_dense() to convert to a dense tensor first." %
|
|
|
|
|
self.type()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# RuntimeError, matching tensor.__array__() behavior.
|
|
|
|
|
if self.requires_grad:
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
"Can't get __cuda_array_interface__ on Variable that requires grad. "
|
|
|
|
|
"If gradients aren't required, use var.detach() to get Variable that doesn't require grad."
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# CUDA devices are little-endian and tensors are stored in native byte
|
|
|
|
|
# order. 1-byte entries are endian-agnostic.
|
|
|
|
|
typestr = {
|
|
|
|
|
torch.float16: "<f2",
|
|
|
|
|
torch.float32: "<f4",
|
|
|
|
|
torch.float64: "<f8",
|
|
|
|
|
torch.uint8: "|u1",
|
|
|
|
|
torch.int8: "|i1",
|
|
|
|
|
torch.int16: "<i2",
|
|
|
|
|
torch.int32: "<i4",
|
|
|
|
|
torch.int64: "<i8",
|
|
|
|
|
}[self.dtype]
|
|
|
|
|
|
|
|
|
|
itemsize = self.storage().element_size()
|
|
|
|
|
|
|
|
|
|
shape = self.shape
|
|
|
|
|
strides = tuple(s * itemsize for s in self.stride())
|
|
|
|
|
data = (self.data_ptr(), False) # read-only is false
|
|
|
|
|
|
|
|
|
|
return dict(typestr=typestr, shape=shape, strides=strides, data=data, version=0)
|
|
|
|
|
|
2018-04-03 20:29:25 +00:00
|
|
|
__module__ = 'torch'
|