mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Implement NumPy-like function torch.msort() (#48440)
Summary: - Related with https://github.com/pytorch/pytorch/issues/38349 - Implementing the NumPy-like function `torch.msort()` . Pull Request resolved: https://github.com/pytorch/pytorch/pull/48440 Reviewed By: bdhirsh Differential Revision: D25265753 Pulled By: mruberry fbshipit-source-id: 7709ac5e5667e7541a3dc9048b9c9896b1a6dfa1
This commit is contained in:
parent
cb285080b0
commit
6ab84ca0f3
10 changed files with 90 additions and 1 deletions
|
|
@ -497,6 +497,7 @@ _(aten, mode) \
|
|||
_(aten, mse_loss) \
|
||||
_(aten, mse_loss_backward) \
|
||||
_(aten, mse_loss_forward) \
|
||||
_(aten, msort) \
|
||||
_(aten, multi_margin_loss) \
|
||||
_(aten, multi_margin_loss_backward) \
|
||||
_(aten, multi_margin_loss_forward) \
|
||||
|
|
|
|||
|
|
@ -708,5 +708,15 @@ std::tuple<Tensor, Tensor> sort_cpu(
|
|||
return sort_out_cpu(values, indices, self, dim, descending);
|
||||
}
|
||||
|
||||
Tensor& msort_out(Tensor& values, const Tensor& self) {
|
||||
Tensor indices = at::empty({0}, self.options().dtype(kLong));
|
||||
at::sort_out(values, indices, self, 0, false);
|
||||
return values;
|
||||
}
|
||||
|
||||
Tensor msort(const Tensor& self) {
|
||||
return std::get<0>(at::sort(self, 0, false));
|
||||
}
|
||||
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
|
|
|
|||
|
|
@ -6779,6 +6779,16 @@
|
|||
- func: sort.dimname(Tensor self, Dimname dim, bool descending=False) -> (Tensor values, Tensor indices)
|
||||
variants: method, function
|
||||
|
||||
- func: msort.out(Tensor self, *, Tensor(a!) out) -> Tensor(a!)
|
||||
dispatch:
|
||||
Math: msort_out
|
||||
|
||||
- func: msort(Tensor self) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
dispatch:
|
||||
Math: msort
|
||||
|
||||
- func: argsort(Tensor self, int dim=-1, bool descending=False) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
variants: method, function
|
||||
|
|
|
|||
|
|
@ -463,6 +463,7 @@ view of a storage and defines numeric operations on it.
|
|||
.. automethod:: mode
|
||||
.. automethod:: movedim
|
||||
.. automethod:: moveaxis
|
||||
.. automethod:: msort
|
||||
.. automethod:: mul
|
||||
.. automethod:: mul_
|
||||
.. automethod:: multiply
|
||||
|
|
|
|||
|
|
@ -414,6 +414,7 @@ Comparison Ops
|
|||
not_equal
|
||||
sort
|
||||
topk
|
||||
msort
|
||||
|
||||
|
||||
Spectral Ops
|
||||
|
|
|
|||
|
|
@ -1,11 +1,12 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
import random
|
||||
from torch._six import nan
|
||||
from itertools import product
|
||||
|
||||
from torch.testing._internal.common_utils import \
|
||||
(TestCase, run_tests)
|
||||
(TestCase, run_tests, make_tensor)
|
||||
from torch.testing._internal.common_device_type import \
|
||||
(instantiate_device_type_tests, dtypes, onlyOnCPUAndCUDA,
|
||||
skipCUDAIfRocm, onlyCUDA, dtypesIfCUDA)
|
||||
|
|
@ -112,6 +113,33 @@ class TestSortAndSelect(TestCase):
|
|||
self.assertIsOrdered('descending', x, res2val, res2ind,
|
||||
'random with NaNs')
|
||||
|
||||
@dtypes(*(torch.testing.get_all_int_dtypes() + torch.testing.get_all_fp_dtypes(include_bfloat16=False)))
|
||||
def test_msort(self, device, dtype):
|
||||
def test(shape):
|
||||
tensor = make_tensor(shape, device, dtype, low=-9, high=9)
|
||||
if tensor.size() != torch.Size([]):
|
||||
expected = torch.from_numpy(np.msort(tensor.cpu().numpy()))
|
||||
else:
|
||||
expected = tensor # numpy.msort() does not support empty shapes tensor
|
||||
|
||||
result = torch.msort(tensor)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
out = torch.empty_like(result)
|
||||
torch.msort(tensor, out=out)
|
||||
self.assertEqual(out, expected)
|
||||
|
||||
shapes = (
|
||||
[],
|
||||
[0, ],
|
||||
[20, ],
|
||||
[1, 20],
|
||||
[30, 30],
|
||||
[10, 20, 30]
|
||||
)
|
||||
for shape in shapes:
|
||||
test(shape)
|
||||
|
||||
def test_topk(self, device):
|
||||
def topKViaSort(t, k, dim, dir):
|
||||
sorted, indices = t.sort(dim, dir)
|
||||
|
|
|
|||
|
|
@ -3382,6 +3382,13 @@ sort(dim=-1, descending=False) -> (Tensor, LongTensor)
|
|||
See :func:`torch.sort`
|
||||
""")
|
||||
|
||||
add_docstr_all('msort',
|
||||
r"""
|
||||
msort() -> Tensor
|
||||
|
||||
See :func:`torch.msort`
|
||||
""")
|
||||
|
||||
add_docstr_all('argsort',
|
||||
r"""
|
||||
argsort(dim=-1, descending=False) -> LongTensor
|
||||
|
|
|
|||
|
|
@ -7556,6 +7556,35 @@ Example::
|
|||
[3, 2, 1, 0]])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.msort,
|
||||
r"""
|
||||
msort(input, *, out=None) -> Tensor
|
||||
|
||||
Sorts the elements of the :attr:`input` tensor along its first dimension
|
||||
in ascending order by value.
|
||||
|
||||
.. note:: `torch.msort(t)` is equivalent to `torch.sort(t, dim=0)[0]`.
|
||||
See also :func:`torch.sort`.
|
||||
|
||||
Args:
|
||||
{input}
|
||||
|
||||
Keyword args:
|
||||
{out}
|
||||
|
||||
Example::
|
||||
|
||||
>>> t = torch.randn(3, 4)
|
||||
>>> t
|
||||
tensor([[-0.1321, 0.4370, -1.2631, -1.1289],
|
||||
[-2.0527, -1.1250, 0.2275, 0.3077],
|
||||
[-0.0881, -0.1259, -0.5495, 1.0284]])
|
||||
>>> torch.msort(t)
|
||||
tensor([[-2.0527, -1.1250, -1.2631, -1.1289],
|
||||
[-0.1321, -0.1259, -0.5495, 0.3077],
|
||||
[-0.0881, 0.4370, 0.2275, 1.0284]])
|
||||
""".format(**common_args))
|
||||
|
||||
add_docstr(torch.sparse_coo_tensor,
|
||||
r"""
|
||||
sparse_coo_tensor(indices, values, size=None, *, dtype=None, device=None, requires_grad=False) -> Tensor
|
||||
|
|
|
|||
|
|
@ -525,6 +525,7 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
|
|||
torch.mode: lambda input, dim=-1, keepdim=False, out=None: -1,
|
||||
torch.movedim: lambda input, source, destination: -1,
|
||||
torch.moveaxis: lambda input, source, destination: -1,
|
||||
torch.msort: lambda input, descending=False, out=None: -1,
|
||||
torch.mul: lambda input, other, out=None: -1,
|
||||
torch.multiply: lambda input, other, out=None: -1,
|
||||
torch.multinomial: lambda input, num_samples, replacement=False, out=None: -1,
|
||||
|
|
|
|||
|
|
@ -1721,6 +1721,7 @@ def method_tests():
|
|||
('sort', (), NO_ARGS, 'scalar'),
|
||||
('sort', (), (0,), 'dim_scalar'),
|
||||
('sort', (), (0, True), 'dim_desc_scalar'),
|
||||
('msort', (S, M, S), NO_ARGS),
|
||||
('topk', (S, M, S), (3,)),
|
||||
('topk', (S, M, S), (3, 1), 'dim', (), [1]),
|
||||
('topk', (S, M, S), (3, 1, True), 'dim_desc', (), [1]),
|
||||
|
|
|
|||
Loading…
Reference in a new issue