mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] support maxlen for collections.deque (#138194)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138194 Approved by: https://github.com/jansel, https://github.com/malfet
This commit is contained in:
parent
a4b35767cb
commit
9bbe4a67ad
4 changed files with 69 additions and 24 deletions
|
|
@ -9963,7 +9963,7 @@ def ___make_guard_fn():
|
|||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
collections.deque([0.0, -x]),
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
|
|
@ -9995,7 +9995,7 @@ def ___make_guard_fn():
|
|||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
[0.0, -x],
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
|
|
@ -10011,6 +10011,7 @@ def ___make_guard_fn():
|
|||
x * y,
|
||||
3.0,
|
||||
y - 2,
|
||||
1,
|
||||
torch.zeros(2, 2),
|
||||
2 * y,
|
||||
-y,
|
||||
|
|
@ -10043,7 +10044,7 @@ def ___make_guard_fn():
|
|||
"c": (
|
||||
x,
|
||||
3.0,
|
||||
[0.0, -x],
|
||||
collections.deque([0.0, -x, 1, 2], maxlen=3),
|
||||
),
|
||||
"d": collections.OrderedDict(
|
||||
{
|
||||
|
|
@ -10054,7 +10055,7 @@ def ___make_guard_fn():
|
|||
}
|
||||
tree2 = collections.OrderedDict(
|
||||
[
|
||||
("c", (y, 3.0, [-y, 10.0])),
|
||||
("c", (y, 3.0, collections.deque([1, -y, 10.0]))),
|
||||
("a", [y, y + 1]),
|
||||
("b", y + 2),
|
||||
(
|
||||
|
|
|
|||
|
|
@ -452,12 +452,33 @@ class ListVariable(CommonListMethodsVariable):
|
|||
|
||||
|
||||
class DequeVariable(CommonListMethodsVariable):
|
||||
def __init__(self, items, maxlen=None, **kwargs) -> None:
|
||||
if maxlen is None:
|
||||
maxlen = ConstantVariable.create(None)
|
||||
assert (
|
||||
maxlen.is_python_constant()
|
||||
), f"maxlen must be a constant, got: {maxlen.debug_repr()}"
|
||||
self.maxlen = maxlen
|
||||
if self.maxlen.as_python_constant() is not None:
|
||||
items = list(items)[-maxlen.as_python_constant() :]
|
||||
super().__init__(items, **kwargs)
|
||||
|
||||
def python_type(self):
|
||||
return collections.deque
|
||||
|
||||
def debug_repr(self):
|
||||
if self.maxlen.as_python_constant() is None:
|
||||
return self.debug_repr_helper(
|
||||
"deque([", "], maxlen=" + self.maxlen.debug_repr() + ")"
|
||||
)
|
||||
return self.debug_repr_helper("deque([", "])")
|
||||
|
||||
def as_python_constant(self):
|
||||
return self.python_type()(
|
||||
[x.as_python_constant() for x in self.items],
|
||||
maxlen=self.maxlen.as_python_constant(),
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
assert "deque" not in codegen.tx.f_globals
|
||||
codegen.add_push_null(
|
||||
|
|
@ -466,12 +487,14 @@ class DequeVariable(CommonListMethodsVariable):
|
|||
)
|
||||
)
|
||||
codegen.foreach(self.items)
|
||||
codegen.extend_output(
|
||||
[
|
||||
create_instruction("BUILD_LIST", arg=len(self.items)),
|
||||
*create_call_function(1, False),
|
||||
]
|
||||
)
|
||||
codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))])
|
||||
codegen(self.maxlen)
|
||||
codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False))
|
||||
|
||||
def var_getattr(self, tx: "InstructionTranslator", name):
|
||||
if name == "maxlen":
|
||||
return self.maxlen
|
||||
return super().var_getattr(tx, name)
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
|
|
@ -494,33 +517,37 @@ class DequeVariable(CommonListMethodsVariable):
|
|||
tx.output.side_effects.mutation(self)
|
||||
self.items[key.as_python_constant()] = value
|
||||
return ConstantVariable.create(None)
|
||||
elif (
|
||||
|
||||
if (
|
||||
name == "extendleft"
|
||||
and self.mutable_local
|
||||
and args[0].has_force_unpack_var_sequence(tx)
|
||||
):
|
||||
assert not kwargs
|
||||
|
||||
(arg,) = args
|
||||
prefix = arg.force_unpack_var_sequence(tx)
|
||||
prefix.reverse()
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items = prefix + list(self.items)
|
||||
return ConstantVariable.create(None)
|
||||
result = ConstantVariable.create(None)
|
||||
elif name == "popleft" and self.mutable_local:
|
||||
assert not args
|
||||
assert not kwargs
|
||||
item = self.items[0]
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items = self.items[1:]
|
||||
return item
|
||||
result = item
|
||||
elif name == "appendleft" and self.mutable_local:
|
||||
assert not kwargs
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items = [args[0]] + list(self.items)
|
||||
return ConstantVariable.create(None)
|
||||
result = ConstantVariable.create(None)
|
||||
else:
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
result = super().call_method(tx, name, args, kwargs)
|
||||
|
||||
if self.maxlen.as_python_constant() is not None:
|
||||
self.items = list(self.items)[-self.maxlen.as_python_constant() :]
|
||||
return result
|
||||
|
||||
|
||||
class TupleVariable(BaseListVariable):
|
||||
|
|
|
|||
|
|
@ -374,14 +374,31 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
if self.value.__optional_keys__:
|
||||
unimplemented("TypedDict with optional keys not supported")
|
||||
return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs)
|
||||
elif self.value is collections.deque and not kwargs:
|
||||
if len(args) == 0:
|
||||
items = []
|
||||
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||
items = args[0].force_unpack_var_sequence(tx)
|
||||
elif self.value is collections.deque:
|
||||
maxlen = variables.ConstantVariable.create(None)
|
||||
if not kwargs:
|
||||
if len(args) == 0:
|
||||
items = []
|
||||
elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||
items = args[0].force_unpack_var_sequence(tx)
|
||||
elif len(args) == 2 and args[0].has_force_unpack_var_sequence(tx):
|
||||
items = args[0].force_unpack_var_sequence(tx)
|
||||
maxlen = args[1]
|
||||
else:
|
||||
unimplemented("deque() with more than 2 arg not supported")
|
||||
elif tuple(kwargs) == ("maxlen",):
|
||||
maxlen = kwargs["maxlen"]
|
||||
if len(args) == 0:
|
||||
items = []
|
||||
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||
items = args[0].force_unpack_var_sequence(tx)
|
||||
else:
|
||||
unimplemented("deque() with more than 1 arg not supported")
|
||||
else:
|
||||
unimplemented("deque() with more than 1 arg not supported")
|
||||
return variables.lists.DequeVariable(items, mutable_local=MutableLocal())
|
||||
unimplemented("deque() with invalid kwargs not supported")
|
||||
return variables.lists.DequeVariable(
|
||||
items, maxlen=maxlen, mutable_local=MutableLocal()
|
||||
)
|
||||
elif self.value is functools.partial:
|
||||
if not args:
|
||||
unimplemented("functools.partial malformed")
|
||||
|
|
|
|||
|
|
@ -994,7 +994,7 @@ def tree_map_(
|
|||
"""
|
||||
leaves, treespec = tree_flatten(tree, is_leaf=is_leaf)
|
||||
flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests]
|
||||
tuple(map(func, *flat_args)) # consume and exhaust the iterable
|
||||
deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable
|
||||
return tree
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue