Documents sub properly, adds subtract alias (#43850)

Summary:
`torch.sub` was undocumented, so this PR adds its documentation, analogous to `torch.add`'s documentation, and adds the alias `torch.subtract` for `torch.sub`, too. This alias comes from NumPy (see https://numpy.org/doc/stable/reference/generated/numpy.subtract.html?highlight=subtract#numpy.subtract)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/43850

Reviewed By: ngimel

Differential Revision: D23416908

Pulled By: mruberry

fbshipit-source-id: 6c4d2ebaf6ecae91f3a6efe484ce6c4dad96f016
This commit is contained in:
Mike Ruberry 2020-08-30 15:42:19 -07:00 committed by Facebook GitHub Bot
parent 3dc9645430
commit 3aeb70db0b
11 changed files with 122 additions and 27 deletions

View file

@ -654,8 +654,6 @@ _(aten, stft) \
_(aten, storage_offset) \
_(aten, stride) \
_(aten, strides) \
_(aten, sub) \
_(aten, sub_) \
_(aten, rsub) \
_(aten, sum) \
_(aten, sum_to_size) \

View file

@ -214,6 +214,10 @@ namespace c10 {
_(aten, list) \
_(aten, wait) \
_(aten, save) \
_(aten, sub) \
_(aten, sub_) \
_(aten, subtract) \
_(aten, subtract_) \
_(aten, keys) \
_(aten, ord) \
_(aten, chr) \

View file

@ -48,6 +48,12 @@ DEFINE_DISPATCH(lcm_stub);
DEFINE_DISPATCH(hypot_stub);
DEFINE_DISPATCH(nextafter_stub);
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return tensor;
}
Tensor& add_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
auto iter = TensorIterator::binary_op(result, self, other);
alpha_check(iter.dtype(), alpha);
@ -275,6 +281,35 @@ Tensor& sub_(Tensor& self, const Tensor& other, Scalar alpha) {
return native::sub_out(self, self, other, alpha);
}
Tensor sub(const Tensor& self, Scalar other, Scalar alpha) {
return native::sub(self, wrapped_scalar_tensor(other), alpha);
}
Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) {
return native::sub_(self, wrapped_scalar_tensor(other), alpha);
}
// subtract, alias for sub
Tensor& subtract_out(Tensor& result, const Tensor& self, const Tensor& other, Scalar alpha) {
return at::sub_out(result, self, other, alpha);
}
Tensor subtract(const Tensor& self, const Tensor& other, Scalar alpha) {
return self.sub(other, alpha);
}
Tensor& subtract_(Tensor& self, const Tensor& other, Scalar alpha) {
return self.sub_(other, alpha);
}
Tensor subtract(const Tensor& self, Scalar other, Scalar alpha) {
return self.sub(other, alpha);
}
Tensor& subtract_(Tensor& self, Scalar other, Scalar alpha) {
return self.sub_(other, alpha);
}
Tensor& sigmoid_backward_out(Tensor& result, const Tensor& grad_output, const Tensor& output) {
auto iter = TensorIterator::binary_op(result, grad_output, output);
sigmoid_backward_stub(iter.device_type(), iter);
@ -346,12 +381,6 @@ Tensor& atan2_(Tensor& self, const Tensor& other) {
// types (int, float, etc.) to Tensor (only to Scalar). They're not exposed
// to Python.
static Tensor wrapped_scalar_tensor(Scalar scalar) {
auto tensor = scalar_to_tensor(scalar);
tensor.unsafeGetTensorImpl()->set_wrapped_number(true);
return tensor;
}
static void check_convert(Scalar scalar, ScalarType scalarType) {
// Validate that is possible to convert scalar to tensor dtype without overflow
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half, scalarType, "check_convert", [&]{
@ -418,14 +447,6 @@ Tensor& mul_(Tensor& self, Scalar other) {
return native::mul_(self, wrapped_scalar_tensor(other));
}
Tensor sub(const Tensor& self, Scalar other, Scalar alpha) {
return native::sub(self, wrapped_scalar_tensor(other), alpha);
}
Tensor& sub_(Tensor& self, Scalar other, Scalar alpha) {
return native::sub_(self, wrapped_scalar_tensor(other), alpha);
}
Tensor rsub(const Tensor& self, Scalar other, Scalar alpha) {
return native::rsub(self, wrapped_scalar_tensor(other), alpha);
}

View file

@ -3609,6 +3609,26 @@
use_c10_dispatcher: full
variants: method
# subtract, alias for sub
- func: subtract.out(Tensor self, Tensor other, *, Scalar alpha=1, Tensor(a!) out) -> Tensor(a!)
- func: subtract.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
variants: function, method
- func: subtract_.Tensor(Tensor(a!) self, Tensor other, *, Scalar alpha=1) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
# For C++ only, until we have conversion from C++ numbers to Tensor
- func: subtract.Scalar(Tensor self, Scalar other, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
variants: function, method
- func: subtract_.Scalar(Tensor(a!) self, Scalar other, Scalar alpha=1) -> Tensor(a!)
use_c10_dispatcher: full
variants: method
- func: rsub.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor
use_c10_dispatcher: full
variants: function
@ -5756,7 +5776,7 @@
CUDA: foreach_tensor_add_scalar_kernel_cuda
- func: _foreach_add_.Scalar(Tensor[](a!) self, Scalar scalar) -> ()
device_guard: False
device_guard: False
variants: function
dispatch:
CPU: foreach_tensor_add_scalar_kernel_slow_

View file

@ -541,6 +541,8 @@ view of a storage and defines numeric operations on it.
.. automethod:: stride
.. automethod:: sub
.. automethod:: sub_
.. automethod:: subtract
.. automethod:: subtract_
.. automethod:: sum
.. automethod:: sum_to_size
.. automethod:: svd

View file

@ -325,6 +325,8 @@ Pointwise Ops
sinh
sqrt
square
sub
subtract
tan
tanh
true_divide

View file

@ -80,6 +80,14 @@ alias_infos = (
lambda d: torch.clamp(torch.randn(20, device=d), -1, 1)),
AliasInfo('arctanh_', torch.Tensor.arctanh_, 'atanh_', torch.Tensor.atanh_,
lambda d: torch.clamp(torch.randn(20, device=d), -1, 1)),
AliasInfo('subtract', torch.subtract, 'sub', torch.sub,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
AliasInfo('subtract_', torch.Tensor.subtract_, 'sub_', torch.Tensor.sub_,
lambda d: torch.randn(20, device=d),
get_args=lambda d: (torch.randn(20, device=d),),
decorators=(onlyCPU,)),
)
# Placeholder test class for validating that aliases are correctly

View file

@ -3167,18 +3167,10 @@ Example::
""")
add_docstr_all('sub',
r"""
add_docstr_all('sub', r"""
sub(other, *, alpha=1) -> Tensor
Subtracts a scalar or tensor from :attr:`self` tensor. If both :attr:`alpha`
and :attr:`other` are specified, each element of :attr:`other` is scaled by
:attr:`alpha` before being used.
When :attr:`other` is a tensor, the shape of :attr:`other` must be
:ref:`broadcastable <broadcasting-semantics>` with the shape of the underlying
tensor.
See :func:`torch.sub`.
""")
add_docstr_all('sub_',
@ -3188,6 +3180,18 @@ sub_(other, *, alpha=1) -> Tensor
In-place version of :meth:`~Tensor.sub`
""")
add_docstr_all('subtract', r"""
subtract(other, *, alpha=1) -> Tensor
See :func:`torch.subtract`.
""")
add_docstr_all('subtract_', r"""
subtract_(other, *, alpha=1) -> Tensor
In-place version of :meth:`~Tensor.subtract`.
""")
add_docstr_all('sum',
r"""
sum(dim=None, keepdim=False, dtype=None) -> Tensor

View file

@ -6707,6 +6707,40 @@ Example::
(tensor([0.9110, 0.8197, 1.2552, 1.0608]), tensor([-0.6871, 0.6229, 0.2169, -0.9058]))
""".format(**multi_dim_common))
add_docstr(torch.sub, r"""
sub(input, other, *, alpha=1, out=None) -> Tensor
Subtracts :attr:`other`, scaled by :attr:`alpha`, from :attr:`input`.
.. math::
\text{{out}}_i = \text{{input}}_i - \text{{alpha}} \times \text{{other}}_i
""" + r"""
Supports :ref:`broadcasting to a common shape <broadcasting-semantics>`,
:ref:`type promotion <type-promotion-doc>`, and integer, float, and complex inputs.
Args:
{input}
other (Tensor or Scalar): the tensor or scalar to subtract from :attr:`input`
Keyword args:
alpha (Scalar): the scalar multiplier for :attr:`other`
{out}
Example::
>>> a = torch.tensor((1, 2))
>>> b = torch.tensor((0, 1))
>>> torch.sub(a, b, alpha=2)
tensor([1, 0])
""".format(**common_args))
add_docstr(torch.subtract, r"""
subtract(input, other, *, alpha=1, out=None) -> Tensor
Alias for :func:`torch.sub`.
""")
add_docstr(torch.sum,
r"""
sum(input, dtype=None) -> Tensor

View file

@ -19,6 +19,7 @@ static const std::unordered_map<Symbol, Symbol> alias_map = {
{aten::arctanh, aten::atanh}, {aten::arctanh_, aten::atanh_},
{aten::fix, aten::trunc}, {aten::fix_, aten::trunc_},
{aten::negative, aten::neg}, {aten::negative_, aten::neg_},
{aten::subtract, aten::sub}, {aten::subtract_, aten::sub_},
};
void replaceNodeWithNewSymbol(Node* node, Symbol new_symbol) {

View file

@ -707,6 +707,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.stft: (lambda input, n_fft, hop_length=None, win_length=None, window=None, center=True,
pad_mode='reflect', normalized=False, onesided=True: -1),
torch.sub: lambda input, other, out=None: -1,
torch.subtract: lambda input, other, out=None: -1,
torch.sum: lambda input, dim=None: -1,
torch.nansum: lambda input, dim=None: -1,
torch.svd: lambda input, some=True, compute_uv=True, out=None: -1,