Add fake_impl for unique_consecutive (#145649)

Summary:
It's fairly similar to torch.unique and torch.unique_dim.

Test Plan:
New test

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145649
Approved by: https://github.com/ezyang, https://github.com/eellison
This commit is contained in:
rzou 2025-01-29 15:27:26 +00:00 committed by PyTorch MergeBot
parent 1e57154af3
commit 2e5886dcc4
4 changed files with 44 additions and 4 deletions

View file

@ -2643,6 +2643,18 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(r.dtype, torch.int64)
self.assertEqual(cnts.frame_count, 1)
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
def test_unique_consecutive(self):
x = torch.tensor([1, 1, 2, 2, 1, 3])
def fn(x):
return torch.unique_consecutive(x)
expected = fn(x)
opt_fn = torch.compile(fn, fullgraph=True, backend="eager")
result = opt_fn(x)
self.assertEqual(result, expected)
def test_numpy_unique_f16(self):
def fn():
x = np.asarray([1, 1, 2, 2, 3], dtype=np.float16)

View file

@ -2011,7 +2011,6 @@ symbolic_tensor_failures = {
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
}

View file

@ -274,7 +274,15 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
def _unique(
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
fake_mode,
func,
arg,
dim,
sorted=True,
return_inverse=False,
return_counts=False,
*,
unique_consecutive=False,
):
if (
fake_mode.shape_env is None
@ -283,8 +291,10 @@ def _unique(
# Without symints/symfloats, cannot handle this
raise DynamicOutputShapeException(func)
nnz = arg.unique_consecutive_memo if unique_consecutive else arg.unique_memo
# Do not use a memo for unique_dim
if dim is not None or (nnz := arg.unique_memo) is None:
if dim is not None or nnz is None:
# Avoid importing sympy at a module level
from torch.fx.experimental.symbolic_shapes import (
_constrain_range_for_size,
@ -313,7 +323,10 @@ def _unique(
_constrain_range_for_size(nnz, max=maxval)
if dim is None:
arg.unique_memo = nnz
if unique_consecutive:
arg.unique_consecutive_memo = nnz
else:
arg.unique_memo = nnz
if dim is None:
ret = [arg.new_empty((nnz,))]
@ -359,6 +372,20 @@ def unique_dim(
)
@register_op_impl(aten.unique_consecutive.default)
def _(fake_mode, func, arg, return_inverse=False, return_counts=False, dim=None):
return _unique(
fake_mode,
func,
arg,
dim,
False,
return_inverse,
return_counts,
unique_consecutive=True,
)
@register_op_impl(aten.repeat_interleave.Tensor)
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
if output_size is None:

View file

@ -619,6 +619,7 @@ class FakeTensor(Tensor):
nonzero_memo = SymNumberMemoDescriptor()
item_memo = SymNumberMemoDescriptor()
unique_memo = SymNumberMemoDescriptor()
unique_consecutive_memo = SymNumberMemoDescriptor()
# We expect nested_int_memo to be None when an offsets is a graph
# intermediate, or an input that has never been associated with a
@ -721,6 +722,7 @@ class FakeTensor(Tensor):
self.nonzero_memo = None
self.item_memo = None
self.unique_memo = None
self.unique_consecutive_memo = None
self.nested_int_memo = None
if FakeTensorConfig.debug: