[dynamo] support inactive context managers across graph breaks (#125203)

Fix https://github.com/pytorch/pytorch/issues/124900.

When we reconstruct `ContextWrappingVariables`s, we only reconstruct the context class, not the object. Normally, contexts are active (via `with ctx:`) and we initialize the context object in the resume function. But for the case of inactive contexts (contexts declared ahead of time before the `with` block), we do not reconstruct them properly in the optimized bytecode or resume function. So this PR adds initialization for inactive contexts in the resume function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125203
Approved by: https://github.com/jansel
This commit is contained in:
William Wen 2024-04-30 10:58:38 -07:00 committed by PyTorch MergeBot
parent 1b9d353e4f
commit 0506e95433
4 changed files with 73 additions and 2 deletions

View file

@ -1304,6 +1304,22 @@ class GraphModule(torch.nn.Module):
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
def test_inactive_context_graph_break(self):
def fn(x):
x = x + 1
ctx = torch.set_grad_enabled(True)
torch._dynamo.graph_break()
with ctx:
x = x + 1
return x
x = torch.zeros(10, requires_grad=False)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
self.assertEqual(cnts.frame_count, 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View file

@ -317,6 +317,17 @@ def _filter_iter(l1, l2, cond):
return res
def _load_tuple_and_call(tup):
insts = []
if sys.version_info >= (3, 11):
insts.append(create_instruction("PUSH_NULL"))
insts.append(create_instruction("SWAP", arg=2))
for val in tup:
insts.append(create_instruction("LOAD_CONST", argval=val))
insts.extend(create_call_function(len(tup), False))
return insts
class ContinueExecutionCache:
cache = ExactWeakKeyDictionary()
generated_code_metadata = ExactWeakKeyDictionary()
@ -341,6 +352,8 @@ class ContinueExecutionCache:
argnames: Tuple[str],
argnames_null: Tuple[str],
setup_fns: Tuple[ReenterWith],
stack_ctx_vars: Tuple[int, Tuple[Any]],
argnames_ctx_vars: Tuple[str, Tuple[Any]],
null_idxes: Tuple[int],
) -> types.CodeType:
assert offset is not None
@ -359,6 +372,8 @@ class ContinueExecutionCache:
argnames,
argnames_null,
setup_fns,
stack_ctx_vars,
argnames_ctx_vars,
null_idxes,
)
@ -420,6 +435,7 @@ class ContinueExecutionCache:
# map old hook targets to new targets generated by the hook
old_hook_target_remap = {}
null_idxes_i = 0
stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type]
for i in range(nstack):
while (
null_idxes_i < len(null_idxes)
@ -437,6 +453,12 @@ class ContinueExecutionCache:
old_hook_target = offset_to_inst[hook_target_offset]
meta.prefix_block_target_offset_remap.append(hook_target_offset)
old_hook_target_remap[old_hook_target] = exn_target
real_i = i + null_idxes_i
if real_i in stack_ctx_vars_d:
# current stack variable is a context var -
# load args for context variable and construct it
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[real_i]))
if is_py311_plus:
# reverse the mapping since targets of later/nested contexts are inserted
# into the mapping later, but show up earlier in the prefix.
@ -446,6 +468,12 @@ class ContinueExecutionCache:
assert not hooks
# initialize inactive context vars in argnames
for name, vals in argnames_ctx_vars:
prefix.append(create_instruction("LOAD_FAST", argval=name))
prefix.extend(_load_tuple_and_call(vals))
prefix.append(create_instruction("STORE_FAST", argval=name))
# 3.12+: store NULL into variables that were NULL
if argnames_null:
assert sys.version_info >= (3, 12)

View file

@ -2282,6 +2282,23 @@ class InstructionTranslator(InstructionTranslatorBase):
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# Handle inactive context variables - inactive context variables
# are reconstructed to be the class, NOT the object.
# So the resume function needs to construct the context object
# from the class and the context object's target values.
# e.g. torch.set_grad_enabled(True) will be reconstructed as
# torch.set_grad_enabled
stack_ctx_vars = []
for i, var in enumerate(self.stack):
if type.__instancecheck__(ContextWrappingVariable, var):
stack_ctx_vars.append((i, tuple(var.target_values))) # type: ignore[attr-defined]
argnames_ctx_vars = []
for name in argnames:
if type.__instancecheck__(
ContextWrappingVariable, var := self.symbolic_locals[name]
):
argnames_ctx_vars.append((name, tuple(var.target_values))) # type: ignore[attr-defined]
cg = PyCodegen(self)
# Python does not allow null to be an arg to a function, so
@ -2293,12 +2310,12 @@ class InstructionTranslator(InstructionTranslatorBase):
if sys.version_info >= (3, 11):
# find indices of NullVariables
for i, var in enumerate(self.stack):
if isinstance(var, NullVariable):
if type.__instancecheck__(NullVariable, var):
null_idxes.append(i)
# generate bytecode to pop the nulls
null_cnt = 0
for i, var in enumerate(reversed(self.stack)):
if isinstance(var, NullVariable):
if type.__instancecheck__(NullVariable, var):
for j in range(2, i + 2 - null_cnt):
cg.append_output(create_instruction("SWAP", arg=j))
cg.extend_output(cg.pop_null())
@ -2320,6 +2337,8 @@ class InstructionTranslator(InstructionTranslatorBase):
argnames,
argnames_null,
tuple(b.resume_fn() for b in self.block_stack),
tuple(stack_ctx_vars),
tuple(argnames_ctx_vars),
tuple(null_idxes),
)

View file

@ -19,9 +19,17 @@ class LazyCache:
assert self.vt is None
from ..symbolic_convert import InstructionTranslator
from .builder import VariableBuilder
from .ctx_manager import ContextWrappingVariable, NullContextVariable
from .misc import NullVariable
tx = InstructionTranslator.current_tx()
self.vt = VariableBuilder(tx, self.source)(self.value)
# we do not expect wrapping these variables in lazy VTs
assert not isinstance(
self.vt, (NullVariable, ContextWrappingVariable)
) or isinstance(self.vt, NullContextVariable)
del self.value
del self.source