Revert "Fix split decomp returning self (#140065)"

This reverts commit 9d99dceb53.

Reverted https://github.com/pytorch/pytorch/pull/140065 on behalf of https://github.com/ZainRizvi due to Diff been imported internally, but merged externally. And the internal diff has been updated so the diff and PR are now mismatched.  Reverting this PR to get things back into a consistent state. See D65635070 ([comment](https://github.com/pytorch/pytorch/pull/140065#issuecomment-2465928027))
This commit is contained in:
PyTorch MergeBot 2024-11-09 00:16:26 +00:00
parent a02e88d19c
commit 7eb66173e2
2 changed files with 2 additions and 12 deletions

View file

@ -21,7 +21,6 @@ import torch.utils._pytree as pytree
from torch import distributed as dist
from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
from torch._dispatch.python import enable_python_dispatcher
from torch._dynamo.testing import make_test_cls_with_patches, rand_strided
from torch._guards import tracing, TracingContext
from torch._higher_order_ops.scan import scan
@ -190,6 +189,7 @@ class FakeTensorTest(TestCase):
self.assertEqual(torch.ones([10]), out[0])
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
with FakeTensorMode() as mode:
@ -431,16 +431,6 @@ class FakeTensorTest(TestCase):
self.assertTrue(out[1].is_contiguous())
self.checkMetaProps(out[0], out[1])
def test_split_return_self(self):
def fn(x):
return torch.functional.split(x, 0)[0]
with FakeTensorMode(), enable_python_dispatcher():
out_fake = fn(torch.empty((0,)))
out_eager = fn(torch.empty((0,)))
self.checkMetaProps(out_fake, out_eager)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_cpu_fallback(self):
with FakeTensorMode(allow_fallback_kernels=False):

View file

@ -1431,7 +1431,7 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
dim_size = input_sizes[dim]
if split_size == 0:
assert dim_size == 0
return (self.detach(),)
return (self,)
chunks = (dim_size + split_size - 1) // split_size
# Avoid importing sympy at a module level