diff --git a/test/dynamo/test_functions.py b/test/dynamo/test_functions.py index 1d8ca18df61..57d04d2027f 100644 --- a/test/dynamo/test_functions.py +++ b/test/dynamo/test_functions.py @@ -4499,6 +4499,60 @@ class DefaultsTests(torch._dynamo.test_case.TestCase): self.assertEqual(ref, res) self.assertTrue(isinstance(res, tuple)) + def test_udf_list(self): + class MyList(list): # noqa: SLOT001 + def len_mulitply_2(self): + return len(self) * 2 + + def __contains__(self, val): + # Ensure that overridden method is traced + self.checked = True + return super().__contains__(val) + + def fn(x, lst): + if 3 in lst: + x = torch.cos(x) + else: + x = torch.sin(x) + return x * lst.len_mulitply_2() + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref_lst = MyList([1, 2, 3]) + ref = fn(x, ref_lst) + res_lst = MyList([1, 2, 3]) + res = opt_fn(x, res_lst) + self.assertEqual(ref, res) + self.assertTrue(ref_lst.checked) + self.assertTrue(res_lst.checked) + + def test_udf_list_reconstruction(self): + class MyList(list): # noqa: SLOT001 + # def __new__(cls, *args, **kwargs): + # return super().__new__(cls, *args, **kwargs) + pass + + def fn(x, klass): + x = x * 2 + sc_list = list.__new__(klass) + sc_list.append(x) + if isinstance(sc_list, MyList): + sc_list.attr = 3 + return sc_list + + opt_fn = torch.compile(fn, backend="eager", fullgraph=True) + x = torch.randn(4) + ref = fn(x, MyList) + res = opt_fn(x, MyList) + self.assertEqual(ref, res) + self.assertTrue(isinstance(res, MyList)) + self.assertEqual(ref.attr, res.attr) + + ref = fn(x, list) + res = opt_fn(x, list) + self.assertEqual(ref, res) + self.assertTrue(isinstance(res, list)) + def test_sys_recursionlimit(self): def fn(x): return x.sin() * sys.getrecursionlimit() diff --git a/torch/_dynamo/side_effects.py b/torch/_dynamo/side_effects.py index 93ec5aa05e2..854807fd4cb 100644 --- a/torch/_dynamo/side_effects.py +++ b/torch/_dynamo/side_effects.py @@ -48,6 +48,11 @@ def _manual_dict_setitem(dict_from, dict_to, mro_index): dict_class.__setitem__(dict_to, k, v) +def _manual_list_setitem(list_from, list_to): + list.clear(list_to) + list.extend(list_to, list_from) + + class SideEffects: """ Track side effects (list mutation, setattr, etc) that need to be @@ -213,6 +218,8 @@ class SideEffects: dict.__getattribute__, int.__getattribute__, str.__getattribute__, + list.__getattribute__, + tuple.__getattribute__, ) def is_attribute_mutation(self, item): @@ -233,8 +240,8 @@ class SideEffects: return False if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)): return True - if self.is_attribute_mutation(item): - return item in self.store_attr_mutations + if self.is_attribute_mutation(item) and item in self.store_attr_mutations: + return True return item.mutation_type.is_modified def _track_obj( @@ -318,6 +325,8 @@ class SideEffects: variable_cls = variables.UserDefinedDictVariable elif issubclass(user_cls, tuple): variable_cls = variables.UserDefinedTupleVariable + elif issubclass(user_cls, list): + variable_cls = variables.UserDefinedListVariable elif issubclass(user_cls, MutableMapping): variable_cls = variables.MutableMappingVariable elif is_frozen_dataclass(user_cls): @@ -822,6 +831,41 @@ class SideEffects: create_instruction("POP_TOP"), ] ) + elif isinstance(var, variables.UserDefinedListVariable): + # Update the list to the updated items. Be careful in + # calling the list methods and not the overridden methods. + varname_map = {} + for name in _manual_list_setitem.__code__.co_varnames: + varname_map[name] = cg.tx.output.new_var() + + cg(var.source) # type: ignore[attr-defined] + cg.extend_output( + [ + create_instruction( + "STORE_FAST", argval=varname_map["list_to"] + ) + ] + ) + + cg(var._list_vt, allow_cache=False) # Don't codegen via source + cg.extend_output( + [ + create_instruction( + "STORE_FAST", argval=varname_map["list_from"] + ) + ] + ) + + list_update_insts = bytecode_from_template( + _manual_list_setitem, varname_map=varname_map + ) + + suffixes.append( + [ + *list_update_insts, + create_instruction("POP_TOP"), + ] + ) # Applying mutations involves two steps: 1) Push all # reconstructed objects onto the stack. 2) Call STORE_ATTR to diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index d06ffccfc4c..7d881a6ed9c 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -174,6 +174,7 @@ manual_torch_name_rule_map = { # torch.fx map utils "torch.fx.node.map_aggregate": UserFunctionVariable, "torch.fx.node.map_arg": UserFunctionVariable, + "torch.fx.immutable_collections._no_mutation": UserFunctionVariable, # symbol operators implemented in Python "torch.sym_not": TorchInGraphFunctionVariable, "torch.sym_float": TorchInGraphFunctionVariable, diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index de834e2c594..6130c1168f4 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2356,6 +2356,7 @@ dict_methods = { tuple_new = tuple.__new__ tuple_methods = {method for method in tuple.__dict__.values() if callable(method)} +list_methods = {method for method in list.__dict__.values() if callable(method)} def builtin_dict_keys(d): diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index ba7a10267e2..fbb26ac0358 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -66,7 +66,6 @@ from .iter import ( from .lazy import LazyVariableTracker from .lists import ( BaseListVariable, - FxImmutableListVariable, ListIteratorVariable, ListVariable, NamedTupleVariable, @@ -121,6 +120,7 @@ from .user_defined import ( RemovableHandleVariable, UserDefinedClassVariable, UserDefinedDictVariable, + UserDefinedListVariable, UserDefinedObjectVariable, UserDefinedTupleVariable, ) diff --git a/torch/_dynamo/variables/base.py b/torch/_dynamo/variables/base.py index e6fffdc40f2..56bf5192ed2 100644 --- a/torch/_dynamo/variables/base.py +++ b/torch/_dynamo/variables/base.py @@ -127,6 +127,7 @@ class AttributeMutation(MutationType): def __init__(self, typ: SourceType): super().__init__(typ) + self.is_modified = False class AttributeMutationExisting(AttributeMutation): diff --git a/torch/_dynamo/variables/builder.py b/torch/_dynamo/variables/builder.py index f43bb435c08..c732d2bc4e9 100644 --- a/torch/_dynamo/variables/builder.py +++ b/torch/_dynamo/variables/builder.py @@ -225,6 +225,7 @@ from .user_defined import ( SourcelessGraphModuleVariable, UserDefinedClassVariable, UserDefinedDictVariable, + UserDefinedListVariable, UserDefinedObjectVariable, UserDefinedTupleVariable, ) @@ -1268,6 +1269,22 @@ class VariableBuilder: value, tuple_vt=tuple_vt, source=self.source ) return self.tx.output.side_effects.track_object_existing(value, result) + elif isinstance(value, list): + self.install_guards(GuardBuilder.TYPE_MATCH) + self.install_guards(GuardBuilder.SEQUENCE_LENGTH) + + # NB - Be careful in not triggering user code. Guards also work on + # the underlying list data structure. + output = [ + LazyVariableTracker.create( + list.__getitem__(value, i), + source=GetItemSource(self.get_source(), i), + ) + for i in range(list.__len__(value)) + ] + list_vt = ListVariable(output, mutation_type=ValueMutationNew()) + result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source) + return self.tx.output.side_effects.track_object_existing(value, result) elif issubclass(type(value), MutableMapping): self.install_guards(GuardBuilder.TYPE_MATCH) return MutableMappingVariable(value, source=self.source) diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 15f384eeeaa..6727e64790f 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -1136,6 +1136,16 @@ class BuiltinVariable(VariableTracker): result.set_underlying_tuple_vt(tuple_vt) return result + if self.fn is list: + list_vt = ListVariable([], mutation_type=ValueMutationNew()) + if isinstance(args[0], BuiltinVariable) and args[0].fn is list: + return list_vt + return tx.output.side_effects.track_new_user_defined_object( + self, + args[0], + args[1:], + ) + if self.fn is object and name == "__init__": # object.__init__ is a no-op return variables.ConstantVariable(None) diff --git a/torch/_dynamo/variables/lists.py b/torch/_dynamo/variables/lists.py index 438529cac69..87d9a0de235 100644 --- a/torch/_dynamo/variables/lists.py +++ b/torch/_dynamo/variables/lists.py @@ -477,6 +477,13 @@ class ListVariable(CommonListMethodsVariable): self.items[:] = [x for x, *_ in sorted_items_with_keys] return ConstantVariable.create(None) + if name == "__init__" and self.is_mutable(): + assert not kwargs + if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx): + (arg,) = args + tx.output.side_effects.mutation(self) + self.items[:] = arg.force_unpack_var_sequence(tx) + return ConstantVariable.create(None) return super().call_method(tx, name, args, kwargs) def var_getattr(self, tx, name): @@ -497,58 +504,6 @@ class ListVariable(CommonListMethodsVariable): return variables.ConstantVariable.create(hasattr([], name)) -class FxImmutableListVariable(ListVariable): - def __init__(self, items, **kwargs) -> None: - super().__init__(items, **kwargs) - self.mutable_methods = { - "__delitem__", - "__iadd__", - "__imul__", - "__setitem__", - "append", - "clear", - "extend", - "insert", - "pop", - "remove", - "reverse", - "sort", - } - - def python_type(self): - return torch.fx.immutable_collections.immutable_list - - def reconstruct(self, codegen: "PyCodegen") -> None: - # load torch.fx.immutable_collections.immutable_list - codegen.add_push_null( - lambda: codegen.extend_output( - [ - codegen.create_load_python_module(torch.fx.immutable_collections), - codegen.create_load_attr("immutable_list"), - ] - ) - ) - - # Construct the list - super().reconstruct(codegen) - - # Construct the immutable_list - codegen.extend_output(create_call_function(1, False)) - - def call_method( - self, - tx, - name, - args: list["VariableTracker"], - kwargs: dict[str, "VariableTracker"], - ) -> "VariableTracker": - if name in self.mutable_methods: - # immutable fx list raises NotImplementedError - raise_observed_exception(NotImplementedError, tx) - - return super().call_method(tx, name, args, kwargs) - - class DequeVariable(CommonListMethodsVariable): def __init__(self, items, maxlen=None, **kwargs) -> None: if maxlen is None: diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index d7c78bdaa93..86d9dccfb68 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -26,6 +26,7 @@ from ..utils import ( cmp_name_to_op_mapping, identity, is_tensor_base_attr_getter, + list_methods, proxy_args_kwargs, set_example_value, tuple_methods, @@ -219,6 +220,11 @@ class SuperVariable(VariableTracker): and inner_fn in tuple_methods ): return self.objvar._tuple_vt.call_method(tx, name, args, kwargs) + elif ( + isinstance(self.objvar, variables.UserDefinedListVariable) + and inner_fn in list_methods + ): + return self.objvar._list_vt.call_method(tx, name, args, kwargs) elif inner_fn is object.__getattribute__: # object.__getattribute__ has no side-effects. We can directly call # __getattribute__ to access the attribute. diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index 0ea1583af57..1d4899f6d3f 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -53,6 +53,7 @@ from ..utils import ( is_utils_checkpoint, is_wrapper_or_member_descriptor, istype, + list_methods, namedtuple_fields, object_has_getattribute, proxy_args_kwargs, @@ -156,6 +157,7 @@ class UserDefinedClassVariable(UserDefinedVariable): object.__new__, dict.__new__, tuple.__new__, + list.__new__, } @staticmethod @@ -455,18 +457,6 @@ class UserDefinedClassVariable(UserDefinedVariable): elif self.value is torch.cuda.device and not kwargs and len(args) == 1: assert args[0].is_python_constant() return variables.CUDADeviceVariable.create(tx, args[0].as_python_constant()) - elif ( - self.value is torch.fx.immutable_collections.immutable_list - and len(args) == 1 - and isinstance(args[0], variables.ListVariable) - ): - arg = args[0] - if arg.source: - install_guard(arg.source.make_guard(GuardBuilder.SEQUENCE_LENGTH)) - return variables.FxImmutableListVariable( - list(arg.unpack_var_sequence(tx)), - mutation_type=ValueMutationNew(), - ) elif ( issubclass(type(self.value), type) and hasattr( @@ -1484,7 +1474,10 @@ class UserDefinedDictVariable(UserDefinedObjectVariable): ) -> "VariableTracker": method = self._maybe_get_baseclass_method(name) if method in self._dict_methods: - return self._dict_vt.call_method(tx, name, args, kwargs) + out = self._dict_vt.call_method(tx, name, args, kwargs) + if tx.output.side_effects.is_modified(self._dict_vt): + self.mutation_type.is_modified = True + return out return super().call_method(tx, name, args, kwargs) def unpack_var_sequence(self, tx): @@ -1496,6 +1489,49 @@ class UserDefinedDictVariable(UserDefinedObjectVariable): raise NotImplementedError +class UserDefinedListVariable(UserDefinedObjectVariable): + """ + Represents user defined objects that are subclasses of lists. + + Internally, it uses a ListVariable to represent the list part of the + variable tracker. For everything else, it falls back to + UserDefinedObjectVariable. + """ + + _nonvar_fields = UserDefinedObjectVariable._nonvar_fields + + def __init__(self, value, list_vt=None, **kwargs): + super().__init__(value, **kwargs) + self._list_vt = list_vt + if self._list_vt is None: + assert ( + self.source is None + ), "list_vt must be constructed by builder.py when source is present" + self._list_vt = variables.ListVariable([], mutation_type=ValueMutationNew()) + + def call_method( + self, + tx, + name, + args: "list[VariableTracker]", + kwargs: "dict[str, VariableTracker]", + ) -> "VariableTracker": + assert self._list_vt is not None + method = self._maybe_get_baseclass_method(name) + if method in list_methods: + out = self._list_vt.call_method(tx, name, args, kwargs) + if tx.output.side_effects.is_modified(self._list_vt): + self.mutation_type.is_modified = True + return out + return super().call_method(tx, name, args, kwargs) + + def unpack_var_sequence(self, tx): + assert self._list_vt is not None + if type(self.value).__iter__ is list.__iter__: + return self._list_vt.unpack_var_sequence(tx) + raise NotImplementedError + + class UserDefinedTupleVariable(UserDefinedObjectVariable): """ Represents user defined objects that are subclasses of tuple.