diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index a246de3a9e5..e49903bac71 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -1810,8 +1810,8 @@ def forward(self, L_x_ : torch.Tensor): getitem_4 = map_impl[3] getitem_5 = map_impl[4] getitem_6 = map_impl[5] - value = map_impl[6]; map_impl = None - return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, value)""", + getitem_7 = map_impl[6]; map_impl = None + return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7)""", ) self.assertExpectedInline( body_graph, @@ -2632,8 +2632,8 @@ class GraphModule(torch.nn.Module): wrap_body_0 = self.wrap_body_0 wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None - value: "f32[3]" = wrap[0]; wrap = None - return (value,) + getitem: "f32[3]" = wrap[0]; wrap = None + return (getitem,) class wrap_body_0(torch.nn.Module): def forward(self, l_x_: "f32[3]"): @@ -4209,8 +4209,8 @@ class GraphModule(torch.nn.Module): child_1: "f32[5]" = child.sin() child_2: "f32[5]" = child.cos(); child = None - value: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) - value_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) + _unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1) + _unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1) _grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None _saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None @@ -4219,7 +4219,7 @@ class GraphModule(torch.nn.Module): _autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None - return (value, value_1, getitem) + return (_unwrap_for_grad, _unwrap_for_grad_1, getitem) """, ) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 056c9a2ba04..58461626272 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -4088,9 +4088,9 @@ def forward(self, L_it_ : torch.Tensor, L_pytree_input_0_0_ : torch.Tensor, L_py while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_it_, l_pytree_input_0_0_, l_pytree_input_1_x_, l_pytree_input_1_y_), ()); cond_fn_0 = body_fn_0 = l_it_ = l_pytree_input_0_0_ = l_pytree_input_1_x_ = l_pytree_input_1_y_ = None getitem = while_loop[0] getitem_1 = while_loop[1] - value = while_loop[2] - value_1 = while_loop[3]; while_loop = None - return (getitem, getitem_1, value, value_1)""", # noqa: B950 + getitem_2 = while_loop[2] + getitem_3 = while_loop[3]; while_loop = None + return (getitem, getitem_1, getitem_2, getitem_3)""", # noqa: B950 ) def _wrap_with_functionalize(self, fn, func_type): diff --git a/torch/_dynamo/polyfills/__init__.py b/torch/_dynamo/polyfills/__init__.py index 17c733dec08..9837e4af900 100644 --- a/torch/_dynamo/polyfills/__init__.py +++ b/torch/_dynamo/polyfills/__init__.py @@ -9,7 +9,7 @@ Python polyfills for common builtins. # mypy: allow-untyped-defs from itertools import repeat as _repeat -from typing import Any, Callable, List, MutableMapping, Sequence, TYPE_CHECKING +from typing import Any, Callable, List, Sequence, TYPE_CHECKING import torch @@ -146,30 +146,6 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs): return obj -# Used with something like dict(obj) -def construct_dict(cls, /, *args, **kwargs): - dst = cls.__new__(cls) - - if args: - src = args[0] - - # Ensure that the overridden __iter__ method is invoked - if isinstance(src, (dict, MutableMapping)): - for key in src: - # This will inline the __getitem__ of the src object - dst[key] = src[key] - else: - # likely a sequence like tuple of pairs - for key, value in src: - dst[key] = value - - if kwargs: - for key in kwargs: - dst[key] = kwargs[key] - - return dst - - def foreach_map_fn(*args): op = args[0] new_args: List[Any] = [] diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index c2cb26efae1..1db051565e0 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -9,14 +9,14 @@ import math import operator import types from collections import defaultdict, OrderedDict -from collections.abc import KeysView +from collections.abc import KeysView, MutableMapping from typing import Dict, List, TYPE_CHECKING import torch from torch import sym_float, sym_int from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from .. import config, polyfills, variables +from .. import config, variables from ..exc import ( AttributeMutationError, unimplemented, @@ -38,6 +38,7 @@ from ..utils import ( check_numpy_ndarray_args, check_unspec_or_constant_args, check_unspec_python_args, + does_not_override_dict_iter_methods, extract_fake_example_value, get_fake_value, guard_if_dyn, @@ -1025,10 +1026,6 @@ class BuiltinVariable(VariableTracker): return tx.output.side_effects.track_object_new_from_user_defined_class( args[0] ) - if self.fn is dict and name == "__new__": - assert len(args) == 1 - assert len(kwargs) == 0 - return ConstDictVariable({}, dict, mutation_type=ValueMutationNew()) if self.fn is dict and name == "fromkeys": return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs) return super().call_method(tx, name, args, kwargs) @@ -1373,11 +1370,73 @@ class BuiltinVariable(VariableTracker): @staticmethod def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs): - return tx.inline_user_function_return( - VariableTracker.build(tx, polyfills.construct_dict), - [VariableTracker.build(tx, user_cls), *args], - kwargs, - ) + if not kwargs: + if not args: + args = ({},) + assert len(args) == 1 + arg = args[0] + if isinstance(arg, dict): + return ConstDictVariable( + arg, user_cls, mutation_type=ValueMutationNew() + ) + elif isinstance(arg, variables.ConstDictVariable): + return arg.clone( + user_cls=user_cls, source=None, mutation_type=ValueMutationNew() + ) + elif isinstance( + arg, + ( + ListVariable, + TupleVariable, + ListIteratorVariable, + variables.IteratorVariable, + ), + ): + items = dict( + x.force_unpack_var_sequence(tx) + for x in arg.force_unpack_var_sequence(tx) + ) + return ConstDictVariable( + items, user_cls, mutation_type=ValueMutationNew() + ) + elif hasattr(arg, "value") and isinstance(arg.value, MutableMapping): + # This handles all other `MutableMapping` instances; for + # example, TensorDict which derives from MutableMapping. + # + # TODO(#142414) `hasattr(arg, 'value')` is a local workaround + # for lack of generall multiple inheritance in Dynamo. We can't + # use `isinstance(arg, MutableMappingVariable)` here because + # `arg` could be, e.g., a `UnspecializedNNModuleVariable` when + # `arg.value` has multiple inheritace. + if does_not_override_dict_iter_methods(type(arg.value)): + # In this case, `arg.value.items()` uses the default impls, + # which are implemented in C and cannot be traced, so we + # will have to manually construct the items. This is safe + # because we know they are side-effect free. + # + # Mutation tracked by Dynamo isn't reflected in `arg.value`, + # so we can't handle such cases by just calling + # `arg.value.items()` + if tx.output.side_effects.has_pending_mutation(arg): + unimplemented( + f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated" + ) + new_dict = dict(arg.value.items()) + return VariableTracker.build(tx, new_dict) + else: + func_var = arg.var_getattr(tx, "items") + if not isinstance(func_var, variables.UserFunctionVariable): + unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}") + out = tx.inline_user_function_return(func_var, args, kwargs) + if isinstance(out, ConstDictVariable): + return out + return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out) + elif not args and kwargs: + items = {ConstantVariable.create(k): v for k, v in kwargs.items()} + return variables.ConstDictVariable( + items, user_cls=user_cls, mutation_type=ValueMutationNew() + ) + unimplemented(f"{user_cls.__name__}(): {args} {kwargs}") @staticmethod def call_custom_dict_fromkeys( diff --git a/torch/_dynamo/variables/user_defined.py b/torch/_dynamo/variables/user_defined.py index fa8a8ae9eef..db20e981e07 100644 --- a/torch/_dynamo/variables/user_defined.py +++ b/torch/_dynamo/variables/user_defined.py @@ -217,7 +217,6 @@ class UserDefinedClassVariable(UserDefinedVariable): source and not inspect.ismethoddescriptor(obj) and not is_wrapper_or_member_descriptor(obj) - and obj is not dict.__new__ ): return VariableTracker.build(tx, obj, source) @@ -322,12 +321,6 @@ class UserDefinedClassVariable(UserDefinedVariable): return variables.ConstantVariable(self.value == args[0].value) elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"): return variables.ConstantVariable(self.value != args[0].value) - elif name == "__new__" and self.value is collections.OrderedDict: - assert len(args) == 1 - assert len(kwargs) == 0 - return variables.ConstDictVariable( - {}, collections.OrderedDict, mutation_type=ValueMutationNew() - ) return super().call_method(tx, name, args, kwargs) @@ -339,6 +332,7 @@ class UserDefinedClassVariable(UserDefinedVariable): ) -> "VariableTracker": from ..side_effects import SideEffects from .builder import wrap_fx_proxy + from .builtin import BuiltinVariable constant_args = check_constant_args(args, kwargs) @@ -358,10 +352,8 @@ class UserDefinedClassVariable(UserDefinedVariable): return NullContextVariable() elif self.value is collections.OrderedDict: - return tx.inline_user_function_return( - VariableTracker.build(tx, polyfills.construct_dict), - [self, *args], - kwargs, + return BuiltinVariable.call_custom_dict( + tx, collections.OrderedDict, *args, **kwargs ) elif ( self.value is collections.defaultdict @@ -1426,14 +1418,6 @@ class UserDefinedDictVariable(UserDefinedObjectVariable): return self._dict_vt.call_method(tx, name, args, kwargs) return super().call_method(tx, name, args, kwargs) - def unpack_var_sequence(self, tx): - if type(self.value).__iter__ in ( - dict.__iter__, - collections.OrderedDict.__iter__, - ): - return self._dict_vt.unpack_var_sequence(tx) - raise NotImplementedError - class MutableMappingVariable(UserDefinedObjectVariable): _nonvar_fields = UserDefinedObjectVariable._nonvar_fields