mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo][lists] Support list subclasses
[ghstack-poisoned]
This commit is contained in:
parent
97b36c6513
commit
7270a4a6de
11 changed files with 193 additions and 68 deletions
|
|
@ -4499,6 +4499,60 @@ class DefaultsTests(torch._dynamo.test_case.TestCase):
|
|||
self.assertEqual(ref, res)
|
||||
self.assertTrue(isinstance(res, tuple))
|
||||
|
||||
def test_udf_list(self):
|
||||
class MyList(list): # noqa: SLOT001
|
||||
def len_mulitply_2(self):
|
||||
return len(self) * 2
|
||||
|
||||
def __contains__(self, val):
|
||||
# Ensure that overridden method is traced
|
||||
self.checked = True
|
||||
return super().__contains__(val)
|
||||
|
||||
def fn(x, lst):
|
||||
if 3 in lst:
|
||||
x = torch.cos(x)
|
||||
else:
|
||||
x = torch.sin(x)
|
||||
return x * lst.len_mulitply_2()
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
ref_lst = MyList([1, 2, 3])
|
||||
ref = fn(x, ref_lst)
|
||||
res_lst = MyList([1, 2, 3])
|
||||
res = opt_fn(x, res_lst)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertTrue(ref_lst.checked)
|
||||
self.assertTrue(res_lst.checked)
|
||||
|
||||
def test_udf_list_reconstruction(self):
|
||||
class MyList(list): # noqa: SLOT001
|
||||
# def __new__(cls, *args, **kwargs):
|
||||
# return super().__new__(cls, *args, **kwargs)
|
||||
pass
|
||||
|
||||
def fn(x, klass):
|
||||
x = x * 2
|
||||
sc_list = list.__new__(klass)
|
||||
sc_list.append(x)
|
||||
if isinstance(sc_list, MyList):
|
||||
sc_list.attr = 3
|
||||
return sc_list
|
||||
|
||||
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
||||
x = torch.randn(4)
|
||||
ref = fn(x, MyList)
|
||||
res = opt_fn(x, MyList)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertTrue(isinstance(res, MyList))
|
||||
self.assertEqual(ref.attr, res.attr)
|
||||
|
||||
ref = fn(x, list)
|
||||
res = opt_fn(x, list)
|
||||
self.assertEqual(ref, res)
|
||||
self.assertTrue(isinstance(res, list))
|
||||
|
||||
def test_sys_recursionlimit(self):
|
||||
def fn(x):
|
||||
return x.sin() * sys.getrecursionlimit()
|
||||
|
|
|
|||
|
|
@ -48,6 +48,11 @@ def _manual_dict_setitem(dict_from, dict_to, mro_index):
|
|||
dict_class.__setitem__(dict_to, k, v)
|
||||
|
||||
|
||||
def _manual_list_setitem(list_from, list_to):
|
||||
list.clear(list_to)
|
||||
list.extend(list_to, list_from)
|
||||
|
||||
|
||||
class SideEffects:
|
||||
"""
|
||||
Track side effects (list mutation, setattr, etc) that need to be
|
||||
|
|
@ -213,6 +218,8 @@ class SideEffects:
|
|||
dict.__getattribute__,
|
||||
int.__getattribute__,
|
||||
str.__getattribute__,
|
||||
list.__getattribute__,
|
||||
tuple.__getattribute__,
|
||||
)
|
||||
|
||||
def is_attribute_mutation(self, item):
|
||||
|
|
@ -233,8 +240,8 @@ class SideEffects:
|
|||
return False
|
||||
if isinstance(item.mutation_type, (AttributeMutationNew, ValueMutationNew)):
|
||||
return True
|
||||
if self.is_attribute_mutation(item):
|
||||
return item in self.store_attr_mutations
|
||||
if self.is_attribute_mutation(item) and item in self.store_attr_mutations:
|
||||
return True
|
||||
return item.mutation_type.is_modified
|
||||
|
||||
def _track_obj(
|
||||
|
|
@ -318,6 +325,8 @@ class SideEffects:
|
|||
variable_cls = variables.UserDefinedDictVariable
|
||||
elif issubclass(user_cls, tuple):
|
||||
variable_cls = variables.UserDefinedTupleVariable
|
||||
elif issubclass(user_cls, list):
|
||||
variable_cls = variables.UserDefinedListVariable
|
||||
elif issubclass(user_cls, MutableMapping):
|
||||
variable_cls = variables.MutableMappingVariable
|
||||
elif is_frozen_dataclass(user_cls):
|
||||
|
|
@ -822,6 +831,41 @@ class SideEffects:
|
|||
create_instruction("POP_TOP"),
|
||||
]
|
||||
)
|
||||
elif isinstance(var, variables.UserDefinedListVariable):
|
||||
# Update the list to the updated items. Be careful in
|
||||
# calling the list methods and not the overridden methods.
|
||||
varname_map = {}
|
||||
for name in _manual_list_setitem.__code__.co_varnames:
|
||||
varname_map[name] = cg.tx.output.new_var()
|
||||
|
||||
cg(var.source) # type: ignore[attr-defined]
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction(
|
||||
"STORE_FAST", argval=varname_map["list_to"]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
cg(var._list_vt, allow_cache=False) # Don't codegen via source
|
||||
cg.extend_output(
|
||||
[
|
||||
create_instruction(
|
||||
"STORE_FAST", argval=varname_map["list_from"]
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
list_update_insts = bytecode_from_template(
|
||||
_manual_list_setitem, varname_map=varname_map
|
||||
)
|
||||
|
||||
suffixes.append(
|
||||
[
|
||||
*list_update_insts,
|
||||
create_instruction("POP_TOP"),
|
||||
]
|
||||
)
|
||||
|
||||
# Applying mutations involves two steps: 1) Push all
|
||||
# reconstructed objects onto the stack. 2) Call STORE_ATTR to
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ manual_torch_name_rule_map = {
|
|||
# torch.fx map utils
|
||||
"torch.fx.node.map_aggregate": UserFunctionVariable,
|
||||
"torch.fx.node.map_arg": UserFunctionVariable,
|
||||
"torch.fx.immutable_collections._no_mutation": UserFunctionVariable,
|
||||
# symbol operators implemented in Python
|
||||
"torch.sym_not": TorchInGraphFunctionVariable,
|
||||
"torch.sym_float": TorchInGraphFunctionVariable,
|
||||
|
|
|
|||
|
|
@ -2356,6 +2356,7 @@ dict_methods = {
|
|||
|
||||
tuple_new = tuple.__new__
|
||||
tuple_methods = {method for method in tuple.__dict__.values() if callable(method)}
|
||||
list_methods = {method for method in list.__dict__.values() if callable(method)}
|
||||
|
||||
|
||||
def builtin_dict_keys(d):
|
||||
|
|
|
|||
|
|
@ -66,7 +66,6 @@ from .iter import (
|
|||
from .lazy import LazyVariableTracker
|
||||
from .lists import (
|
||||
BaseListVariable,
|
||||
FxImmutableListVariable,
|
||||
ListIteratorVariable,
|
||||
ListVariable,
|
||||
NamedTupleVariable,
|
||||
|
|
@ -121,6 +120,7 @@ from .user_defined import (
|
|||
RemovableHandleVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedListVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedTupleVariable,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -127,6 +127,7 @@ class AttributeMutation(MutationType):
|
|||
|
||||
def __init__(self, typ: SourceType):
|
||||
super().__init__(typ)
|
||||
self.is_modified = False
|
||||
|
||||
|
||||
class AttributeMutationExisting(AttributeMutation):
|
||||
|
|
|
|||
|
|
@ -225,6 +225,7 @@ from .user_defined import (
|
|||
SourcelessGraphModuleVariable,
|
||||
UserDefinedClassVariable,
|
||||
UserDefinedDictVariable,
|
||||
UserDefinedListVariable,
|
||||
UserDefinedObjectVariable,
|
||||
UserDefinedTupleVariable,
|
||||
)
|
||||
|
|
@ -1268,6 +1269,22 @@ class VariableBuilder:
|
|||
value, tuple_vt=tuple_vt, source=self.source
|
||||
)
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif isinstance(value, list):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
|
||||
|
||||
# NB - Be careful in not triggering user code. Guards also work on
|
||||
# the underlying list data structure.
|
||||
output = [
|
||||
LazyVariableTracker.create(
|
||||
list.__getitem__(value, i),
|
||||
source=GetItemSource(self.get_source(), i),
|
||||
)
|
||||
for i in range(list.__len__(value))
|
||||
]
|
||||
list_vt = ListVariable(output, mutation_type=ValueMutationNew())
|
||||
result = UserDefinedListVariable(value, list_vt=list_vt, source=self.source)
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif issubclass(type(value), MutableMapping):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
return MutableMappingVariable(value, source=self.source)
|
||||
|
|
|
|||
|
|
@ -1136,6 +1136,16 @@ class BuiltinVariable(VariableTracker):
|
|||
result.set_underlying_tuple_vt(tuple_vt)
|
||||
return result
|
||||
|
||||
if self.fn is list:
|
||||
list_vt = ListVariable([], mutation_type=ValueMutationNew())
|
||||
if isinstance(args[0], BuiltinVariable) and args[0].fn is list:
|
||||
return list_vt
|
||||
return tx.output.side_effects.track_new_user_defined_object(
|
||||
self,
|
||||
args[0],
|
||||
args[1:],
|
||||
)
|
||||
|
||||
if self.fn is object and name == "__init__":
|
||||
# object.__init__ is a no-op
|
||||
return variables.ConstantVariable(None)
|
||||
|
|
|
|||
|
|
@ -477,6 +477,13 @@ class ListVariable(CommonListMethodsVariable):
|
|||
self.items[:] = [x for x, *_ in sorted_items_with_keys]
|
||||
return ConstantVariable.create(None)
|
||||
|
||||
if name == "__init__" and self.is_mutable():
|
||||
assert not kwargs
|
||||
if len(args) == 1 and args[0].has_force_unpack_var_sequence(tx):
|
||||
(arg,) = args
|
||||
tx.output.side_effects.mutation(self)
|
||||
self.items[:] = arg.force_unpack_var_sequence(tx)
|
||||
return ConstantVariable.create(None)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def var_getattr(self, tx, name):
|
||||
|
|
@ -497,58 +504,6 @@ class ListVariable(CommonListMethodsVariable):
|
|||
return variables.ConstantVariable.create(hasattr([], name))
|
||||
|
||||
|
||||
class FxImmutableListVariable(ListVariable):
|
||||
def __init__(self, items, **kwargs) -> None:
|
||||
super().__init__(items, **kwargs)
|
||||
self.mutable_methods = {
|
||||
"__delitem__",
|
||||
"__iadd__",
|
||||
"__imul__",
|
||||
"__setitem__",
|
||||
"append",
|
||||
"clear",
|
||||
"extend",
|
||||
"insert",
|
||||
"pop",
|
||||
"remove",
|
||||
"reverse",
|
||||
"sort",
|
||||
}
|
||||
|
||||
def python_type(self):
|
||||
return torch.fx.immutable_collections.immutable_list
|
||||
|
||||
def reconstruct(self, codegen: "PyCodegen") -> None:
|
||||
# load torch.fx.immutable_collections.immutable_list
|
||||
codegen.add_push_null(
|
||||
lambda: codegen.extend_output(
|
||||
[
|
||||
codegen.create_load_python_module(torch.fx.immutable_collections),
|
||||
codegen.create_load_attr("immutable_list"),
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Construct the list
|
||||
super().reconstruct(codegen)
|
||||
|
||||
# Construct the immutable_list
|
||||
codegen.extend_output(create_call_function(1, False))
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: list["VariableTracker"],
|
||||
kwargs: dict[str, "VariableTracker"],
|
||||
) -> "VariableTracker":
|
||||
if name in self.mutable_methods:
|
||||
# immutable fx list raises NotImplementedError
|
||||
raise_observed_exception(NotImplementedError, tx)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
||||
class DequeVariable(CommonListMethodsVariable):
|
||||
def __init__(self, items, maxlen=None, **kwargs) -> None:
|
||||
if maxlen is None:
|
||||
|
|
|
|||
|
|
@ -26,6 +26,7 @@ from ..utils import (
|
|||
cmp_name_to_op_mapping,
|
||||
identity,
|
||||
is_tensor_base_attr_getter,
|
||||
list_methods,
|
||||
proxy_args_kwargs,
|
||||
set_example_value,
|
||||
tuple_methods,
|
||||
|
|
@ -219,6 +220,11 @@ class SuperVariable(VariableTracker):
|
|||
and inner_fn in tuple_methods
|
||||
):
|
||||
return self.objvar._tuple_vt.call_method(tx, name, args, kwargs)
|
||||
elif (
|
||||
isinstance(self.objvar, variables.UserDefinedListVariable)
|
||||
and inner_fn in list_methods
|
||||
):
|
||||
return self.objvar._list_vt.call_method(tx, name, args, kwargs)
|
||||
elif inner_fn is object.__getattribute__:
|
||||
# object.__getattribute__ has no side-effects. We can directly call
|
||||
# __getattribute__ to access the attribute.
|
||||
|
|
|
|||
|
|
@ -53,6 +53,7 @@ from ..utils import (
|
|||
is_utils_checkpoint,
|
||||
is_wrapper_or_member_descriptor,
|
||||
istype,
|
||||
list_methods,
|
||||
namedtuple_fields,
|
||||
object_has_getattribute,
|
||||
proxy_args_kwargs,
|
||||
|
|
@ -156,6 +157,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
object.__new__,
|
||||
dict.__new__,
|
||||
tuple.__new__,
|
||||
list.__new__,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -455,18 +457,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
elif self.value is torch.cuda.device and not kwargs and len(args) == 1:
|
||||
assert args[0].is_python_constant()
|
||||
return variables.CUDADeviceVariable.create(tx, args[0].as_python_constant())
|
||||
elif (
|
||||
self.value is torch.fx.immutable_collections.immutable_list
|
||||
and len(args) == 1
|
||||
and isinstance(args[0], variables.ListVariable)
|
||||
):
|
||||
arg = args[0]
|
||||
if arg.source:
|
||||
install_guard(arg.source.make_guard(GuardBuilder.SEQUENCE_LENGTH))
|
||||
return variables.FxImmutableListVariable(
|
||||
list(arg.unpack_var_sequence(tx)),
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
elif (
|
||||
issubclass(type(self.value), type)
|
||||
and hasattr(
|
||||
|
|
@ -1484,7 +1474,10 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
|
|||
) -> "VariableTracker":
|
||||
method = self._maybe_get_baseclass_method(name)
|
||||
if method in self._dict_methods:
|
||||
return self._dict_vt.call_method(tx, name, args, kwargs)
|
||||
out = self._dict_vt.call_method(tx, name, args, kwargs)
|
||||
if tx.output.side_effects.is_modified(self._dict_vt):
|
||||
self.mutation_type.is_modified = True
|
||||
return out
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
|
|
@ -1496,6 +1489,49 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class UserDefinedListVariable(UserDefinedObjectVariable):
|
||||
"""
|
||||
Represents user defined objects that are subclasses of lists.
|
||||
|
||||
Internally, it uses a ListVariable to represent the list part of the
|
||||
variable tracker. For everything else, it falls back to
|
||||
UserDefinedObjectVariable.
|
||||
"""
|
||||
|
||||
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
|
||||
|
||||
def __init__(self, value, list_vt=None, **kwargs):
|
||||
super().__init__(value, **kwargs)
|
||||
self._list_vt = list_vt
|
||||
if self._list_vt is None:
|
||||
assert (
|
||||
self.source is None
|
||||
), "list_vt must be constructed by builder.py when source is present"
|
||||
self._list_vt = variables.ListVariable([], mutation_type=ValueMutationNew())
|
||||
|
||||
def call_method(
|
||||
self,
|
||||
tx,
|
||||
name,
|
||||
args: "list[VariableTracker]",
|
||||
kwargs: "dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
assert self._list_vt is not None
|
||||
method = self._maybe_get_baseclass_method(name)
|
||||
if method in list_methods:
|
||||
out = self._list_vt.call_method(tx, name, args, kwargs)
|
||||
if tx.output.side_effects.is_modified(self._list_vt):
|
||||
self.mutation_type.is_modified = True
|
||||
return out
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
assert self._list_vt is not None
|
||||
if type(self.value).__iter__ is list.__iter__:
|
||||
return self._list_vt.unpack_var_sequence(tx)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class UserDefinedTupleVariable(UserDefinedObjectVariable):
|
||||
"""
|
||||
Represents user defined objects that are subclasses of tuple.
|
||||
|
|
|
|||
Loading…
Reference in a new issue