mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Add fake_impl for unique_consecutive (#145649)
Summary: It's fairly similar to torch.unique and torch.unique_dim. Test Plan: New test Pull Request resolved: https://github.com/pytorch/pytorch/pull/145649 Approved by: https://github.com/ezyang, https://github.com/eellison
This commit is contained in:
parent
1e57154af3
commit
2e5886dcc4
4 changed files with 44 additions and 4 deletions
|
|
@ -2643,6 +2643,18 @@ utils_device.CURRENT_DEVICE == None""".split(
|
|||
self.assertEqual(r.dtype, torch.int64)
|
||||
self.assertEqual(cnts.frame_count, 1)
|
||||
|
||||
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
||||
def test_unique_consecutive(self):
|
||||
x = torch.tensor([1, 1, 2, 2, 1, 3])
|
||||
|
||||
def fn(x):
|
||||
return torch.unique_consecutive(x)
|
||||
|
||||
expected = fn(x)
|
||||
opt_fn = torch.compile(fn, fullgraph=True, backend="eager")
|
||||
result = opt_fn(x)
|
||||
self.assertEqual(result, expected)
|
||||
|
||||
def test_numpy_unique_f16(self):
|
||||
def fn():
|
||||
x = np.asarray([1, 1, 2, 2, 3], dtype=np.float16)
|
||||
|
|
|
|||
|
|
@ -2011,7 +2011,6 @@ symbolic_tensor_failures = {
|
|||
xfail('nn.functional.cross_entropy', ''), # aten.size.default - couldn't find symbolic meta function/decomposition
|
||||
xfail('nn.functional.ctc_loss'), # aten._ctc_loss.Tensor - couldn't find symbolic meta function/decomposition
|
||||
xfail('quantile', ''), # Could not run 'aten::equal' with arguments from the 'Meta' backend.
|
||||
xfail('unique_consecutive', ''), # aten.unique_consecutive.default - couldn't find symbolic meta function/decomposition
|
||||
|
||||
xfail('max_pool2d_with_indices_backward', ''), # Expected a value of type 'List[int]' for argument 'kernel_size' but...
|
||||
}
|
||||
|
|
|
|||
|
|
@ -274,7 +274,15 @@ def dyn_shape(fake_mode, func, *args, **kwargs):
|
|||
|
||||
|
||||
def _unique(
|
||||
fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
|
||||
fake_mode,
|
||||
func,
|
||||
arg,
|
||||
dim,
|
||||
sorted=True,
|
||||
return_inverse=False,
|
||||
return_counts=False,
|
||||
*,
|
||||
unique_consecutive=False,
|
||||
):
|
||||
if (
|
||||
fake_mode.shape_env is None
|
||||
|
|
@ -283,8 +291,10 @@ def _unique(
|
|||
# Without symints/symfloats, cannot handle this
|
||||
raise DynamicOutputShapeException(func)
|
||||
|
||||
nnz = arg.unique_consecutive_memo if unique_consecutive else arg.unique_memo
|
||||
|
||||
# Do not use a memo for unique_dim
|
||||
if dim is not None or (nnz := arg.unique_memo) is None:
|
||||
if dim is not None or nnz is None:
|
||||
# Avoid importing sympy at a module level
|
||||
from torch.fx.experimental.symbolic_shapes import (
|
||||
_constrain_range_for_size,
|
||||
|
|
@ -313,7 +323,10 @@ def _unique(
|
|||
_constrain_range_for_size(nnz, max=maxval)
|
||||
|
||||
if dim is None:
|
||||
arg.unique_memo = nnz
|
||||
if unique_consecutive:
|
||||
arg.unique_consecutive_memo = nnz
|
||||
else:
|
||||
arg.unique_memo = nnz
|
||||
|
||||
if dim is None:
|
||||
ret = [arg.new_empty((nnz,))]
|
||||
|
|
@ -359,6 +372,20 @@ def unique_dim(
|
|||
)
|
||||
|
||||
|
||||
@register_op_impl(aten.unique_consecutive.default)
|
||||
def _(fake_mode, func, arg, return_inverse=False, return_counts=False, dim=None):
|
||||
return _unique(
|
||||
fake_mode,
|
||||
func,
|
||||
arg,
|
||||
dim,
|
||||
False,
|
||||
return_inverse,
|
||||
return_counts,
|
||||
unique_consecutive=True,
|
||||
)
|
||||
|
||||
|
||||
@register_op_impl(aten.repeat_interleave.Tensor)
|
||||
def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
|
||||
if output_size is None:
|
||||
|
|
|
|||
|
|
@ -619,6 +619,7 @@ class FakeTensor(Tensor):
|
|||
nonzero_memo = SymNumberMemoDescriptor()
|
||||
item_memo = SymNumberMemoDescriptor()
|
||||
unique_memo = SymNumberMemoDescriptor()
|
||||
unique_consecutive_memo = SymNumberMemoDescriptor()
|
||||
|
||||
# We expect nested_int_memo to be None when an offsets is a graph
|
||||
# intermediate, or an input that has never been associated with a
|
||||
|
|
@ -721,6 +722,7 @@ class FakeTensor(Tensor):
|
|||
self.nonzero_memo = None
|
||||
self.item_memo = None
|
||||
self.unique_memo = None
|
||||
self.unique_consecutive_memo = None
|
||||
self.nested_int_memo = None
|
||||
|
||||
if FakeTensorConfig.debug:
|
||||
|
|
|
|||
Loading…
Reference in a new issue