mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] support inactive context managers across graph breaks (#125203)
Fix https://github.com/pytorch/pytorch/issues/124900. When we reconstruct `ContextWrappingVariables`s, we only reconstruct the context class, not the object. Normally, contexts are active (via `with ctx:`) and we initialize the context object in the resume function. But for the case of inactive contexts (contexts declared ahead of time before the `with` block), we do not reconstruct them properly in the optimized bytecode or resume function. So this PR adds initialization for inactive contexts in the resume function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125203 Approved by: https://github.com/jansel
This commit is contained in:
parent
1b9d353e4f
commit
0506e95433
4 changed files with 73 additions and 2 deletions
|
|
@ -1304,6 +1304,22 @@ class GraphModule(torch.nn.Module):
|
|||
self.assertEqual(fn(x), opt_fn(x))
|
||||
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
|
||||
|
||||
def test_inactive_context_graph_break(self):
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
ctx = torch.set_grad_enabled(True)
|
||||
torch._dynamo.graph_break()
|
||||
with ctx:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
x = torch.zeros(10, requires_grad=False)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -317,6 +317,17 @@ def _filter_iter(l1, l2, cond):
|
|||
return res
|
||||
|
||||
|
||||
def _load_tuple_and_call(tup):
|
||||
insts = []
|
||||
if sys.version_info >= (3, 11):
|
||||
insts.append(create_instruction("PUSH_NULL"))
|
||||
insts.append(create_instruction("SWAP", arg=2))
|
||||
for val in tup:
|
||||
insts.append(create_instruction("LOAD_CONST", argval=val))
|
||||
insts.extend(create_call_function(len(tup), False))
|
||||
return insts
|
||||
|
||||
|
||||
class ContinueExecutionCache:
|
||||
cache = ExactWeakKeyDictionary()
|
||||
generated_code_metadata = ExactWeakKeyDictionary()
|
||||
|
|
@ -341,6 +352,8 @@ class ContinueExecutionCache:
|
|||
argnames: Tuple[str],
|
||||
argnames_null: Tuple[str],
|
||||
setup_fns: Tuple[ReenterWith],
|
||||
stack_ctx_vars: Tuple[int, Tuple[Any]],
|
||||
argnames_ctx_vars: Tuple[str, Tuple[Any]],
|
||||
null_idxes: Tuple[int],
|
||||
) -> types.CodeType:
|
||||
assert offset is not None
|
||||
|
|
@ -359,6 +372,8 @@ class ContinueExecutionCache:
|
|||
argnames,
|
||||
argnames_null,
|
||||
setup_fns,
|
||||
stack_ctx_vars,
|
||||
argnames_ctx_vars,
|
||||
null_idxes,
|
||||
)
|
||||
|
||||
|
|
@ -420,6 +435,7 @@ class ContinueExecutionCache:
|
|||
# map old hook targets to new targets generated by the hook
|
||||
old_hook_target_remap = {}
|
||||
null_idxes_i = 0
|
||||
stack_ctx_vars_d = dict(stack_ctx_vars) # type: ignore[var-annotated,arg-type]
|
||||
for i in range(nstack):
|
||||
while (
|
||||
null_idxes_i < len(null_idxes)
|
||||
|
|
@ -437,6 +453,12 @@ class ContinueExecutionCache:
|
|||
old_hook_target = offset_to_inst[hook_target_offset]
|
||||
meta.prefix_block_target_offset_remap.append(hook_target_offset)
|
||||
old_hook_target_remap[old_hook_target] = exn_target
|
||||
real_i = i + null_idxes_i
|
||||
if real_i in stack_ctx_vars_d:
|
||||
# current stack variable is a context var -
|
||||
# load args for context variable and construct it
|
||||
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[real_i]))
|
||||
|
||||
if is_py311_plus:
|
||||
# reverse the mapping since targets of later/nested contexts are inserted
|
||||
# into the mapping later, but show up earlier in the prefix.
|
||||
|
|
@ -446,6 +468,12 @@ class ContinueExecutionCache:
|
|||
|
||||
assert not hooks
|
||||
|
||||
# initialize inactive context vars in argnames
|
||||
for name, vals in argnames_ctx_vars:
|
||||
prefix.append(create_instruction("LOAD_FAST", argval=name))
|
||||
prefix.extend(_load_tuple_and_call(vals))
|
||||
prefix.append(create_instruction("STORE_FAST", argval=name))
|
||||
|
||||
# 3.12+: store NULL into variables that were NULL
|
||||
if argnames_null:
|
||||
assert sys.version_info >= (3, 12)
|
||||
|
|
|
|||
|
|
@ -2282,6 +2282,23 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
|
||||
# Handle inactive context variables - inactive context variables
|
||||
# are reconstructed to be the class, NOT the object.
|
||||
# So the resume function needs to construct the context object
|
||||
# from the class and the context object's target values.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as
|
||||
# torch.set_grad_enabled
|
||||
stack_ctx_vars = []
|
||||
for i, var in enumerate(self.stack):
|
||||
if type.__instancecheck__(ContextWrappingVariable, var):
|
||||
stack_ctx_vars.append((i, tuple(var.target_values))) # type: ignore[attr-defined]
|
||||
argnames_ctx_vars = []
|
||||
for name in argnames:
|
||||
if type.__instancecheck__(
|
||||
ContextWrappingVariable, var := self.symbolic_locals[name]
|
||||
):
|
||||
argnames_ctx_vars.append((name, tuple(var.target_values))) # type: ignore[attr-defined]
|
||||
|
||||
cg = PyCodegen(self)
|
||||
|
||||
# Python does not allow null to be an arg to a function, so
|
||||
|
|
@ -2293,12 +2310,12 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
if sys.version_info >= (3, 11):
|
||||
# find indices of NullVariables
|
||||
for i, var in enumerate(self.stack):
|
||||
if isinstance(var, NullVariable):
|
||||
if type.__instancecheck__(NullVariable, var):
|
||||
null_idxes.append(i)
|
||||
# generate bytecode to pop the nulls
|
||||
null_cnt = 0
|
||||
for i, var in enumerate(reversed(self.stack)):
|
||||
if isinstance(var, NullVariable):
|
||||
if type.__instancecheck__(NullVariable, var):
|
||||
for j in range(2, i + 2 - null_cnt):
|
||||
cg.append_output(create_instruction("SWAP", arg=j))
|
||||
cg.extend_output(cg.pop_null())
|
||||
|
|
@ -2320,6 +2337,8 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
argnames,
|
||||
argnames_null,
|
||||
tuple(b.resume_fn() for b in self.block_stack),
|
||||
tuple(stack_ctx_vars),
|
||||
tuple(argnames_ctx_vars),
|
||||
tuple(null_idxes),
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -19,9 +19,17 @@ class LazyCache:
|
|||
assert self.vt is None
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
from .builder import VariableBuilder
|
||||
from .ctx_manager import ContextWrappingVariable, NullContextVariable
|
||||
from .misc import NullVariable
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
self.vt = VariableBuilder(tx, self.source)(self.value)
|
||||
|
||||
# we do not expect wrapping these variables in lazy VTs
|
||||
assert not isinstance(
|
||||
self.vt, (NullVariable, ContextWrappingVariable)
|
||||
) or isinstance(self.vt, NullContextVariable)
|
||||
|
||||
del self.value
|
||||
del self.source
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue