[dynamo] Validate mutation_type and source in VariableTracker.__init__ (#141717)

As title, this also uncovered a few invalid use cases; the cases that
cause error are fixed in separate patches prior to this patch, and the
rest are fixed in this patch.

This patch also moves a few `.source` mutation to variable construction,
to increase the coverage of the validation.

Fixes #133027.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141717
Approved by: https://github.com/jansel
ghstack dependencies: #141713, #141714, #141715, #141902, #141716
This commit is contained in:
Ryan Guo 2024-12-02 13:36:17 -08:00 committed by PyTorch MergeBot
parent 0efd184685
commit ff73e2e679
6 changed files with 35 additions and 24 deletions

View file

@ -479,6 +479,23 @@ class VariableTracker(metaclass=VariableTrackerMeta):
self.source = source
self.mutation_type = mutation_type
# NOTE sometimes mutation_type is set afterwards for implementation
# convenience, we don't validate those cases at the moment.
if mutation_type is not None:
if isinstance(mutation_type, (ValueMutationNew, AttributeMutationNew)):
# If this fails, it's either
# 1. one mistakenly passed in a source
# 2. `mutation_type` is incorrect
assert source is None
else:
assert isinstance(
mutation_type, (ValueMutationExisting, AttributeMutationExisting)
)
# If this fails, it's either
# 1. one forgot to pass in a source
# 2. `mutation_type` is incorrect
assert source is not None
def typestr(*objs):
if len(objs) == 1:

View file

@ -431,11 +431,6 @@ class VariableBuilder:
install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
return {}
def set_source_and_track_mutable(self, value, var):
assert isinstance(var, VariableTracker)
var.source = self.source
return self.tx.output.side_effects.track_mutable(value, var)
@classmethod
def _type_dispatch(cls):
return cls._type_dispatch_impl(config.trace_numpy)
@ -607,7 +602,6 @@ class VariableBuilder:
elif CustomizedDictVariable.is_matching_cls_hf(type(value)):
self.install_guards(GuardBuilder.TYPE_MATCH)
result = CustomizedDictVariable.wrap(self, value)
result.source = self.source
return self.tx.output.side_effects.track_object_existing(value, result)
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
@ -671,7 +665,7 @@ class VariableBuilder:
result, user_cls=type(value), source=self.source
)
return self.set_source_and_track_mutable(value, result)
return self.tx.output.side_effects.track_mutable(value, result)
elif isinstance(value, torch.nn.Module):
return self.wrap_module(value)
elif ConstantVariable.is_literal(value): # non-atomic literals
@ -1137,7 +1131,7 @@ class VariableBuilder:
)
elif RestrictedListSubclassVariable.is_matching_cls(type(value)):
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
return self.set_source_and_track_mutable(
return self.tx.output.side_effects.track_mutable(
value,
RestrictedListSubclassVariable(
[
@ -1148,6 +1142,7 @@ class VariableBuilder:
],
user_cls=type(value),
user_cls_source=AttrSource(self.source, "__class__"),
source=self.source,
),
)
elif TorchScriptObjectVariable.is_matching_cls(type(value)):
@ -1326,9 +1321,9 @@ class VariableBuilder:
)
tensor_list_proxy.node.meta["grapharg"] = grapharg
result = BaseListVariable.cls_for_instance(value)(output)
result = BaseListVariable.cls_for_instance(value)(output, source=self.source)
if istype(value, (list, collections.deque)):
return self.set_source_and_track_mutable(value, result)
return self.tx.output.side_effects.track_mutable(value, result)
return result
def wrap_tuple_iterator(self, value: tuple_iterator):
@ -1339,11 +1334,8 @@ class VariableBuilder:
)
for i in range(tuple_iterator_len(value))
]
result = TupleIteratorVariable(
output, mutation_type=ValueMutationNew(), source=self.source
)
return self.set_source_and_track_mutable(value, result)
result = TupleIteratorVariable(output, source=self.source)
return self.tx.output.side_effects.track_mutable(value, result)
def wrap_range_iterator(self, value: range_iterator):
self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH)
@ -1512,7 +1504,7 @@ class VariableBuilder:
self.install_guards(GuardBuilder.CONSTANT_MATCH)
result = ConstantVariable.create(value=value, source=self.source)
if isinstance(value, (list, set)):
return self.set_source_and_track_mutable(value, result)
return self.tx.output.side_effects.track_mutable(value, result)
return result
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
@ -2403,7 +2395,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
elif istype(example_value, tuple):
return TupleVariable(unpacked, **options)
elif istype(example_value, (list, immutable_list)):
return ListVariable(unpacked, mutation_type=ValueMutationNew(), **options)
return ListVariable(unpacked, **options)
else:
assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
example_value, "_fields"

View file

@ -1375,7 +1375,9 @@ class BuiltinVariable(VariableTracker):
arg, user_cls, mutation_type=ValueMutationNew()
)
elif isinstance(arg, variables.ConstDictVariable):
return arg.clone(user_cls=user_cls, mutation_type=ValueMutationNew())
return arg.clone(
user_cls=user_cls, source=None, mutation_type=ValueMutationNew()
)
elif isinstance(
arg,
(

View file

@ -845,7 +845,7 @@ class CustomizedDictVariable(ConstDictVariable):
if val is not None:
key = ConstantVariable.create(key)
items[key] = var
return cls(items, user_cls)
return cls(items, user_cls, source=builder.source)
def __init__(self, items, user_cls, **options) -> None:
super().__init__(items, user_cls, **options)

View file

@ -721,11 +721,11 @@ class AutogradFunctionVariable(VariableTracker):
def call_backward(self, tx: "InstructionTranslator", args, kwargs):
fn = self.fn_cls.backward
self.source = AttrSource(self.source, "backward")
assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
assert isinstance(fn, types.FunctionType)
return variables.UserFunctionVariable(fn, source=self.source).call_function(
fn_source = AttrSource(self.source, "backward")
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
tx, args, kwargs
)

View file

@ -1117,6 +1117,9 @@ Either create the tensor outside the compiled region, or do not set the tensor t
(data.as_proxy(), placeholder.as_proxy()),
{},
),
# In reconstruct() we should use the original parameter. The one
# returned by the graph will be an alias.
source=placeholder.source,
)
assert isinstance(result, variables.TensorVariable)
result.class_type = torch.nn.Parameter
@ -1127,9 +1130,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
# has_grad_fn field to False to workaround the issue.
result.has_grad_fn = False
# In reconstruct() should use the original parameter. The one returned by the graph will be an alias.
result.source = placeholder.source
# TODO(jansel): if the new param falls out of scope, currently it won't get freed until
# the end of the graph. We should fix this.
return result