Implement NumPy-like function torch.msort() (#48440)

Summary:
- Related with https://github.com/pytorch/pytorch/issues/38349
- Implementing the NumPy-like function `torch.msort()` .

Pull Request resolved: https://github.com/pytorch/pytorch/pull/48440

Reviewed By: bdhirsh

Differential Revision: D25265753

Pulled By: mruberry

fbshipit-source-id: 7709ac5e5667e7541a3dc9048b9c9896b1a6dfa1
This commit is contained in:
kiyosora 2020-12-04 04:30:23 -08:00 committed by Facebook GitHub Bot
parent cb285080b0
commit 6ab84ca0f3
10 changed files with 90 additions and 1 deletions

View file

@ -497,6 +497,7 @@ _(aten, mode) \
_(aten, mse_loss) \
_(aten, mse_loss_backward) \
_(aten, mse_loss_forward) \
_(aten, msort) \
_(aten, multi_margin_loss) \
_(aten, multi_margin_loss_backward) \
_(aten, multi_margin_loss_forward) \

View file

@ -708,5 +708,15 @@ std::tuple<Tensor, Tensor> sort_cpu(
return sort_out_cpu(values, indices, self, dim, descending);
}
Tensor& msort_out(Tensor& values, const Tensor& self) {
Tensor indices = at::empty({0}, self.options().dtype(kLong));
at::sort_out(values, indices, self, 0, false);
return values;
}
Tensor msort(const Tensor& self) {
return std::get<0>(at::sort(self, 0, false));
}
} // namespace native
} // namespace at

View file

@ -6779,6 +6779,16 @@
- func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
variants: method, function
- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
dispatch:
Math: msort_out
- func: msort(Tensor self) -> Tensor
use_c10_dispatcher: full
variants: method, function
dispatch:
Math: msort
- func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor
use_c10_dispatcher: full
variants: method, function

View file

@ -463,6 +463,7 @@ view of a storage and defines numeric operations on it.
.. automethod:: mode
.. automethod:: movedim
.. automethod:: moveaxis
.. automethod:: msort
.. automethod:: mul
.. automethod:: mul_
.. automethod:: multiply

View file

@ -414,6 +414,7 @@ Comparison Ops
not_equal
sort
topk
msort
Spectral Ops

View file

@ -1,11 +1,12 @@
import torch
import numpy as np
import random
from torch._six import nan
from itertools import product
from torch.testing._internal.common_utils import \
(TestCase, run_tests)
(TestCase, run_tests, make_tensor)
from torch.testing._internal.common_device_type import \
(instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA,
skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA)
@ -112,6 +113,33 @@ class TestSortAndSelect(TestCase):
self.assertIsOrdered('descending', x, res2val, res2ind,
'random with NaNs')
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
def test_msort(self, device, dtype):
def test(shape):
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
if tensor.size() != torch.Size([]):
expected = torch.from_numpy(np.msort(tensor.cpu().numpy()))
else:
expected = tensor # numpy.msort() does not support empty shapes tensor
result = torch.msort(tensor)
self.assertEqual(result, expected)
out = torch.empty_like(result)
torch.msort(tensor, out=out)
self.assertEqual(out, expected)
shapes = (
[],
[0, ],
[20, ],
[1, 20],
[30, 30],
[10, 20, 30]
)
for shape in shapes:
test(shape)
def test_topk(self, device):
def topKViaSort(t, k, dim, dir):
sorted, indices = t.sort(dim, dir)

View file

@ -3382,6 +3382,13 @@ sort(dim=-1, descending=False) -> (Tensor, LongTensor)
See :func:`torch.sort`
""")
add_docstr_all('msort',
r"""
msort() -> Tensor
See :func:`torch.msort`
""")
add_docstr_all('argsort',
r"""
argsort(dim=-1, descending=False) -> LongTensor

View file

@ -7556,6 +7556,35 @@ Example::
[3, 2, 1, 0]])
""".format(**common_args))
add_docstr(torch.msort,
r"""
msort(input, *, out=None) -> Tensor
Sorts the elements of the :attr:`input` tensor along its first dimension
in ascending order by value.
.. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`.
See also :func:`torch.sort`.
Args:
{input}
Keyword args:
{out}
Example::
>>> t = torch.randn(3, 4)
>>> t
tensor([[-0.1321, 0.4370, -1.2631, -1.1289],
[-2.0527, -1.1250, 0.2275, 0.3077],
[-0.0881, -0.1259, -0.5495, 1.0284]])
>>> torch.msort(t)
tensor([[-2.0527, -1.1250, -1.2631, -1.1289],
[-0.1321, -0.1259, -0.5495, 0.3077],
[-0.0881, 0.4370, 0.2275, 1.0284]])
""".format(**common_args))
add_docstr(torch.sparse_coo_tensor,
r"""
sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor

View file

@ -525,6 +525,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
torch.movedim: lambda input, source, destination: -1,
torch.moveaxis: lambda input, source, destination: -1,
torch.msort: lambda input, descending=False, out=None: -1,
torch.mul: lambda input, other, out=None: -1,
torch.multiply: lambda input, other, out=None: -1,
torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,

View file

@ -1721,6 +1721,7 @@ def method_tests():
('sort', (), NO_ARGS, 'scalar'),
('sort', (), (0,), 'dim_scalar'),
('sort', (), (0, True), 'dim_desc_scalar'),
('msort', (S, M, S), NO_ARGS),
('topk', (S, M, S), (3,)),
('topk', (S, M, S), (3, 1), 'dim', (), [1]),
('topk', (S, M, S), (3, 1, True), 'dim_desc', (), [1]),