Revert "[dynamo][dicts] Consolidate dict(..) construction (#144342)"

This reverts commit a54a784b82.

Reverted https://github.com/pytorch/pytorch/pull/144342 on behalf of https://github.com/kit1980 due to breaking internal builds, see D68125388 ([comment](https://github.com/pytorch/pytorch/pull/144342#issuecomment-2597184167))
This commit is contained in:
PyTorch MergeBot 2025-01-17 00:32:09 +00:00
parent 2ea394ba29
commit 5e6e6200bf
5 changed files with 84 additions and 65 deletions

View file

@ -1810,8 +1810,8 @@ def forward(self, L_x_ : torch.Tensor):
getitem_4 = map_impl[3]
getitem_5 = map_impl[4]
getitem_6 = map_impl[5]
value = map_impl[6]; map_impl = None
return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, value)""",
getitem_7 = map_impl[6]; map_impl = None
return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7)""",
)
self.assertExpectedInline(
body_graph,
@ -2632,8 +2632,8 @@ class GraphModule(torch.nn.Module):
wrap_body_0 = self.wrap_body_0
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
value: "f32[3]" = wrap[0]; wrap = None
return (value,)
getitem: "f32[3]" = wrap[0]; wrap = None
return (getitem,)
class wrap_body_0(torch.nn.Module):
def forward(self, l_x_: "f32[3]"):
@ -4209,8 +4209,8 @@ class GraphModule(torch.nn.Module):
child_1: "f32[5]" = child.sin()
child_2: "f32[5]" = child.cos(); child = None
value: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
value_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
_unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
_unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
@ -4219,7 +4219,7 @@ class GraphModule(torch.nn.Module):
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None
getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None
return (value, value_1, getitem)
return (_unwrap_for_grad, _unwrap_for_grad_1, getitem)
""",
)

View file

@ -4088,9 +4088,9 @@ def forward(self, L_it_ : torch.Tensor, L_pytree_input_0_0_ : torch.Tensor, L_py
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_it_, l_pytree_input_0_0_, l_pytree_input_1_x_, l_pytree_input_1_y_), ()); cond_fn_0 = body_fn_0 = l_it_ = l_pytree_input_0_0_ = l_pytree_input_1_x_ = l_pytree_input_1_y_ = None
getitem = while_loop[0]
getitem_1 = while_loop[1]
value = while_loop[2]
value_1 = while_loop[3]; while_loop = None
return (getitem, getitem_1, value, value_1)""", # noqa: B950
getitem_2 = while_loop[2]
getitem_3 = while_loop[3]; while_loop = None
return (getitem, getitem_1, getitem_2, getitem_3)""", # noqa: B950
)
def _wrap_with_functionalize(self, fn, func_type):

View file

@ -9,7 +9,7 @@ Python polyfills for common builtins.
# mypy: allow-untyped-defs
from itertools import repeat as _repeat
from typing import Any, Callable, List, MutableMapping, Sequence, TYPE_CHECKING
from typing import Any, Callable, List, Sequence, TYPE_CHECKING
import torch
@ -146,30 +146,6 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
return obj
# Used with something like dict(obj)
def construct_dict(cls, /, *args, **kwargs):
dst = cls.__new__(cls)
if args:
src = args[0]
# Ensure that the overridden __iter__ method is invoked
if isinstance(src, (dict, MutableMapping)):
for key in src:
# This will inline the __getitem__ of the src object
dst[key] = src[key]
else:
# likely a sequence like tuple of pairs
for key, value in src:
dst[key] = value
if kwargs:
for key in kwargs:
dst[key] = kwargs[key]
return dst
def foreach_map_fn(*args):
op = args[0]
new_args: List[Any] = []

View file

@ -9,14 +9,14 @@ import math
import operator
import types
from collections import defaultdict, OrderedDict
from collections.abc import KeysView
from collections.abc import KeysView, MutableMapping
from typing import Dict, List, TYPE_CHECKING
import torch
from torch import sym_float, sym_int
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
from .. import config, polyfills, variables
from .. import config, variables
from ..exc import (
AttributeMutationError,
unimplemented,
@ -38,6 +38,7 @@ from ..utils import (
check_numpy_ndarray_args,
check_unspec_or_constant_args,
check_unspec_python_args,
does_not_override_dict_iter_methods,
extract_fake_example_value,
get_fake_value,
guard_if_dyn,
@ -1025,10 +1026,6 @@ class BuiltinVariable(VariableTracker):
return tx.output.side_effects.track_object_new_from_user_defined_class(
args[0]
)
if self.fn is dict and name == "__new__":
assert len(args) == 1
assert len(kwargs) == 0
return ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
if self.fn is dict and name == "fromkeys":
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
return super().call_method(tx, name, args, kwargs)
@ -1373,11 +1370,73 @@ class BuiltinVariable(VariableTracker):
@staticmethod
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.construct_dict),
[VariableTracker.build(tx, user_cls), *args],
kwargs,
)
if not kwargs:
if not args:
args = ({},)
assert len(args) == 1
arg = args[0]
if isinstance(arg, dict):
return ConstDictVariable(
arg, user_cls, mutation_type=ValueMutationNew()
)
elif isinstance(arg, variables.ConstDictVariable):
return arg.clone(
user_cls=user_cls, source=None, mutation_type=ValueMutationNew()
)
elif isinstance(
arg,
(
ListVariable,
TupleVariable,
ListIteratorVariable,
variables.IteratorVariable,
),
):
items = dict(
x.force_unpack_var_sequence(tx)
for x in arg.force_unpack_var_sequence(tx)
)
return ConstDictVariable(
items, user_cls, mutation_type=ValueMutationNew()
)
elif hasattr(arg, "value") and isinstance(arg.value, MutableMapping):
# This handles all other `MutableMapping` instances; for
# example, TensorDict which derives from MutableMapping.
#
# TODO(#142414) `hasattr(arg, 'value')` is a local workaround
# for lack of generall multiple inheritance in Dynamo. We can't
# use `isinstance(arg, MutableMappingVariable)` here because
# `arg` could be, e.g., a `UnspecializedNNModuleVariable` when
# `arg.value` has multiple inheritace.
if does_not_override_dict_iter_methods(type(arg.value)):
# In this case, `arg.value.items()` uses the default impls,
# which are implemented in C and cannot be traced, so we
# will have to manually construct the items. This is safe
# because we know they are side-effect free.
#
# Mutation tracked by Dynamo isn't reflected in `arg.value`,
# so we can't handle such cases by just calling
# `arg.value.items()`
if tx.output.side_effects.has_pending_mutation(arg):
unimplemented(
f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated"
)
new_dict = dict(arg.value.items())
return VariableTracker.build(tx, new_dict)
else:
func_var = arg.var_getattr(tx, "items")
if not isinstance(func_var, variables.UserFunctionVariable):
unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}")
out = tx.inline_user_function_return(func_var, args, kwargs)
if isinstance(out, ConstDictVariable):
return out
return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out)
elif not args and kwargs:
items = {ConstantVariable.create(k): v for k, v in kwargs.items()}
return variables.ConstDictVariable(
items, user_cls=user_cls, mutation_type=ValueMutationNew()
)
unimplemented(f"{user_cls.__name__}(): {args} {kwargs}")
@staticmethod
def call_custom_dict_fromkeys(

View file

@ -217,7 +217,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
source
and not inspect.ismethoddescriptor(obj)
and not is_wrapper_or_member_descriptor(obj)
and obj is not dict.__new__
):
return VariableTracker.build(tx, obj, source)
@ -322,12 +321,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
return variables.ConstantVariable(self.value == args[0].value)
elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):
return variables.ConstantVariable(self.value != args[0].value)
elif name == "__new__" and self.value is collections.OrderedDict:
assert len(args) == 1
assert len(kwargs) == 0
return variables.ConstDictVariable(
{}, collections.OrderedDict, mutation_type=ValueMutationNew()
)
return super().call_method(tx, name, args, kwargs)
@ -339,6 +332,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
) -> "VariableTracker":
from ..side_effects import SideEffects
from .builder import wrap_fx_proxy
from .builtin import BuiltinVariable
constant_args = check_constant_args(args, kwargs)
@ -358,10 +352,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
return NullContextVariable()
elif self.value is collections.OrderedDict:
return tx.inline_user_function_return(
VariableTracker.build(tx, polyfills.construct_dict),
[self, *args],
kwargs,
return BuiltinVariable.call_custom_dict(
tx, collections.OrderedDict, *args, **kwargs
)
elif (
self.value is collections.defaultdict
@ -1426,14 +1418,6 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
return self._dict_vt.call_method(tx, name, args, kwargs)
return super().call_method(tx, name, args, kwargs)
def unpack_var_sequence(self, tx):
if type(self.value).__iter__ in (
dict.__iter__,
collections.OrderedDict.__iter__,
):
return self._dict_vt.unpack_var_sequence(tx)
raise NotImplementedError
class MutableMappingVariable(UserDefinedObjectVariable):
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields