[dynamo] support maxlen for collections.deque (#138194)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/138194
Approved by: https://github.com/jansel, https://github.com/malfet
This commit is contained in:
Xuehai Pan 2024-10-30 13:55:04 +08:00 committed by PyTorch MergeBot
parent a4b35767cb
commit 9bbe4a67ad
4 changed files with 69 additions and 24 deletions

View file

@ -9963,7 +9963,7 @@ def ___make_guard_fn():
"c": (
x,
3.0,
collections.deque([0.0, -x]),
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
@ -9995,7 +9995,7 @@ def ___make_guard_fn():
"c": (
x,
3.0,
[0.0, -x],
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
@ -10011,6 +10011,7 @@ def ___make_guard_fn():
x * y,
3.0,
y - 2,
1,
torch.zeros(2, 2),
2 * y,
-y,
@ -10043,7 +10044,7 @@ def ___make_guard_fn():
"c": (
x,
3.0,
[0.0, -x],
collections.deque([0.0, -x, 1, 2], maxlen=3),
),
"d": collections.OrderedDict(
{
@ -10054,7 +10055,7 @@ def ___make_guard_fn():
}
tree2 = collections.OrderedDict(
[
("c", (y, 3.0, [-y, 10.0])),
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
("a", [y, y + 1]),
("b", y + 2),
(

View file

@ -452,12 +452,33 @@ class ListVariable(CommonListMethodsVariable):
class DequeVariable(CommonListMethodsVariable):
def __init__(self, items, maxlen=None, **kwargs) -> None:
if maxlen is None:
maxlen = ConstantVariable.create(None)
assert (
maxlen.is_python_constant()
), f"maxlen must be a constant, got: {maxlen.debug_repr()}"
self.maxlen = maxlen
if self.maxlen.as_python_constant() is not None:
items = list(items)[-maxlen.as_python_constant() :]
super().__init__(items, **kwargs)
def python_type(self):
return collections.deque
def debug_repr(self):
if self.maxlen.as_python_constant() is None:
return self.debug_repr_helper(
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
)
return self.debug_repr_helper("deque([", "])")
def as_python_constant(self):
return self.python_type()(
[x.as_python_constant() for x in self.items],
maxlen=self.maxlen.as_python_constant(),
)
def reconstruct(self, codegen: "PyCodegen") -> None:
assert "deque" not in codegen.tx.f_globals
codegen.add_push_null(
@ -466,12 +487,14 @@ class DequeVariable(CommonListMethodsVariable):
)
)
codegen.foreach(self.items)
codegen.extend_output(
[
create_instruction("BUILD_LIST", arg=len(self.items)),
*create_call_function(1, False),
]
)
codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))])
codegen(self.maxlen)
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
def var_getattr(self, tx: "InstructionTranslator", name):
if name == "maxlen":
return self.maxlen
return super().var_getattr(tx, name)
def call_method(
self,
@ -494,33 +517,37 @@ class DequeVariable(CommonListMethodsVariable):
tx.output.side_effects.mutation(self)
self.items[key.as_python_constant()] = value
return ConstantVariable.create(None)
elif (
if (
name == "extendleft"
and self.mutable_local
and args[0].has_force_unpack_var_sequence(tx)
):
assert not kwargs
(arg,) = args
prefix = arg.force_unpack_var_sequence(tx)
prefix.reverse()
tx.output.side_effects.mutation(self)
self.items = prefix + list(self.items)
return ConstantVariable.create(None)
result = ConstantVariable.create(None)
elif name == "popleft" and self.mutable_local:
assert not args
assert not kwargs
item = self.items[0]
tx.output.side_effects.mutation(self)
self.items = self.items[1:]
return item
result = item
elif name == "appendleft" and self.mutable_local:
assert not kwargs
tx.output.side_effects.mutation(self)
self.items = [args[0]] + list(self.items)
return ConstantVariable.create(None)
result = ConstantVariable.create(None)
else:
return super().call_method(tx, name, args, kwargs)
result = super().call_method(tx, name, args, kwargs)
if self.maxlen.as_python_constant() is not None:
self.items = list(self.items)[-self.maxlen.as_python_constant() :]
return result
class TupleVariable(BaseListVariable):

View file

@ -374,14 +374,31 @@ class UserDefinedClassVariable(UserDefinedVariable):
if self.value.__optional_keys__:
unimplemented("TypedDict with optional keys not supported")
return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs)
elif self.value is collections.deque and not kwargs:
if len(args) == 0:
items = []
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
items = args[0].force_unpack_var_sequence(tx)
elif self.value is collections.deque:
maxlen = variables.ConstantVariable.create(None)
if not kwargs:
if len(args) == 0:
items = []
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
items = args[0].force_unpack_var_sequence(tx)
elif len(args) == 2 and args[0].has_force_unpack_var_sequence(tx):
items = args[0].force_unpack_var_sequence(tx)
maxlen = args[1]
else:
unimplemented("deque() with more than 2 arg not supported")
elif tuple(kwargs) == ("maxlen",):
maxlen = kwargs["maxlen"]
if len(args) == 0:
items = []
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
items = args[0].force_unpack_var_sequence(tx)
else:
unimplemented("deque() with more than 1 arg not supported")
else:
unimplemented("deque() with more than 1 arg not supported")
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
unimplemented("deque() with invalid kwargs not supported")
return variables.lists.DequeVariable(
items, maxlen=maxlen, mutable_local=MutableLocal()
)
elif self.value is functools.partial:
if not args:
unimplemented("functools.partial malformed")

View file

@ -994,7 +994,7 @@ def tree_map_(
"""
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
tuple(map(func, *flat_args)) # consume and exhaust the iterable
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
return tree