From 0506e95433ef39cfa698bf3cf23669fe3e877538 Mon Sep 17 00:00:00 2001 From: William Wen Date: Tue, 30 Apr 2024 10:58:38 -0700 Subject: [PATCH] [dynamo] support inactive context managers across graph breaks (#125203) Fix https://github.com/pytorch/pytorch/issues/124900. When we reconstruct `ContextWrappingVariables`s, we only reconstruct the context class, not the object. Normally, contexts are active (via `with ctx:`) and we initialize the context object in the resume function. But for the case of inactive contexts (contexts declared ahead of time before the `with` block), we do not reconstruct them properly in the optimized bytecode or resume function. So this PR adds initialization for inactive contexts in the resume function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125203 Approved by: https://github.com/jansel --- test/dynamo/test_ctx_manager.py | 16 ++++++++++++++++ torch/_dynamo/resume_execution.py | 28 ++++++++++++++++++++++++++++ torch/_dynamo/symbolic_convert.py | 23 +++++++++++++++++++++-- torch/_dynamo/variables/lazy.py | 8 ++++++++ 4 files changed, 73 insertions(+), 2 deletions(-) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index b3e286edabc..cc6e39de4d1 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1304,6 +1304,22 @@ class GraphModule(torch.nn.Module): self.assertEqual(fn(x), opt_fn(x)) self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) + def test_inactive_context_graph_break(self): + def fn(x): + x = x + 1 + ctx = torch.set_grad_enabled(True) + torch._dynamo.graph_break() + with ctx: + x = x + 1 + return x + + x = torch.zeros(10, requires_grad=False) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(fn, backend=cnts) + self.assertEqual(fn(x), opt_fn(x)) + self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) + self.assertEqual(cnts.frame_count, 2) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 545bb0f5c9f..969a679c9e9 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -317,6 +317,17 @@ def _filter_iter(l1, l2, cond): return res +def _load_tuple_and_call(tup): + insts = [] + if sys.version_info >= (3, 11): + insts.append(create_instruction("PUSH_NULL")) + insts.append(create_instruction("SWAP", arg=2)) + for val in tup: + insts.append(create_instruction("LOAD_CONST", argval=val)) + insts.extend(create_call_function(len(tup), False)) + return insts + + class ContinueExecutionCache: cache = ExactWeakKeyDictionary() generated_code_metadata = ExactWeakKeyDictionary() @@ -341,6 +352,8 @@ class ContinueExecutionCache: argnames: Tuple[str], argnames_null: Tuple[str], setup_fns: Tuple[ReenterWith], + stack_ctx_vars: Tuple[int, Tuple[Any]], + argnames_ctx_vars: Tuple[str, Tuple[Any]], null_idxes: Tuple[int], ) -> types.CodeType: assert offset is not None @@ -359,6 +372,8 @@ class ContinueExecutionCache: argnames, argnames_null, setup_fns, + stack_ctx_vars, + argnames_ctx_vars, null_idxes, ) @@ -420,6 +435,7 @@ class ContinueExecutionCache: # map old hook targets to new targets generated by the hook old_hook_target_remap = {} null_idxes_i = 0 + stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type] for i in range(nstack): while ( null_idxes_i < len(null_idxes) @@ -437,6 +453,12 @@ class ContinueExecutionCache: old_hook_target = offset_to_inst[hook_target_offset] meta.prefix_block_target_offset_remap.append(hook_target_offset) old_hook_target_remap[old_hook_target] = exn_target + real_i = i + null_idxes_i + if real_i in stack_ctx_vars_d: + # current stack variable is a context var - + # load args for context variable and construct it + prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[real_i])) + if is_py311_plus: # reverse the mapping since targets of later/nested contexts are inserted # into the mapping later, but show up earlier in the prefix. @@ -446,6 +468,12 @@ class ContinueExecutionCache: assert not hooks + # initialize inactive context vars in argnames + for name, vals in argnames_ctx_vars: + prefix.append(create_instruction("LOAD_FAST", argval=name)) + prefix.extend(_load_tuple_and_call(vals)) + prefix.append(create_instruction("STORE_FAST", argval=name)) + # 3.12+: store NULL into variables that were NULL if argnames_null: assert sys.version_info >= (3, 12) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 5e9eb41cc04..aef6d32bac0 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -2282,6 +2282,23 @@ class InstructionTranslator(InstructionTranslatorBase): if sys.version_info < (3, 12): assert len(argnames_null) == 0, "variables should not be NULL in < 3.12" + # Handle inactive context variables - inactive context variables + # are reconstructed to be the class, NOT the object. + # So the resume function needs to construct the context object + # from the class and the context object's target values. + # e.g. torch.set_grad_enabled(True) will be reconstructed as + # torch.set_grad_enabled + stack_ctx_vars = [] + for i, var in enumerate(self.stack): + if type.__instancecheck__(ContextWrappingVariable, var): + stack_ctx_vars.append((i, tuple(var.target_values))) # type: ignore[attr-defined] + argnames_ctx_vars = [] + for name in argnames: + if type.__instancecheck__( + ContextWrappingVariable, var := self.symbolic_locals[name] + ): + argnames_ctx_vars.append((name, tuple(var.target_values))) # type: ignore[attr-defined] + cg = PyCodegen(self) # Python does not allow null to be an arg to a function, so @@ -2293,12 +2310,12 @@ class InstructionTranslator(InstructionTranslatorBase): if sys.version_info >= (3, 11): # find indices of NullVariables for i, var in enumerate(self.stack): - if isinstance(var, NullVariable): + if type.__instancecheck__(NullVariable, var): null_idxes.append(i) # generate bytecode to pop the nulls null_cnt = 0 for i, var in enumerate(reversed(self.stack)): - if isinstance(var, NullVariable): + if type.__instancecheck__(NullVariable, var): for j in range(2, i + 2 - null_cnt): cg.append_output(create_instruction("SWAP", arg=j)) cg.extend_output(cg.pop_null()) @@ -2320,6 +2337,8 @@ class InstructionTranslator(InstructionTranslatorBase): argnames, argnames_null, tuple(b.resume_fn() for b in self.block_stack), + tuple(stack_ctx_vars), + tuple(argnames_ctx_vars), tuple(null_idxes), ) diff --git a/torch/_dynamo/variables/lazy.py b/torch/_dynamo/variables/lazy.py index c3dc781029d..4c68c7bf78a 100644 --- a/torch/_dynamo/variables/lazy.py +++ b/torch/_dynamo/variables/lazy.py @@ -19,9 +19,17 @@ class LazyCache: assert self.vt is None from ..symbolic_convert import InstructionTranslator from .builder import VariableBuilder + from .ctx_manager import ContextWrappingVariable, NullContextVariable + from .misc import NullVariable tx = InstructionTranslator.current_tx() self.vt = VariableBuilder(tx, self.source)(self.value) + + # we do not expect wrapping these variables in lazy VTs + assert not isinstance( + self.vt, (NullVariable, ContextWrappingVariable) + ) or isinstance(self.vt, NullContextVariable) + del self.value del self.source