mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Revert "[dynamo][dicts] Consolidate dict(..) construction (#144342)"
This reverts commit a54a784b82.
Reverted https://github.com/pytorch/pytorch/pull/144342 on behalf of https://github.com/kit1980 due to breaking internal builds, see D68125388 ([comment](https://github.com/pytorch/pytorch/pull/144342#issuecomment-2597184167))
This commit is contained in:
parent
2ea394ba29
commit
5e6e6200bf
5 changed files with 84 additions and 65 deletions
|
|
@ -1810,8 +1810,8 @@ def forward(self, L_x_ : torch.Tensor):
|
|||
getitem_4 = map_impl[3]
|
||||
getitem_5 = map_impl[4]
|
||||
getitem_6 = map_impl[5]
|
||||
value = map_impl[6]; map_impl = None
|
||||
return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, value)""",
|
||||
getitem_7 = map_impl[6]; map_impl = None
|
||||
return (getitem_1, getitem_2, getitem_3, getitem_4, getitem_5, getitem_6, getitem_7)""",
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
body_graph,
|
||||
|
|
@ -2632,8 +2632,8 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
wrap_body_0 = self.wrap_body_0
|
||||
wrap = torch.ops.higher_order.wrap(wrap_body_0, l_x_); wrap_body_0 = l_x_ = None
|
||||
value: "f32[3]" = wrap[0]; wrap = None
|
||||
return (value,)
|
||||
getitem: "f32[3]" = wrap[0]; wrap = None
|
||||
return (getitem,)
|
||||
|
||||
class wrap_body_0(torch.nn.Module):
|
||||
def forward(self, l_x_: "f32[3]"):
|
||||
|
|
@ -4209,8 +4209,8 @@ class GraphModule(torch.nn.Module):
|
|||
child_1: "f32[5]" = child.sin()
|
||||
child_2: "f32[5]" = child.cos(); child = None
|
||||
|
||||
value: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
|
||||
value_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
|
||||
_unwrap_for_grad: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_1, 1)
|
||||
_unwrap_for_grad_1: "f32[5]" = torch._C._functorch._unwrap_for_grad(child_2, 1)
|
||||
|
||||
_grad_decrement_nesting = torch._C._functorch._grad_decrement_nesting(); _grad_decrement_nesting = None
|
||||
_saved_tensors_hooks_enable = torch._C._autograd._saved_tensors_hooks_enable(); _saved_tensors_hooks_enable = None
|
||||
|
|
@ -4219,7 +4219,7 @@ class GraphModule(torch.nn.Module):
|
|||
|
||||
_autograd_grad = torch._functorch.eager_transforms._autograd_grad([child_1, child_2], [child_3], [l_v_, child_4], retain_graph = True, create_graph = True); child_1 = child_2 = child_3 = l_v_ = child_4 = None
|
||||
getitem: "f32[5]" = _autograd_grad[0]; _autograd_grad = None
|
||||
return (value, value_1, getitem)
|
||||
return (_unwrap_for_grad, _unwrap_for_grad_1, getitem)
|
||||
""",
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -4088,9 +4088,9 @@ def forward(self, L_it_ : torch.Tensor, L_pytree_input_0_0_ : torch.Tensor, L_py
|
|||
while_loop = torch.ops.higher_order.while_loop(cond_fn_0, body_fn_0, (l_it_, l_pytree_input_0_0_, l_pytree_input_1_x_, l_pytree_input_1_y_), ()); cond_fn_0 = body_fn_0 = l_it_ = l_pytree_input_0_0_ = l_pytree_input_1_x_ = l_pytree_input_1_y_ = None
|
||||
getitem = while_loop[0]
|
||||
getitem_1 = while_loop[1]
|
||||
value = while_loop[2]
|
||||
value_1 = while_loop[3]; while_loop = None
|
||||
return (getitem, getitem_1, value, value_1)""", # noqa: B950
|
||||
getitem_2 = while_loop[2]
|
||||
getitem_3 = while_loop[3]; while_loop = None
|
||||
return (getitem, getitem_1, getitem_2, getitem_3)""", # noqa: B950
|
||||
)
|
||||
|
||||
def _wrap_with_functionalize(self, fn, func_type):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ Python polyfills for common builtins.
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
from itertools import repeat as _repeat
|
||||
from typing import Any, Callable, List, MutableMapping, Sequence, TYPE_CHECKING
|
||||
from typing import Any, Callable, List, Sequence, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
|
|
@ -146,30 +146,6 @@ def instantiate_user_defined_class_object(cls, /, *args, **kwargs):
|
|||
return obj
|
||||
|
||||
|
||||
# Used with something like dict(obj)
|
||||
def construct_dict(cls, /, *args, **kwargs):
|
||||
dst = cls.__new__(cls)
|
||||
|
||||
if args:
|
||||
src = args[0]
|
||||
|
||||
# Ensure that the overridden __iter__ method is invoked
|
||||
if isinstance(src, (dict, MutableMapping)):
|
||||
for key in src:
|
||||
# This will inline the __getitem__ of the src object
|
||||
dst[key] = src[key]
|
||||
else:
|
||||
# likely a sequence like tuple of pairs
|
||||
for key, value in src:
|
||||
dst[key] = value
|
||||
|
||||
if kwargs:
|
||||
for key in kwargs:
|
||||
dst[key] = kwargs[key]
|
||||
|
||||
return dst
|
||||
|
||||
|
||||
def foreach_map_fn(*args):
|
||||
op = args[0]
|
||||
new_args: List[Any] = []
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@ import math
|
|||
import operator
|
||||
import types
|
||||
from collections import defaultdict, OrderedDict
|
||||
from collections.abc import KeysView
|
||||
from collections.abc import KeysView, MutableMapping
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch import sym_float, sym_int
|
||||
from torch.utils._python_dispatch import is_traceable_wrapper_subclass
|
||||
|
||||
from .. import config, polyfills, variables
|
||||
from .. import config, variables
|
||||
from ..exc import (
|
||||
AttributeMutationError,
|
||||
unimplemented,
|
||||
|
|
@ -38,6 +38,7 @@ from ..utils import (
|
|||
check_numpy_ndarray_args,
|
||||
check_unspec_or_constant_args,
|
||||
check_unspec_python_args,
|
||||
does_not_override_dict_iter_methods,
|
||||
extract_fake_example_value,
|
||||
get_fake_value,
|
||||
guard_if_dyn,
|
||||
|
|
@ -1025,10 +1026,6 @@ class BuiltinVariable(VariableTracker):
|
|||
return tx.output.side_effects.track_object_new_from_user_defined_class(
|
||||
args[0]
|
||||
)
|
||||
if self.fn is dict and name == "__new__":
|
||||
assert len(args) == 1
|
||||
assert len(kwargs) == 0
|
||||
return ConstDictVariable({}, dict, mutation_type=ValueMutationNew())
|
||||
if self.fn is dict and name == "fromkeys":
|
||||
return BuiltinVariable.call_custom_dict_fromkeys(tx, dict, *args, **kwargs)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
|
@ -1373,11 +1370,73 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
@staticmethod
|
||||
def call_custom_dict(tx: "InstructionTranslator", user_cls, *args, **kwargs):
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.construct_dict),
|
||||
[VariableTracker.build(tx, user_cls), *args],
|
||||
kwargs,
|
||||
)
|
||||
if not kwargs:
|
||||
if not args:
|
||||
args = ({},)
|
||||
assert len(args) == 1
|
||||
arg = args[0]
|
||||
if isinstance(arg, dict):
|
||||
return ConstDictVariable(
|
||||
arg, user_cls, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif isinstance(arg, variables.ConstDictVariable):
|
||||
return arg.clone(
|
||||
user_cls=user_cls, source=None, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif isinstance(
|
||||
arg,
|
||||
(
|
||||
ListVariable,
|
||||
TupleVariable,
|
||||
ListIteratorVariable,
|
||||
variables.IteratorVariable,
|
||||
),
|
||||
):
|
||||
items = dict(
|
||||
x.force_unpack_var_sequence(tx)
|
||||
for x in arg.force_unpack_var_sequence(tx)
|
||||
)
|
||||
return ConstDictVariable(
|
||||
items, user_cls, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif hasattr(arg, "value") and isinstance(arg.value, MutableMapping):
|
||||
# This handles all other `MutableMapping` instances; for
|
||||
# example, TensorDict which derives from MutableMapping.
|
||||
#
|
||||
# TODO(#142414) `hasattr(arg, 'value')` is a local workaround
|
||||
# for lack of generall multiple inheritance in Dynamo. We can't
|
||||
# use `isinstance(arg, MutableMappingVariable)` here because
|
||||
# `arg` could be, e.g., a `UnspecializedNNModuleVariable` when
|
||||
# `arg.value` has multiple inheritace.
|
||||
if does_not_override_dict_iter_methods(type(arg.value)):
|
||||
# In this case, `arg.value.items()` uses the default impls,
|
||||
# which are implemented in C and cannot be traced, so we
|
||||
# will have to manually construct the items. This is safe
|
||||
# because we know they are side-effect free.
|
||||
#
|
||||
# Mutation tracked by Dynamo isn't reflected in `arg.value`,
|
||||
# so we can't handle such cases by just calling
|
||||
# `arg.value.items()`
|
||||
if tx.output.side_effects.has_pending_mutation(arg):
|
||||
unimplemented(
|
||||
f"{user_cls.__name__}.items(): {args} {kwargs} - object is mutated"
|
||||
)
|
||||
new_dict = dict(arg.value.items())
|
||||
return VariableTracker.build(tx, new_dict)
|
||||
else:
|
||||
func_var = arg.var_getattr(tx, "items")
|
||||
if not isinstance(func_var, variables.UserFunctionVariable):
|
||||
unimplemented(f"{user_cls.__name__}.items(): {args} {kwargs}")
|
||||
out = tx.inline_user_function_return(func_var, args, kwargs)
|
||||
if isinstance(out, ConstDictVariable):
|
||||
return out
|
||||
return BuiltinVariable(user_cls).call_custom_dict(tx, user_cls, out)
|
||||
elif not args and kwargs:
|
||||
items = {ConstantVariable.create(k): v for k, v in kwargs.items()}
|
||||
return variables.ConstDictVariable(
|
||||
items, user_cls=user_cls, mutation_type=ValueMutationNew()
|
||||
)
|
||||
unimplemented(f"{user_cls.__name__}(): {args} {kwargs}")
|
||||
|
||||
@staticmethod
|
||||
def call_custom_dict_fromkeys(
|
||||
|
|
|
|||
|
|
@ -217,7 +217,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
source
|
||||
and not inspect.ismethoddescriptor(obj)
|
||||
and not is_wrapper_or_member_descriptor(obj)
|
||||
and obj is not dict.__new__
|
||||
):
|
||||
return VariableTracker.build(tx, obj, source)
|
||||
|
||||
|
|
@ -322,12 +321,6 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
return variables.ConstantVariable(self.value == args[0].value)
|
||||
elif name == "__ne__" and len(args) == 1 and hasattr(args[0], "value"):
|
||||
return variables.ConstantVariable(self.value != args[0].value)
|
||||
elif name == "__new__" and self.value is collections.OrderedDict:
|
||||
assert len(args) == 1
|
||||
assert len(kwargs) == 0
|
||||
return variables.ConstDictVariable(
|
||||
{}, collections.OrderedDict, mutation_type=ValueMutationNew()
|
||||
)
|
||||
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
|
|
@ -339,6 +332,7 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
) -> "VariableTracker":
|
||||
from ..side_effects import SideEffects
|
||||
from .builder import wrap_fx_proxy
|
||||
from .builtin import BuiltinVariable
|
||||
|
||||
constant_args = check_constant_args(args, kwargs)
|
||||
|
||||
|
|
@ -358,10 +352,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
|||
|
||||
return NullContextVariable()
|
||||
elif self.value is collections.OrderedDict:
|
||||
return tx.inline_user_function_return(
|
||||
VariableTracker.build(tx, polyfills.construct_dict),
|
||||
[self, *args],
|
||||
kwargs,
|
||||
return BuiltinVariable.call_custom_dict(
|
||||
tx, collections.OrderedDict, *args, **kwargs
|
||||
)
|
||||
elif (
|
||||
self.value is collections.defaultdict
|
||||
|
|
@ -1426,14 +1418,6 @@ class UserDefinedDictVariable(UserDefinedObjectVariable):
|
|||
return self._dict_vt.call_method(tx, name, args, kwargs)
|
||||
return super().call_method(tx, name, args, kwargs)
|
||||
|
||||
def unpack_var_sequence(self, tx):
|
||||
if type(self.value).__iter__ in (
|
||||
dict.__iter__,
|
||||
collections.OrderedDict.__iter__,
|
||||
):
|
||||
return self._dict_vt.unpack_var_sequence(tx)
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MutableMappingVariable(UserDefinedObjectVariable):
|
||||
_nonvar_fields = UserDefinedObjectVariable._nonvar_fields
|
||||
|
|
|
|||
Loading…
Reference in a new issue