Revert "Remove split functional wrapper (#74727)"

This reverts commit a58876ace7.

Reverted https://github.com/pytorch/pytorch/pull/74727 on behalf of https://github.com/seemethere due to Fails internal use cases, might extend out to external use cases as well. Need to assess overall impact of this change more widely
This commit is contained in:
PyTorch MergeBot 2022-08-10 19:45:23 +00:00
parent 651c13166c
commit f534b2c627
11 changed files with 83 additions and 92 deletions

View file

@ -17,7 +17,7 @@ graph {
}
}
node {
input: "onnx::Split_0"
input: "tensor"
input: "onnx::Split_1"
output: "2"
output: "3"
@ -32,7 +32,7 @@ graph {
}
name: "torch_jit"
input {
name: "onnx::Split_0"
name: "tensor"
type {
tensor_type {
elem_type: 1

View file

@ -17,7 +17,7 @@ graph {
}
}
node {
input: "onnx::Split_0"
input: "tensor"
input: "onnx::Split_1"
output: "2"
output: "3"
@ -32,7 +32,7 @@ graph {
}
name: "torch_jit"
input {
name: "onnx::Split_0"
name: "tensor"
type {
tensor_type {
elem_type: 1

View file

@ -1134,7 +1134,7 @@ class TestTorchFunctionMode(TestCase):
with A():
self.assertEqual(torch.randn(3), -1)
self.assertEqual(torch.add(x, x), -1)
self.assertEqual(torch.nn.functional.dropout(None, 0.5), -1) # python side
self.assertEqual(torch.split(None, [2]), -1) # python side
self.assertEqual(bar(x), -1)
def test_factory_override(self):

View file

@ -393,13 +393,11 @@ $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)'''
def test_list_ret(self) -> None:
# test all sequence types are permissible returns
for list_type in (list, tuple):
class A(torch.Tensor):
class A(torch._C._TensorBase):
@staticmethod
def __new__(cls, elem):
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if func.overloadpacket == torch.ops.aten.split:

View file

@ -3797,28 +3797,6 @@ else:
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)])
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])
def test_split_deprecated_overloads(self, device):
x = torch.randn((10, 1, 3, 2), device=device)
# split_sizes_or_sections is deprecated in favour of split_sizes
self.assertEqual(
torch.split(x, split_size_or_sections=2, dim=0),
torch.split(x, split_size=2, dim=0),
)
self.assertEqual(
torch.split(x, split_size_or_sections=5),
torch.split(x, split_size=5),
)
self.assertEqual(
torch.split(x, split_size_or_sections=[6, 4], dim=0),
torch.split(x, split_size=[6, 4], dim=0),
)
self.assertEqual(
torch.split(x, split_size_or_sections=[7, 3]),
torch.split(x, split_size=[7, 3]),
)
# functions that operate over a dimension but don't reduce.
def test_dim_function_empty(self, device):
shape = (0, 1, 2, 0)

View file

@ -132,9 +132,3 @@
- name: sub(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
aten: sub_out(out, self, other, alpha)
- name: split(Tensor(a -> *) self, *, int split_size_or_sections, int dim=0) -> Tensor(a)[]
aten: split(self, split_size_or_sections, dim)
- name: split(Tensor(a -> *) self, *, int[] split_size_or_sections, int dim=0) -> Tensor(a)[]
aten: split(self, split_size_or_sections, dim)

View file

@ -112,6 +112,7 @@ blocklist = [
"chain_matmul",
"stft",
"tensordot",
"split",
"unique_consecutive",
"atleast_1d",
"atleast_2d",
@ -698,6 +699,10 @@ def gen_pyi(
"def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
"def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...",
],
"split": [
"def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...",
"def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...",
],
"div": [
"def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
],

View file

@ -757,6 +757,23 @@ class Tensor(torch._C._TensorBase):
return Resize.apply(self, tensor.size())
def split(self, split_size, dim=0):
r"""See :func:`torch.split`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.split, (self,), self, split_size, dim=dim
)
if isinstance(split_size, int):
return super(Tensor, self).split(split_size, dim)
elif isinstance(split_size, Tensor):
try:
split_size = int(split_size)
return super(Tensor, self).split(split_size, dim)
except ValueError:
return super(Tensor, self).split_with_sizes(split_size, dim)
else:
return super(Tensor, self).split_with_sizes(split_size, dim)
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
r"""Returns the unique elements of the input tensor.

View file

@ -5988,15 +5988,6 @@ See :func:`torch.tensor_split`
""",
)
add_docstr_all(
"split",
r"""
split(split_size, dim=0) -> List of Tensors
See :func:`torch.split`
""",
)
add_docstr_all(
"hsplit",
r"""

View file

@ -1954,7 +1954,7 @@ on inplace modification of the outputs.
add_docstr(
torch.unsafe_split,
r"""
unsafe_split(tensor, split_size, dim=0) -> List of Tensors
unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors
Works like :func:`torch.split` but without enforcing the autograd restrictions
on inplace modification of the outputs.
@ -11465,52 +11465,6 @@ Example::
),
)
add_docstr(
torch.split,
r"""
split(input, split_size, dim=0) -> List[Tensor]
Splits the tensor into chunks. Each chunk is a view of the original tensor.
If :attr:`split_size` is an integer type, then :attr:`tensor` will
be split into equally sized chunks (if possible). Last chunk will be smaller if
the tensor size along the given dimension :attr:`dim` is not divisible by
:attr:`split_size`.
If :attr:`split_size` is a list, then :attr:`tensor` will be split
into ``len(split_size)`` chunks with sizes in :attr:`dim` according
to :attr:`split_size`.
Args:
tensor (Tensor): tensor to split.
split_size (int) or (list(int)): size of a single chunk or
list of sizes for each chunk
dim (int): dimension along which to split the tensor.
Example::
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
""",
)
add_docstr(
torch.take,
r"""

View file

@ -32,6 +32,7 @@ __all__ = [
'norm',
'meshgrid',
'pca_lowrank',
'split',
'stft',
'svd_lowrank',
'tensordot',
@ -135,6 +136,59 @@ def broadcast_shapes(*shapes):
return tensors[0].shape
def split(
tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
) -> List[Tensor]:
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
be split into equally sized chunks (if possible). Last chunk will be smaller if
the tensor size along the given dimension :attr:`dim` is not divisible by
:attr:`split_size`.
If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
to :attr:`split_size_or_sections`.
Args:
tensor (Tensor): tensor to split.
split_size_or_sections (int) or (list(int)): size of a single chunk or
list of sizes for each chunk
dim (int): dimension along which to split the tensor.
Example::
>>> a = torch.arange(10).reshape(5,2)
>>> a
tensor([[0, 1],
[2, 3],
[4, 5],
[6, 7],
[8, 9]])
>>> torch.split(a, 2)
(tensor([[0, 1],
[2, 3]]),
tensor([[4, 5],
[6, 7]]),
tensor([[8, 9]]))
>>> torch.split(a, [1,4])
(tensor([[0, 1]]),
tensor([[2, 3],
[4, 5],
[6, 7],
[8, 9]]))
"""
if has_torch_function_unary(tensor):
return handle_torch_function(
split, (tensor,), tensor, split_size_or_sections, dim=dim)
# Overwriting reason:
# This dispatches to two ATen functions depending on the type of
# split_size_or_sections. The branching code is in _tensor.py, which we
# call here.
return tensor.split(split_size_or_sections, dim)
def einsum(*args: Any) -> Tensor:
r"""einsum(equation, *operands) -> Tensor