mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Remove split functional wrapper (#74727)"
This reverts commit a58876ace7.
Reverted https://github.com/pytorch/pytorch/pull/74727 on behalf of https://github.com/seemethere due to Fails internal use cases, might extend out to external use cases as well. Need to assess overall impact of this change more widely
This commit is contained in:
parent
651c13166c
commit
f534b2c627
11 changed files with 83 additions and 92 deletions
|
|
@ -17,7 +17,7 @@ graph {
|
|||
}
|
||||
}
|
||||
node {
|
||||
input: "onnx::Split_0"
|
||||
input: "tensor"
|
||||
input: "onnx::Split_1"
|
||||
output: "2"
|
||||
output: "3"
|
||||
|
|
@ -32,7 +32,7 @@ graph {
|
|||
}
|
||||
name: "torch_jit"
|
||||
input {
|
||||
name: "onnx::Split_0"
|
||||
name: "tensor"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ graph {
|
|||
}
|
||||
}
|
||||
node {
|
||||
input: "onnx::Split_0"
|
||||
input: "tensor"
|
||||
input: "onnx::Split_1"
|
||||
output: "2"
|
||||
output: "3"
|
||||
|
|
@ -32,7 +32,7 @@ graph {
|
|||
}
|
||||
name: "torch_jit"
|
||||
input {
|
||||
name: "onnx::Split_0"
|
||||
name: "tensor"
|
||||
type {
|
||||
tensor_type {
|
||||
elem_type: 1
|
||||
|
|
|
|||
|
|
@ -1134,7 +1134,7 @@ class TestTorchFunctionMode(TestCase):
|
|||
with A():
|
||||
self.assertEqual(torch.randn(3), -1)
|
||||
self.assertEqual(torch.add(x, x), -1)
|
||||
self.assertEqual(torch.nn.functional.dropout(None, 0.5), -1) # python side
|
||||
self.assertEqual(torch.split(None, [2]), -1) # python side
|
||||
self.assertEqual(bar(x), -1)
|
||||
|
||||
def test_factory_override(self):
|
||||
|
|
|
|||
|
|
@ -393,13 +393,11 @@ $5 = torch._ops.aten.clone.default($4, memory_format=torch.contiguous_format)'''
|
|||
def test_list_ret(self) -> None:
|
||||
# test all sequence types are permissible returns
|
||||
for list_type in (list, tuple):
|
||||
class A(torch.Tensor):
|
||||
class A(torch._C._TensorBase):
|
||||
@staticmethod
|
||||
def __new__(cls, elem):
|
||||
return torch.Tensor._make_subclass(cls, elem, elem.requires_grad)
|
||||
|
||||
__torch_function__ = torch._C._disabled_torch_function_impl
|
||||
|
||||
@classmethod
|
||||
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
|
||||
if func.overloadpacket == torch.ops.aten.split:
|
||||
|
|
|
|||
|
|
@ -3797,28 +3797,6 @@ else:
|
|||
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 1, dim=0)])
|
||||
self.assertEqual([(0, 1, 3, 0)], [z.shape for z in torch.split(x, 0, dim=0)])
|
||||
|
||||
def test_split_deprecated_overloads(self, device):
|
||||
x = torch.randn((10, 1, 3, 2), device=device)
|
||||
|
||||
# split_sizes_or_sections is deprecated in favour of split_sizes
|
||||
self.assertEqual(
|
||||
torch.split(x, split_size_or_sections=2, dim=0),
|
||||
torch.split(x, split_size=2, dim=0),
|
||||
)
|
||||
self.assertEqual(
|
||||
torch.split(x, split_size_or_sections=5),
|
||||
torch.split(x, split_size=5),
|
||||
)
|
||||
|
||||
self.assertEqual(
|
||||
torch.split(x, split_size_or_sections=[6, 4], dim=0),
|
||||
torch.split(x, split_size=[6, 4], dim=0),
|
||||
)
|
||||
self.assertEqual(
|
||||
torch.split(x, split_size_or_sections=[7, 3]),
|
||||
torch.split(x, split_size=[7, 3]),
|
||||
)
|
||||
|
||||
# functions that operate over a dimension but don't reduce.
|
||||
def test_dim_function_empty(self, device):
|
||||
shape = (0, 1, 2, 0)
|
||||
|
|
|
|||
|
|
@ -132,9 +132,3 @@
|
|||
|
||||
- name: sub(Tensor self, Scalar alpha, Tensor other, *, Tensor(a!) out) -> Tensor(a!)
|
||||
aten: sub_out(out, self, other, alpha)
|
||||
|
||||
- name: split(Tensor(a -> *) self, *, int split_size_or_sections, int dim=0) -> Tensor(a)[]
|
||||
aten: split(self, split_size_or_sections, dim)
|
||||
|
||||
- name: split(Tensor(a -> *) self, *, int[] split_size_or_sections, int dim=0) -> Tensor(a)[]
|
||||
aten: split(self, split_size_or_sections, dim)
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ blocklist = [
|
|||
"chain_matmul",
|
||||
"stft",
|
||||
"tensordot",
|
||||
"split",
|
||||
"unique_consecutive",
|
||||
"atleast_1d",
|
||||
"atleast_2d",
|
||||
|
|
@ -698,6 +699,10 @@ def gen_pyi(
|
|||
"def set_(self, storage: Union[Storage, TypedStorage], offset: _int, size: _size, stride: _size) -> Tensor: ...",
|
||||
"def set_(self, storage: Union[Storage, TypedStorage]) -> Tensor: ...",
|
||||
],
|
||||
"split": [
|
||||
"def split(self, split_size: _int, dim: _int=0) -> Sequence[Tensor]: ...",
|
||||
"def split(self, split_size: Tuple[_int, ...], dim: _int=0) -> Sequence[Tensor]: ...",
|
||||
],
|
||||
"div": [
|
||||
"def div(self, other: Union[Tensor, Number], *, rounding_mode: Optional[str] = None) -> Tensor: ..."
|
||||
],
|
||||
|
|
|
|||
|
|
@ -757,6 +757,23 @@ class Tensor(torch._C._TensorBase):
|
|||
|
||||
return Resize.apply(self, tensor.size())
|
||||
|
||||
def split(self, split_size, dim=0):
|
||||
r"""See :func:`torch.split`"""
|
||||
if has_torch_function_unary(self):
|
||||
return handle_torch_function(
|
||||
Tensor.split, (self,), self, split_size, dim=dim
|
||||
)
|
||||
if isinstance(split_size, int):
|
||||
return super(Tensor, self).split(split_size, dim)
|
||||
elif isinstance(split_size, Tensor):
|
||||
try:
|
||||
split_size = int(split_size)
|
||||
return super(Tensor, self).split(split_size, dim)
|
||||
except ValueError:
|
||||
return super(Tensor, self).split_with_sizes(split_size, dim)
|
||||
else:
|
||||
return super(Tensor, self).split_with_sizes(split_size, dim)
|
||||
|
||||
def unique(self, sorted=True, return_inverse=False, return_counts=False, dim=None):
|
||||
r"""Returns the unique elements of the input tensor.
|
||||
|
||||
|
|
|
|||
|
|
@ -5988,15 +5988,6 @@ See :func:`torch.tensor_split`
|
|||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"split",
|
||||
r"""
|
||||
split(split_size, dim=0) -> List of Tensors
|
||||
|
||||
See :func:`torch.split`
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr_all(
|
||||
"hsplit",
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -1954,7 +1954,7 @@ on inplace modification of the outputs.
|
|||
add_docstr(
|
||||
torch.unsafe_split,
|
||||
r"""
|
||||
unsafe_split(tensor, split_size, dim=0) -> List of Tensors
|
||||
unsafe_split(tensor, split_size_or_sections, dim=0) -> List of Tensors
|
||||
|
||||
Works like :func:`torch.split` but without enforcing the autograd restrictions
|
||||
on inplace modification of the outputs.
|
||||
|
|
@ -11465,52 +11465,6 @@ Example::
|
|||
),
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.split,
|
||||
r"""
|
||||
split(input, split_size, dim=0) -> List[Tensor]
|
||||
|
||||
Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
||||
|
||||
If :attr:`split_size` is an integer type, then :attr:`tensor` will
|
||||
be split into equally sized chunks (if possible). Last chunk will be smaller if
|
||||
the tensor size along the given dimension :attr:`dim` is not divisible by
|
||||
:attr:`split_size`.
|
||||
|
||||
If :attr:`split_size` is a list, then :attr:`tensor` will be split
|
||||
into ``len(split_size)`` chunks with sizes in :attr:`dim` according
|
||||
to :attr:`split_size`.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): tensor to split.
|
||||
split_size (int) or (list(int)): size of a single chunk or
|
||||
list of sizes for each chunk
|
||||
dim (int): dimension along which to split the tensor.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.arange(10).reshape(5,2)
|
||||
>>> a
|
||||
tensor([[0, 1],
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[8, 9]])
|
||||
>>> torch.split(a, 2)
|
||||
(tensor([[0, 1],
|
||||
[2, 3]]),
|
||||
tensor([[4, 5],
|
||||
[6, 7]]),
|
||||
tensor([[8, 9]]))
|
||||
>>> torch.split(a, [1,4])
|
||||
(tensor([[0, 1]]),
|
||||
tensor([[2, 3],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[8, 9]]))
|
||||
""",
|
||||
)
|
||||
|
||||
add_docstr(
|
||||
torch.take,
|
||||
r"""
|
||||
|
|
|
|||
|
|
@ -32,6 +32,7 @@ __all__ = [
|
|||
'norm',
|
||||
'meshgrid',
|
||||
'pca_lowrank',
|
||||
'split',
|
||||
'stft',
|
||||
'svd_lowrank',
|
||||
'tensordot',
|
||||
|
|
@ -135,6 +136,59 @@ def broadcast_shapes(*shapes):
|
|||
return tensors[0].shape
|
||||
|
||||
|
||||
|
||||
def split(
|
||||
tensor: Tensor, split_size_or_sections: Union[int, List[int]], dim: int = 0
|
||||
) -> List[Tensor]:
|
||||
r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
|
||||
|
||||
If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
|
||||
be split into equally sized chunks (if possible). Last chunk will be smaller if
|
||||
the tensor size along the given dimension :attr:`dim` is not divisible by
|
||||
:attr:`split_size`.
|
||||
|
||||
If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
|
||||
into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
|
||||
to :attr:`split_size_or_sections`.
|
||||
|
||||
Args:
|
||||
tensor (Tensor): tensor to split.
|
||||
split_size_or_sections (int) or (list(int)): size of a single chunk or
|
||||
list of sizes for each chunk
|
||||
dim (int): dimension along which to split the tensor.
|
||||
|
||||
Example::
|
||||
|
||||
>>> a = torch.arange(10).reshape(5,2)
|
||||
>>> a
|
||||
tensor([[0, 1],
|
||||
[2, 3],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[8, 9]])
|
||||
>>> torch.split(a, 2)
|
||||
(tensor([[0, 1],
|
||||
[2, 3]]),
|
||||
tensor([[4, 5],
|
||||
[6, 7]]),
|
||||
tensor([[8, 9]]))
|
||||
>>> torch.split(a, [1,4])
|
||||
(tensor([[0, 1]]),
|
||||
tensor([[2, 3],
|
||||
[4, 5],
|
||||
[6, 7],
|
||||
[8, 9]]))
|
||||
"""
|
||||
if has_torch_function_unary(tensor):
|
||||
return handle_torch_function(
|
||||
split, (tensor,), tensor, split_size_or_sections, dim=dim)
|
||||
# Overwriting reason:
|
||||
# This dispatches to two ATen functions depending on the type of
|
||||
# split_size_or_sections. The branching code is in _tensor.py, which we
|
||||
# call here.
|
||||
return tensor.split(split_size_or_sections, dim)
|
||||
|
||||
|
||||
def einsum(*args: Any) -> Tensor:
|
||||
r"""einsum(equation, *operands) -> Tensor
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue