mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "Fix split decomp returning self (#140065)"
This reverts commit 9d99dceb53.
Reverted https://github.com/pytorch/pytorch/pull/140065 on behalf of https://github.com/ZainRizvi due to Diff been imported internally, but merged externally. And the internal diff has been updated so the diff and PR are now mismatched. Reverting this PR to get things back into a consistent state. See D65635070 ([comment](https://github.com/pytorch/pytorch/pull/140065#issuecomment-2465928027))
This commit is contained in:
parent
a02e88d19c
commit
7eb66173e2
2 changed files with 2 additions and 12 deletions
|
|
@ -21,7 +21,6 @@ import torch.utils._pytree as pytree
|
|||
|
||||
from torch import distributed as dist
|
||||
from torch._C._functorch import _add_batch_dim, get_unwrapped, is_batchedtensor
|
||||
from torch._dispatch.python import enable_python_dispatcher
|
||||
from torch._dynamo.testing import make_test_cls_with_patches, rand_strided
|
||||
from torch._guards import tracing, TracingContext
|
||||
from torch._higher_order_ops.scan import scan
|
||||
|
|
@ -190,6 +189,7 @@ class FakeTensorTest(TestCase):
|
|||
|
||||
self.assertEqual(torch.ones([10]), out[0])
|
||||
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_zero_dim(self):
|
||||
with FakeTensorMode() as mode:
|
||||
|
|
@ -431,16 +431,6 @@ class FakeTensorTest(TestCase):
|
|||
self.assertTrue(out[1].is_contiguous())
|
||||
self.checkMetaProps(out[0], out[1])
|
||||
|
||||
def test_split_return_self(self):
|
||||
def fn(x):
|
||||
return torch.functional.split(x, 0)[0]
|
||||
|
||||
with FakeTensorMode(), enable_python_dispatcher():
|
||||
out_fake = fn(torch.empty((0,)))
|
||||
|
||||
out_eager = fn(torch.empty((0,)))
|
||||
self.checkMetaProps(out_fake, out_eager)
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_cpu_fallback(self):
|
||||
with FakeTensorMode(allow_fallback_kernels=False):
|
||||
|
|
|
|||
|
|
@ -1431,7 +1431,7 @@ def split(self: Tensor, split_size: int, dim: int = 0) -> Tuple[Tensor, ...]:
|
|||
dim_size = input_sizes[dim]
|
||||
if split_size == 0:
|
||||
assert dim_size == 0
|
||||
return (self.detach(),)
|
||||
return (self,)
|
||||
chunks = (dim_size + split_size - 1) // split_size
|
||||
|
||||
# Avoid importing sympy at a module level
|
||||
|
|
|
|||
Loading…
Reference in a new issue