[dynamo][lists] Support list subclasses

[ghstack-poisoned]
This commit is contained in:
Animesh Jain 2025-02-09 23:33:58 -08:00
parent 97b36c6513
commit 7270a4a6de
11 changed files with 193 additions and 68 deletions

View file

@ -4499,6 +4499,60 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
self.assertEqual(ref, res)
self.assertTrue(isinstance(res, tuple))
def test_udf_list(self):
class MyList(list): # noqa: SLOT001
def len_mulitply_2(self):
return len(self) * 2
def __contains__(self, val):
# Ensure that overridden method is traced
self.checked = True
return super().__contains__(val)
def fn(x, lst):
if 3 in lst:
x = torch.cos(x)
else:
x = torch.sin(x)
return x * lst.len_mulitply_2()
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref_lst = MyList([1, 2, 3])
ref = fn(x, ref_lst)
res_lst = MyList([1, 2, 3])
res = opt_fn(x, res_lst)
self.assertEqual(ref, res)
self.assertTrue(ref_lst.checked)
self.assertTrue(res_lst.checked)
def test_udf_list_reconstruction(self):
class MyList(list): # noqa: SLOT001
# def __new__(cls, *args, **kwargs):
# return super().__new__(cls, *args, **kwargs)
pass
def fn(x, klass):
x = x * 2
sc_list = list.__new__(klass)
sc_list.append(x)
if isinstance(sc_list, MyList):
sc_list.attr = 3
return sc_list
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
x = torch.randn(4)
ref = fn(x, MyList)
res = opt_fn(x, MyList)
self.assertEqual(ref, res)
self.assertTrue(isinstance(res, MyList))
self.assertEqual(ref.attr, res.attr)
ref = fn(x, list)
res = opt_fn(x, list)
self.assertEqual(ref, res)
self.assertTrue(isinstance(res, list))
def test_sys_recursionlimit(self):
def fn(x):
return x.sin() * sys.getrecursionlimit()

View file

@ -48,6 +48,11 @@ def _manual_dict_setitem(dict_from, dict_to, mro_index):
dict_class.__setitem__(dict_to, k, v)
def _manual_list_setitem(list_from, list_to):
list.clear(list_to)
list.extend(list_to, list_from)
class SideEffects:
"""
Track side effects (list mutation, setattr, etc) that need to be
@ -213,6 +218,8 @@ class SideEffects:
dict.__getattribute__,
int.__getattribute__,
str.__getattribute__,
list.__getattribute__,
tuple.__getattribute__,
)
def is_attribute_mutation(self, item):
@ -233,8 +240,8 @@ class SideEffects:
return False
if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)):
return True
if self.is_attribute_mutation(item):
return item in self.store_attr_mutations
if self.is_attribute_mutation(item) and item in self.store_attr_mutations:
return True
return item.mutation_type.is_modified
def _track_obj(
@ -318,6 +325,8 @@ class SideEffects:
variable_cls = variables.UserDefinedDictVariable
elif issubclass(user_cls, tuple):
variable_cls = variables.UserDefinedTupleVariable
elif issubclass(user_cls, list):
variable_cls = variables.UserDefinedListVariable
elif issubclass(user_cls, MutableMapping):
variable_cls = variables.MutableMappingVariable
elif is_frozen_dataclass(user_cls):
@ -822,6 +831,41 @@ class SideEffects:
create_instruction("POP_TOP"),
]
)
elif isinstance(var, variables.UserDefinedListVariable):
# Update the list to the updated items. Be careful in
# calling the list methods and not the overridden methods.
varname_map = {}
for name in _manual_list_setitem.__code__.co_varnames:
varname_map[name] = cg.tx.output.new_var()
cg(var.source) # type: ignore[attr-defined]
cg.extend_output(
[
create_instruction(
"STORE_FAST", argval=varname_map["list_to"]
)
]
)
cg(var._list_vt, allow_cache=False) # Don't codegen via source
cg.extend_output(
[
create_instruction(
"STORE_FAST", argval=varname_map["list_from"]
)
]
)
list_update_insts = bytecode_from_template(
_manual_list_setitem, varname_map=varname_map
)
suffixes.append(
[
*list_update_insts,
create_instruction("POP_TOP"),
]
)
# Applying mutations involves two steps: 1) Push all
# reconstructed objects onto the stack. 2) Call STORE_ATTR to

View file

@ -174,6 +174,7 @@ manual_torch_name_rule_map = {
# torch.fx map utils
"torch.fx.node.map_aggregate": UserFunctionVariable,
"torch.fx.node.map_arg": UserFunctionVariable,
"torch.fx.immutable_collections._no_mutation": UserFunctionVariable,
# symbol operators implemented in Python
"torch.sym_not": TorchInGraphFunctionVariable,
"torch.sym_float": TorchInGraphFunctionVariable,

View file

@ -2356,6 +2356,7 @@ dict_methods = {
tuple_new = tuple.__new__
tuple_methods = {method for method in tuple.__dict__.values() if callable(method)}
list_methods = {method for method in list.__dict__.values() if callable(method)}
def builtin_dict_keys(d):

View file

@ -66,7 +66,6 @@ from .iter import (
from .lazy import LazyVariableTracker
from .lists import (
BaseListVariable,
FxImmutableListVariable,
ListIteratorVariable,
ListVariable,
NamedTupleVariable,
@ -121,6 +120,7 @@ from .user_defined import (
RemovableHandleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedListVariable,
UserDefinedObjectVariable,
UserDefinedTupleVariable,
)

View file

@ -127,6 +127,7 @@ class AttributeMutation(MutationType):
def __init__(self, typ: SourceType):
super().__init__(typ)
self.is_modified = False
class AttributeMutationExisting(AttributeMutation):

View file

@ -225,6 +225,7 @@ from .user_defined import (
SourcelessGraphModuleVariable,
UserDefinedClassVariable,
UserDefinedDictVariable,
UserDefinedListVariable,
UserDefinedObjectVariable,
UserDefinedTupleVariable,
)
@ -1268,6 +1269,22 @@ class VariableBuilder:
value, tuple_vt=tuple_vt, source=self.source
)
return self.tx.output.side_effects.track_object_existing(value, result)
elif isinstance(value, list):
self.install_guards(GuardBuilder.TYPE_MATCH)
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
# NB - Be careful in not triggering user code. Guards also work on
# the underlying list data structure.
output = [
LazyVariableTracker.create(
list.__getitem__(value, i),
source=GetItemSource(self.get_source(), i),
)
for i in range(list.__len__(value))
]
list_vt = ListVariable(output, mutation_type=ValueMutationNew())
result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source)
return self.tx.output.side_effects.track_object_existing(value, result)
elif issubclass(type(value), MutableMapping):
self.install_guards(GuardBuilder.TYPE_MATCH)
return MutableMappingVariable(value, source=self.source)

View file

@ -1136,6 +1136,16 @@ class BuiltinVariable(VariableTracker):
result.set_underlying_tuple_vt(tuple_vt)
return result
if self.fn is list:
list_vt = ListVariable([], mutation_type=ValueMutationNew())
if isinstance(args[0], BuiltinVariable) and args[0].fn is list:
return list_vt
return tx.output.side_effects.track_new_user_defined_object(
self,
args[0],
args[1:],
)
if self.fn is object and name == "__init__":
# object.__init__ is a no-op
return variables.ConstantVariable(None)

View file

@ -477,6 +477,13 @@ class ListVariable(CommonListMethodsVariable):
self.items[:] = [x for x, *_ in sorted_items_with_keys]
return ConstantVariable.create(None)
if name == "__init__" and self.is_mutable():
assert not kwargs
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
(arg,) = args
tx.output.side_effects.mutation(self)
self.items[:] = arg.force_unpack_var_sequence(tx)
return ConstantVariable.create(None)
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
@ -497,58 +504,6 @@ class ListVariable(CommonListMethodsVariable):
return variables.ConstantVariable.create(hasattr([], name))
class FxImmutableListVariable(ListVariable):
def __init__(self, items, **kwargs) -> None:
super().__init__(items, **kwargs)
self.mutable_methods = {
"__delitem__",
"__iadd__",
"__imul__",
"__setitem__",
"append",
"clear",
"extend",
"insert",
"pop",
"remove",
"reverse",
"sort",
}
def python_type(self):
return torch.fx.immutable_collections.immutable_list
def reconstruct(self, codegen: "PyCodegen") -> None:
# load torch.fx.immutable_collections.immutable_list
codegen.add_push_null(
lambda: codegen.extend_output(
[
codegen.create_load_python_module(torch.fx.immutable_collections),
codegen.create_load_attr("immutable_list"),
]
)
)
# Construct the list
super().reconstruct(codegen)
# Construct the immutable_list
codegen.extend_output(create_call_function(1, False))
def call_method(
self,
tx,
name,
args: list["VariableTracker"],
kwargs: dict[str, "VariableTracker"],
) -> "VariableTracker":
if name in self.mutable_methods:
# immutable fx list raises NotImplementedError
raise_observed_exception(NotImplementedError, tx)
return super().call_method(tx, name, args, kwargs)
class DequeVariable(CommonListMethodsVariable):
def __init__(self, items, maxlen=None, **kwargs) -> None:
if maxlen is None:

View file

@ -26,6 +26,7 @@ from ..utils import (
cmp_name_to_op_mapping,
identity,
is_tensor_base_attr_getter,
list_methods,
proxy_args_kwargs,
set_example_value,
tuple_methods,
@ -219,6 +220,11 @@ class SuperVariable(VariableTracker):
and inner_fn in tuple_methods
):
return self.objvar._tuple_vt.call_method(tx, name, args, kwargs)
elif (
isinstance(self.objvar, variables.UserDefinedListVariable)
and inner_fn in list_methods
):
return self.objvar._list_vt.call_method(tx, name, args, kwargs)
elif inner_fn is object.__getattribute__:
# object.__getattribute__ has no side-effects. We can directly call
# __getattribute__ to access the attribute.

View file

@ -53,6 +53,7 @@ from ..utils import (
is_utils_checkpoint,
is_wrapper_or_member_descriptor,
istype,
list_methods,
namedtuple_fields,
object_has_getattribute,
proxy_args_kwargs,
@ -156,6 +157,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
object.__new__,
dict.__new__,
tuple.__new__,
list.__new__,
}
@staticmethod
@ -455,18 +457,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
elif self.value is torch.cuda.device and not kwargs and len(args) == 1:
assert args[0].is_python_constant()
return variables.CUDADeviceVariable.create(tx, args[0].as_python_constant())
elif (
self.value is torch.fx.immutable_collections.immutable_list
and len(args) == 1
and isinstance(args[0], variables.ListVariable)
):
arg = args[0]
if arg.source:
install_guard(arg.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
return variables.FxImmutableListVariable(
list(arg.unpack_var_sequence(tx)),
mutation_type=ValueMutationNew(),
)
elif (
issubclass(type(self.value), type)
and hasattr(
@ -1484,7 +1474,10 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
) -> "VariableTracker":
method = self._maybe_get_baseclass_method(name)
if method in self._dict_methods:
return self._dict_vt.call_method(tx, name, args, kwargs)
out = self._dict_vt.call_method(tx, name, args, kwargs)
if tx.output.side_effects.is_modified(self._dict_vt):
self.mutation_type.is_modified = True
return out
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
@ -1496,6 +1489,49 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
raise NotImplementedError
class UserDefinedListVariable(UserDefinedObjectVariable):
"""
Represents user defined objects that are subclasses of lists.
Internally, it uses a ListVariable to represent the list part of the
variable tracker. For everything else, it falls back to
UserDefinedObjectVariable.
"""
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
def __init__(self, value, list_vt=None, **kwargs):
super().__init__(value, **kwargs)
self._list_vt = list_vt
if self._list_vt is None:
assert (
self.source is None
), "list_vt must be constructed by builder.py when source is present"
self._list_vt = variables.ListVariable([], mutation_type=ValueMutationNew())
def call_method(
self,
tx,
name,
args: "list[VariableTracker]",
kwargs: "dict[str, VariableTracker]",
) -> "VariableTracker":
assert self._list_vt is not None
method = self._maybe_get_baseclass_method(name)
if method in list_methods:
out = self._list_vt.call_method(tx, name, args, kwargs)
if tx.output.side_effects.is_modified(self._list_vt):
self.mutation_type.is_modified = True
return out
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
assert self._list_vt is not None
if type(self.value).__iter__ is list.__iter__:
return self._list_vt.unpack_var_sequence(tx)
raise NotImplementedError
class UserDefinedTupleVariable(UserDefinedObjectVariable):
"""
Represents user defined objects that are subclasses of tuple.