diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index 21be24a2678..71297447cf0 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9963,7 +9963,7 @@ def ___make_guard_fn(): "c": ( x, 3.0, - collections.deque([0.0, -x]), + collections.deque([0.0, -x, 1, 2], maxlen=3), ), "d": collections.OrderedDict( { @@ -9995,7 +9995,7 @@ def ___make_guard_fn(): "c": ( x, 3.0, - [0.0, -x], + collections.deque([0.0, -x, 1, 2], maxlen=3), ), "d": collections.OrderedDict( { @@ -10011,6 +10011,7 @@ def ___make_guard_fn(): x * y, 3.0, y - 2, + 1, torch.zeros(2, 2), 2 * y, -y, @@ -10043,7 +10044,7 @@ def ___make_guard_fn(): "c": ( x, 3.0, - [0.0, -x], + collections.deque([0.0, -x, 1, 2], maxlen=3), ), "d": collections.OrderedDict( { @@ -10054,7 +10055,7 @@ def ___make_guard_fn(): } tree2 = collections.OrderedDict( [ - ("c", (y, 3.0, [-y, 10.0])), + ("c", (y, 3.0, collections.deque([1, -y, 10.0]))), ("a", [y, y + 1]), ("b", y + 2), ( diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 44f053452ad..667cdb7c0ba 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -452,12 +452,33 @@ class ListVariable(CommonListMethodsVariable): class DequeVariable(CommonListMethodsVariable): + def __init__(self, items, maxlen=None, **kwargs) -> None: + if maxlen is None: + maxlen = ConstantVariable.create(None) + assert ( + maxlen.is_python_constant() + ), f"maxlen must be a constant, got: {maxlen.debug_repr()}" + self.maxlen = maxlen + if self.maxlen.as_python_constant() is not None: + items = list(items)[-maxlen.as_python_constant() :] + super().__init__(items, **kwargs) + def python_type(self): return collections.deque def debug_repr(self): + if self.maxlen.as_python_constant() is None: + return self.debug_repr_helper( + "deque([", "], maxlen=" + self.maxlen.debug_repr() + ")" + ) return self.debug_repr_helper("deque([", "])") + def as_python_constant(self): + return self.python_type()( + [x.as_python_constant() for x in self.items], + maxlen=self.maxlen.as_python_constant(), + ) + def reconstruct(self, codegen: "PyCodegen") -> None: assert "deque" not in codegen.tx.f_globals codegen.add_push_null( @@ -466,12 +487,14 @@ class DequeVariable(CommonListMethodsVariable): ) ) codegen.foreach(self.items) - codegen.extend_output( - [ - create_instruction("BUILD_LIST", arg=len(self.items)), - *create_call_function(1, False), - ] - ) + codegen.extend_output([create_instruction("BUILD_LIST", arg=len(self.items))]) + codegen(self.maxlen) + codegen.extend_output(codegen.create_call_function_kw(2, ("maxlen",), False)) + + def var_getattr(self, tx: "InstructionTranslator", name): + if name == "maxlen": + return self.maxlen + return super().var_getattr(tx, name) def call_method( self, @@ -494,33 +517,37 @@ class DequeVariable(CommonListMethodsVariable): tx.output.side_effects.mutation(self) self.items[key.as_python_constant()] = value return ConstantVariable.create(None) - elif ( + + if ( name == "extendleft" and self.mutable_local and args[0].has_force_unpack_var_sequence(tx) ): assert not kwargs - (arg,) = args prefix = arg.force_unpack_var_sequence(tx) prefix.reverse() tx.output.side_effects.mutation(self) self.items = prefix + list(self.items) - return ConstantVariable.create(None) + result = ConstantVariable.create(None) elif name == "popleft" and self.mutable_local: assert not args assert not kwargs item = self.items[0] tx.output.side_effects.mutation(self) self.items = self.items[1:] - return item + result = item elif name == "appendleft" and self.mutable_local: assert not kwargs tx.output.side_effects.mutation(self) self.items = [args[0]] + list(self.items) - return ConstantVariable.create(None) + result = ConstantVariable.create(None) else: - return super().call_method(tx, name, args, kwargs) + result = super().call_method(tx, name, args, kwargs) + + if self.maxlen.as_python_constant() is not None: + self.items = list(self.items)[-self.maxlen.as_python_constant() :] + return result class TupleVariable(BaseListVariable): diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index ff0c2bdf6a5..893e1fa0446 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -374,14 +374,31 @@ class UserDefinedClassVariable(UserDefinedVariable): if self.value.__optional_keys__: unimplemented("TypedDict with optional keys not supported") return variables.BuiltinVariable(dict).call_dict(tx, *args, **kwargs) - elif self.value is collections.deque and not kwargs: - if len(args) == 0: - items = [] - elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): - items = args[0].force_unpack_var_sequence(tx) + elif self.value is collections.deque: + maxlen = variables.ConstantVariable.create(None) + if not kwargs: + if len(args) == 0: + items = [] + elif len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + elif len(args) == 2 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + maxlen = args[1] + else: + unimplemented("deque() with more than 2 arg not supported") + elif tuple(kwargs) == ("maxlen",): + maxlen = kwargs["maxlen"] + if len(args) == 0: + items = [] + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + items = args[0].force_unpack_var_sequence(tx) + else: + unimplemented("deque() with more than 1 arg not supported") else: - unimplemented("deque() with more than 1 arg not supported") - return variables.lists.DequeVariable(items, mutable_local=MutableLocal()) + unimplemented("deque() with invalid kwargs not supported") + return variables.lists.DequeVariable( + items, maxlen=maxlen, mutable_local=MutableLocal() + ) elif self.value is functools.partial: if not args: unimplemented("functools.partial malformed") diff --git a/torch/utils/_pytree.py b/torch/utils/_pytree.py index a1b836594a8..de9fbebbe37 100644 --- a/torch/utils/_pytree.py +++ b/torch/utils/_pytree.py @@ -994,7 +994,7 @@ def tree_map_( """ leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] - tuple(map(func, *flat_args)) # consume and exhaust the iterable + deque(map(func, *flat_args), maxlen=0) # consume and exhaust the iterable return tree