From 7eb66173e2b321985efefd319d75882a56ff89ef Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Sat, 9 Nov 2024 00:16:26 +0000 Subject: [PATCH] Revert "Fix split decomp returning self (#140065)" This reverts commit 9d99dceb53884387665a2c273beca99a157193a5. Reverted https://github.com/pytorch/pytorch/pull/140065 on behalf of https://github.com/ZainRizvi due to Diff been imported internally, but merged externally. And the internal diff has been updated so the diff and PR are now mismatched. Reverting this PR to get things back into a consistent state. See D65635070 ([comment](https://github.com/pytorch/pytorch/pull/140065#issuecomment-2465928027)) --- test/test_fake_tensor.py | 12 +----------- torch/_decomp/decompositions.py | 2 +- 2 files changed, 2 insertions(+), 12 deletions(-) diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py index 6d515383a7d..e0e76419c9f 100644 --- a/test/test_fake_tensor.py +++ b/test/test_fake_tensor.py @@ -21,7 +21,6 @@ import torch.utils._pytree as pytree from torch import distributed as dist from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor -from torch._dispatch.python import enable_python_dispatcher from torch._dynamo.testing import make_test_cls_with_patches, rand_strided from torch._guards import tracing, TracingContext from torch._higher_order_ops.scan import scan @@ -190,6 +189,7 @@ class FakeTensorTest(TestCase): self.assertEqual(torch.ones([10]), out[0]) + @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_zero_dim(self): with FakeTensorMode() as mode: @@ -431,16 +431,6 @@ class FakeTensorTest(TestCase): self.assertTrue(out[1].is_contiguous()) self.checkMetaProps(out[0], out[1]) - def test_split_return_self(self): - def fn(x): - return torch.functional.split(x, 0)[0] - - with FakeTensorMode(), enable_python_dispatcher(): - out_fake = fn(torch.empty((0,))) - - out_eager = fn(torch.empty((0,))) - self.checkMetaProps(out_fake, out_eager) - @unittest.skipIf(not RUN_CUDA, "requires cuda") def test_cpu_fallback(self): with FakeTensorMode(allow_fallback_kernels=False): diff --git a/torch/_decomp/decompositions.py b/torch/_decomp/decompositions.py index 2a21444ca12..8822a3840aa 100644 --- a/torch/_decomp/decompositions.py +++ b/torch/_decomp/decompositions.py @@ -1431,7 +1431,7 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]: dim_size = input_sizes[dim] if split_size == 0: assert dim_size == 0 - return (self.detach(),) + return (self,) chunks = (dim_size + split_size - 1) // split_size # Avoid importing sympy at a module level