From 2e5886dcc4531a05d8bdd65ecce0933daa35b893 Mon Sep 17 00:00:00 2001 From: rzou Date: Wed, 29 Jan 2025 15:27:26 +0000 Subject: [PATCH] Add fake_impl for unique_consecutive (#145649) Summary: It's fairly similar to torch.unique and torch.unique_dim. Test Plan: New test Pull Request resolved: https://github.com/pytorch/pytorch/pull/145649 Approved by: https://github.com/ezyang, https://github.com/eellison --- test/dynamo/test_misc.py | 12 ++++++++++++ test/test_proxy_tensor.py | 1 - torch/_subclasses/fake_impls.py | 33 +++++++++++++++++++++++++++++--- torch/_subclasses/fake_tensor.py | 2 ++ 4 files changed, 44 insertions(+), 4 deletions(-) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index ab02d5aae27..86a92584ab7 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -2643,6 +2643,18 @@ utils_device.CURRENT_DEVICE == None""".split( self.assertEqual(r.dtype, torch.int64) self.assertEqual(cnts.frame_count, 1) + @torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True) + def test_unique_consecutive(self): + x = torch.tensor([1, 1, 2, 2, 1, 3]) + + def fn(x): + return torch.unique_consecutive(x) + + expected = fn(x) + opt_fn = torch.compile(fn, fullgraph=True, backend="eager") + result = opt_fn(x) + self.assertEqual(result, expected) + def test_numpy_unique_f16(self): def fn(): x = np.asarray([1, 1, 2, 2, 3], dtype=np.float16) diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index 8ecd46fc879..7d9030aa036 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -2011,7 +2011,6 @@ symbolic_tensor_failures = { xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend. - xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but... } diff --git a/torch/_subclasses/fake_impls.py b/torch/_subclasses/fake_impls.py index 304bc284b0f..ad5c3c8b61c 100644 --- a/torch/_subclasses/fake_impls.py +++ b/torch/_subclasses/fake_impls.py @@ -274,7 +274,15 @@ def dyn_shape(fake_mode, func, *args, **kwargs): def _unique( - fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False + fake_mode, + func, + arg, + dim, + sorted=True, + return_inverse=False, + return_counts=False, + *, + unique_consecutive=False, ): if ( fake_mode.shape_env is None @@ -283,8 +291,10 @@ def _unique( # Without symints/symfloats, cannot handle this raise DynamicOutputShapeException(func) + nnz = arg.unique_consecutive_memo if unique_consecutive else arg.unique_memo + # Do not use a memo for unique_dim - if dim is not None or (nnz := arg.unique_memo) is None: + if dim is not None or nnz is None: # Avoid importing sympy at a module level from torch.fx.experimental.symbolic_shapes import ( _constrain_range_for_size, @@ -313,7 +323,10 @@ def _unique( _constrain_range_for_size(nnz, max=maxval) if dim is None: - arg.unique_memo = nnz + if unique_consecutive: + arg.unique_consecutive_memo = nnz + else: + arg.unique_memo = nnz if dim is None: ret = [arg.new_empty((nnz,))] @@ -359,6 +372,20 @@ def unique_dim( ) +@register_op_impl(aten.unique_consecutive.default) +def _(fake_mode, func, arg, return_inverse=False, return_counts=False, dim=None): + return _unique( + fake_mode, + func, + arg, + dim, + False, + return_inverse, + return_counts, + unique_consecutive=True, + ) + + @register_op_impl(aten.repeat_interleave.Tensor) def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None): if output_size is None: diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index f49bc0d5830..adcf264e038 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -619,6 +619,7 @@ class FakeTensor(Tensor): nonzero_memo = SymNumberMemoDescriptor() item_memo = SymNumberMemoDescriptor() unique_memo = SymNumberMemoDescriptor() + unique_consecutive_memo = SymNumberMemoDescriptor() # We expect nested_int_memo to be None when an offsets is a graph # intermediate, or an input that has never been associated with a @@ -721,6 +722,7 @@ class FakeTensor(Tensor): self.nonzero_memo = None self.item_memo = None self.unique_memo = None + self.unique_consecutive_memo = None self.nested_int_memo = None if FakeTensorConfig.debug: