mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] Validate mutation_type and source in VariableTracker.__init__ (#141717)
As title, this also uncovered a few invalid use cases; the cases that cause error are fixed in separate patches prior to this patch, and the rest are fixed in this patch. This patch also moves a few `.source` mutation to variable construction, to increase the coverage of the validation. Fixes #133027. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141717 Approved by: https://github.com/jansel ghstack dependencies: #141713, #141714, #141715, #141902, #141716
This commit is contained in:
parent
0efd184685
commit
ff73e2e679
6 changed files with 35 additions and 24 deletions
|
|
@ -479,6 +479,23 @@ class VariableTracker(metaclass=VariableTrackerMeta):
|
|||
self.source = source
|
||||
self.mutation_type = mutation_type
|
||||
|
||||
# NOTE sometimes mutation_type is set afterwards for implementation
|
||||
# convenience, we don't validate those cases at the moment.
|
||||
if mutation_type is not None:
|
||||
if isinstance(mutation_type, (ValueMutationNew, AttributeMutationNew)):
|
||||
# If this fails, it's either
|
||||
# 1. one mistakenly passed in a source
|
||||
# 2. `mutation_type` is incorrect
|
||||
assert source is None
|
||||
else:
|
||||
assert isinstance(
|
||||
mutation_type, (ValueMutationExisting, AttributeMutationExisting)
|
||||
)
|
||||
# If this fails, it's either
|
||||
# 1. one forgot to pass in a source
|
||||
# 2. `mutation_type` is incorrect
|
||||
assert source is not None
|
||||
|
||||
|
||||
def typestr(*objs):
|
||||
if len(objs) == 1:
|
||||
|
|
|
|||
|
|
@ -431,11 +431,6 @@ class VariableBuilder:
|
|||
install_guard(*[source.make_guard(guard) for guard in guards], skip=1)
|
||||
return {}
|
||||
|
||||
def set_source_and_track_mutable(self, value, var):
|
||||
assert isinstance(var, VariableTracker)
|
||||
var.source = self.source
|
||||
return self.tx.output.side_effects.track_mutable(value, var)
|
||||
|
||||
@classmethod
|
||||
def _type_dispatch(cls):
|
||||
return cls._type_dispatch_impl(config.trace_numpy)
|
||||
|
|
@ -607,7 +602,6 @@ class VariableBuilder:
|
|||
elif CustomizedDictVariable.is_matching_cls_hf(type(value)):
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
result = CustomizedDictVariable.wrap(self, value)
|
||||
result.source = self.source
|
||||
return self.tx.output.side_effects.track_object_existing(value, result)
|
||||
elif istype(value, (dict, collections.defaultdict, collections.OrderedDict)):
|
||||
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
|
||||
|
|
@ -671,7 +665,7 @@ class VariableBuilder:
|
|||
result, user_cls=type(value), source=self.source
|
||||
)
|
||||
|
||||
return self.set_source_and_track_mutable(value, result)
|
||||
return self.tx.output.side_effects.track_mutable(value, result)
|
||||
elif isinstance(value, torch.nn.Module):
|
||||
return self.wrap_module(value)
|
||||
elif ConstantVariable.is_literal(value): # non-atomic literals
|
||||
|
|
@ -1137,7 +1131,7 @@ class VariableBuilder:
|
|||
)
|
||||
elif RestrictedListSubclassVariable.is_matching_cls(type(value)):
|
||||
self.install_guards(GuardBuilder.SEQUENCE_LENGTH)
|
||||
return self.set_source_and_track_mutable(
|
||||
return self.tx.output.side_effects.track_mutable(
|
||||
value,
|
||||
RestrictedListSubclassVariable(
|
||||
[
|
||||
|
|
@ -1148,6 +1142,7 @@ class VariableBuilder:
|
|||
],
|
||||
user_cls=type(value),
|
||||
user_cls_source=AttrSource(self.source, "__class__"),
|
||||
source=self.source,
|
||||
),
|
||||
)
|
||||
elif TorchScriptObjectVariable.is_matching_cls(type(value)):
|
||||
|
|
@ -1326,9 +1321,9 @@ class VariableBuilder:
|
|||
)
|
||||
tensor_list_proxy.node.meta["grapharg"] = grapharg
|
||||
|
||||
result = BaseListVariable.cls_for_instance(value)(output)
|
||||
result = BaseListVariable.cls_for_instance(value)(output, source=self.source)
|
||||
if istype(value, (list, collections.deque)):
|
||||
return self.set_source_and_track_mutable(value, result)
|
||||
return self.tx.output.side_effects.track_mutable(value, result)
|
||||
return result
|
||||
|
||||
def wrap_tuple_iterator(self, value: tuple_iterator):
|
||||
|
|
@ -1339,11 +1334,8 @@ class VariableBuilder:
|
|||
)
|
||||
for i in range(tuple_iterator_len(value))
|
||||
]
|
||||
result = TupleIteratorVariable(
|
||||
output, mutation_type=ValueMutationNew(), source=self.source
|
||||
)
|
||||
|
||||
return self.set_source_and_track_mutable(value, result)
|
||||
result = TupleIteratorVariable(output, source=self.source)
|
||||
return self.tx.output.side_effects.track_mutable(value, result)
|
||||
|
||||
def wrap_range_iterator(self, value: range_iterator):
|
||||
self.install_guards(GuardBuilder.RANGE_ITERATOR_MATCH)
|
||||
|
|
@ -1512,7 +1504,7 @@ class VariableBuilder:
|
|||
self.install_guards(GuardBuilder.CONSTANT_MATCH)
|
||||
result = ConstantVariable.create(value=value, source=self.source)
|
||||
if isinstance(value, (list, set)):
|
||||
return self.set_source_and_track_mutable(value, result)
|
||||
return self.tx.output.side_effects.track_mutable(value, result)
|
||||
return result
|
||||
|
||||
def assert_not_wrapped_by_this_graph(self, value: torch.Tensor):
|
||||
|
|
@ -2403,7 +2395,7 @@ def handle_traced_output(example_value, tx, proxy, options, subclass_type, targe
|
|||
elif istype(example_value, tuple):
|
||||
return TupleVariable(unpacked, **options)
|
||||
elif istype(example_value, (list, immutable_list)):
|
||||
return ListVariable(unpacked, mutation_type=ValueMutationNew(), **options)
|
||||
return ListVariable(unpacked, **options)
|
||||
else:
|
||||
assert example_value.__class__.__module__ == "torch.return_types" or hasattr(
|
||||
example_value, "_fields"
|
||||
|
|
|
|||
|
|
@ -1375,7 +1375,9 @@ class BuiltinVariable(VariableTracker):
|
|||
arg, user_cls, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif isinstance(arg, variables.ConstDictVariable):
|
||||
return arg.clone(user_cls=user_cls, mutation_type=ValueMutationNew())
|
||||
return arg.clone(
|
||||
user_cls=user_cls, source=None, mutation_type=ValueMutationNew()
|
||||
)
|
||||
elif isinstance(
|
||||
arg,
|
||||
(
|
||||
|
|
|
|||
|
|
@ -845,7 +845,7 @@ class CustomizedDictVariable(ConstDictVariable):
|
|||
if val is not None:
|
||||
key = ConstantVariable.create(key)
|
||||
items[key] = var
|
||||
return cls(items, user_cls)
|
||||
return cls(items, user_cls, source=builder.source)
|
||||
|
||||
def __init__(self, items, user_cls, **options) -> None:
|
||||
super().__init__(items, user_cls, **options)
|
||||
|
|
|
|||
|
|
@ -721,11 +721,11 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
|
||||
def call_backward(self, tx: "InstructionTranslator", args, kwargs):
|
||||
fn = self.fn_cls.backward
|
||||
self.source = AttrSource(self.source, "backward")
|
||||
assert type(args[0].value) is torch._dynamo.external_utils.FakeBackwardCFunction
|
||||
assert isinstance(fn, types.FunctionType)
|
||||
|
||||
return variables.UserFunctionVariable(fn, source=self.source).call_function(
|
||||
fn_source = AttrSource(self.source, "backward")
|
||||
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1117,6 +1117,9 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
(data.as_proxy(), placeholder.as_proxy()),
|
||||
{},
|
||||
),
|
||||
# In reconstruct() we should use the original parameter. The one
|
||||
# returned by the graph will be an alias.
|
||||
source=placeholder.source,
|
||||
)
|
||||
assert isinstance(result, variables.TensorVariable)
|
||||
result.class_type = torch.nn.Parameter
|
||||
|
|
@ -1127,9 +1130,6 @@ Either create the tensor outside the compiled region, or do not set the tensor t
|
|||
# has_grad_fn field to False to workaround the issue.
|
||||
result.has_grad_fn = False
|
||||
|
||||
# In reconstruct() should use the original parameter. The one returned by the graph will be an alias.
|
||||
result.source = placeholder.source
|
||||
|
||||
# TODO(jansel): if the new param falls out of scope, currently it won't get freed until
|
||||
# the end of the graph. We should fix this.
|
||||
return result
|
||||
|
|
|
|||
Loading…
Reference in a new issue