[dynamo] fix crash when context manager is passed to a function (#125321)

Fix https://github.com/pytorch/pytorch/issues/125274. Main change was to reconstruct `ContextWrappingVariables` as objects in general, but we can replace them with the class on the caller side when generating the resume function.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125321
Approved by: https://github.com/jansel
This commit is contained in:
William Wen 2024-05-01 14:45:19 -07:00 committed by PyTorch MergeBot
parent 59abd1dccb
commit f2ab96a57e
6 changed files with 151 additions and 46 deletions

View file

@ -1304,7 +1304,7 @@ class GraphModule(torch.nn.Module):
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
def test_inactive_context_graph_break(self):
def test_inactive_context_graph_break_local(self):
def fn(x):
x = x + 1
ctx = torch.set_grad_enabled(True)
@ -1320,6 +1320,42 @@ class GraphModule(torch.nn.Module):
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
self.assertEqual(cnts.frame_count, 2)
def test_inactive_context_graph_break_stack(self):
def gn(ctx):
torch._dynamo.graph_break()
return ctx
def fn(x):
x = x + 1
ctx = gn(torch.set_grad_enabled(True))
# we expect a graph break on next line as well
with ctx:
x = x + 1
return x
x = torch.zeros(10, requires_grad=False)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
def test_inactive_context_graph_break_stack2(self):
def gn(x, ctx, y, z, dummy):
with ctx:
return x * y * z
def fn(x):
x = x + 1
x = gn(x, torch.set_grad_enabled(True), 2, 3, torch._dynamo.graph_break())
return x
x = torch.zeros(10, requires_grad=False)
cnts = torch._dynamo.testing.CompileCounter()
opt_fn = torch.compile(fn, backend=cnts)
self.assertEqual(fn(x), opt_fn(x))
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
self.assertEqual(cnts.frame_count, 2)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests

View file

@ -244,6 +244,53 @@ def create_load_method(name) -> Instruction:
return create_instruction("LOAD_METHOD", argval=name)
def create_setup_with(target) -> Instruction:
opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH"
return create_instruction(opname, target=target)
def create_swap(n) -> List[Instruction]:
if sys.version_info >= (3, 11):
return [create_instruction("SWAP", arg=n)]
# in Python < 3.11, SWAP is a macro that expands to multiple instructions
if n == 1:
return []
"""
e.g. swap "a" and "b" in this stack:
0 a 1 2 3 b
0 a [1 2 3 b]
0 a [1 2 3 b] [1 2 3 b]
0 a [1 2 3 b] [1 2 3 b] -1
0 a [1 2 3 b] b
0 b a [1 2 3 b]
0 b a [1 2 3 b] [1 2 3 b]
0 b [1 2 3 b] a [1 2 3 b]
0 b [1 2 3 b] a [1 2 3 b] -1
0 b [1 2 3 a]
0 b [1 2 3 a] [1 2 3 a]
0 b [1 2 3 a] [1 2 3 a] reverse
0 b [a 3 2 1] None
0 b [a 3 2 1]
0 b 1 2 3 a
"""
return [
create_instruction("BUILD_LIST", arg=n - 1),
create_instruction("DUP_TOP"),
create_instruction("LOAD_CONST", argval=-1),
create_instruction("BINARY_SUBSCR"),
create_instruction("ROT_THREE"),
create_instruction("DUP_TOP"),
create_instruction("ROT_THREE"),
create_instruction("LOAD_CONST", argval=-1),
create_instruction("STORE_SUBSCR"),
create_instruction("DUP_TOP"),
create_load_method("reverse"),
*create_call_method(0),
create_instruction("POP_TOP"),
create_instruction("UNPACK_SEQUENCE", arg=n - 1),
]
def lnotab_writer(
lineno: int, byteno: int = 0
) -> Tuple[List[int], Callable[[int, int], None]]:
@ -982,6 +1029,17 @@ def get_const_index(code_options, val) -> int:
def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None):
# compute instruction arg from argval if arg is not provided
names = {name: idx for idx, name in enumerate(code_options["co_names"])}
def get_name_index(name) -> int:
try:
idx = names[name]
except KeyError:
# Add a missing item to co_names
idx = names[name] = len(names)
code_options["co_names"] = (*code_options["co_names"], name)
assert len(code_options["co_names"]) == len(names)
return idx
if sys.version_info < (3, 11):
assert varname_from_oparg is None
varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])}
@ -1016,27 +1074,27 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N
assert instructions[i].arg is not None
assert instructions[i].argval is not _NotProvided
if sys.version_info >= (3, 11):
instructions[i].arg = (names[instructions[i].argval] << 1) + (
instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + (
cast(int, instructions[i].arg) % 2
)
else:
instructions[i].arg = names[instructions[i].argval]
instructions[i].arg = get_name_index(instructions[i].argval)
elif instructions[i].opname == "LOAD_ATTR":
# 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL
assert instructions[i].arg is not None
assert instructions[i].argval is not _NotProvided
if sys.version_info >= (3, 12):
instructions[i].arg = (names[instructions[i].argval] << 1) + (
instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + (
cast(int, instructions[i].arg) % 2
)
else:
instructions[i].arg = names[instructions[i].argval]
instructions[i].arg = get_name_index(instructions[i].argval)
elif instructions[i].opname == "LOAD_SUPER_ATTR":
assert instructions[i].arg is not None
assert instructions[i].argval is not _NotProvided
# Copy low bit, force second bit on for explicit super (the "+ 2")
instructions[i].arg = (
(names[instructions[i].argval] << 2)
(get_name_index(instructions[i].argval) << 2)
+ (cast(int, instructions[i].arg) % 2)
+ 2
)
@ -1045,14 +1103,7 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N
instructions[i].arg = varnames[instructions[i].argval]
elif instructions[i].opcode in HAS_NAME:
if should_compute_arg():
name = instructions[i].argval
try:
instructions[i].arg = names[name]
except KeyError:
# Add a missing item to co_names
instructions[i].arg = names[name] = len(names)
code_options["co_names"] = (*code_options["co_names"], name)
assert len(code_options["co_names"]) == len(names)
instructions[i].arg = get_name_index(instructions[i].argval)
elif instructions[i].opcode in HAS_FREE:
if should_compute_arg():
instructions[i].arg = freenames[instructions[i].argval]

View file

@ -51,6 +51,7 @@ class ReenterWith:
finally:
exit context
"""
# NOTE: we assume that TOS is a context manager CLASS!
load_args = []
if self.target_values:
load_args = [
@ -156,6 +157,7 @@ class ReenterWith:
with ctx(args):
(rest)
"""
# NOTE: we assume that TOS is a context manager CLASS!
load_args = []
if self.target_values:
load_args = [
@ -455,8 +457,8 @@ class ContinueExecutionCache:
old_hook_target_remap[old_hook_target] = exn_target
real_i = i + null_idxes_i
if real_i in stack_ctx_vars_d:
# current stack variable is a context var -
# load args for context variable and construct it
# NOTE: we assume that current stack var is a context manager CLASS!
# Load args for context variable and construct it
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[real_i]))
if is_py311_plus:
@ -468,6 +470,7 @@ class ContinueExecutionCache:
assert not hooks
# NOTE: we assume that local var is a context manager CLASS!
# initialize inactive context vars in argnames
for name, vals in argnames_ctx_vars:
prefix.append(create_instruction("LOAD_FAST", argval=name))

View file

@ -18,7 +18,7 @@ import traceback
import types
import typing
import weakref
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type
from unittest.mock import patch
import torch
@ -37,6 +37,7 @@ from .bytecode_transformation import (
create_call_function,
create_instruction,
create_jump_absolute,
create_swap,
get_code_keys,
Instruction,
is_generator,
@ -560,10 +561,10 @@ def break_graph_if_unsupported(*, push):
self.output.compile_subgraph(self, reason=reason)
cg = PyCodegen(self)
cleanup: List[Instruction] = []
# Reconstruct the context variables in the block stack
# Reconstruct the context variable CLASS in the block stack
for b in self.block_stack:
assert b.with_context is not None
cg(b.with_context)
b.with_context.reconstruct_type(cg)
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
self.output.add_output_instructions(cg.get_instructions())
del cg
@ -2285,24 +2286,32 @@ class InstructionTranslator(InstructionTranslatorBase):
if sys.version_info < (3, 12):
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
# Handle inactive context variables - inactive context variables
# are reconstructed to be the class, NOT the object.
# So the resume function needs to construct the context object
# from the class and the context object's target values.
# e.g. torch.set_grad_enabled(True) will be reconstructed as
# torch.set_grad_enabled
cg = PyCodegen(self)
# Handle inactive context variables.
# The resume function assumes that context variables are the class, NOT the object.
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
stack_ctx_vars = []
for i, var in enumerate(self.stack):
if type.__instancecheck__(ContextWrappingVariable, var):
stack_ctx_vars.append((i, tuple(var.target_values))) # type: ignore[attr-defined]
ctx = cast(ContextWrappingVariable, var)
stack_ctx_vars.append((i, tuple(ctx.target_values)))
# Replace the current stack var with the context class
ctx.reconstruct_type(cg)
cg.extend_output(create_swap(len(self.stack) - i + 1))
cg.append_output(create_instruction("POP_TOP"))
argnames_ctx_vars = []
for name in argnames:
if type.__instancecheck__(
ContextWrappingVariable, var := self.symbolic_locals[name]
):
argnames_ctx_vars.append((name, tuple(var.target_values))) # type: ignore[attr-defined]
cg = PyCodegen(self)
ctx = cast(ContextWrappingVariable, var)
argnames_ctx_vars.append((name, tuple(ctx.target_values)))
# Replace the local with the context class
cg.append_output(create_instruction("LOAD_FAST", argval=name))
ctx.reconstruct_type(cg)
cg.append_output(create_instruction("STORE_FAST", argval=name))
# Python does not allow null to be an arg to a function, so
# we remove nulls from the stack and restore them in the

View file

@ -1,6 +1,7 @@
# mypy: ignore-errors
import dataclasses
import inspect
import sys
import warnings
from typing import Callable, Dict, List, Optional
@ -8,7 +9,11 @@ import torch._C
from torch._guards import Guard
from .. import variables
from ..bytecode_transformation import create_call_function, create_instruction
from ..bytecode_transformation import (
create_call_function,
create_instruction,
create_setup_with,
)
from ..device_interface import get_interface_for_device
from ..exc import unimplemented, Unsupported
from ..guards import GuardBuilder, install_guard
@ -77,11 +82,21 @@ class ContextWrappingVariable(VariableTracker):
self.state.cleanup_assert()
return variables.ConstantVariable.create(None)
def reconstruct(self, codegen):
def reconstruct_type(self, codegen):
codegen(
AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
)
def reconstruct(self, codegen):
if sys.version_info >= (3, 11):
codegen.append_output(create_instruction("PUSH_NULL"))
self.reconstruct_type(codegen)
target_values = self.target_values
if not target_values:
target_values = ()
codegen.extend_output([codegen.create_load_const(val) for val in target_values])
codegen.extend_output(create_call_function(len(target_values), False))
def module_name(self):
raise NotImplementedError("module_name called on base")
@ -963,18 +978,16 @@ class WithExitFunctionVariable(VariableTracker):
# Note here we reconstruct the context manager rather than the
# exit function. The handler generated by BlockStackEntry
# will re-enter the context in the resume function.
codegen(
AttrSource(
codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name()
)
)
self.ctx.reconstruct_type(codegen)
if codegen.tx.output.partial_convert:
if sys.version_info >= (3, 11):
codegen.append_output(create_instruction("PUSH_NULL"))
codegen.append_output(create_instruction("SWAP", arg=2))
codegen.extend_output(
[codegen.create_load_const(val) for val in self.ctx.target_values]
)
codegen.extend_output(
create_call_function(len(self.ctx.target_values), True)
create_call_function(len(self.ctx.target_values), False)
)
codegen.append_output(create_instruction("SETUP_WITH", target=self.target))
codegen.append_output(create_setup_with(self.target))
codegen.append_output(create_instruction("POP_TOP"))

View file

@ -19,17 +19,10 @@ class LazyCache:
assert self.vt is None
from ..symbolic_convert import InstructionTranslator
from .builder import VariableBuilder
from .ctx_manager import ContextWrappingVariable, NullContextVariable
from .misc import NullVariable
tx = InstructionTranslator.current_tx()
self.vt = VariableBuilder(tx, self.source)(self.value)
# we do not expect wrapping these variables in lazy VTs
assert not isinstance(
self.vt, (NullVariable, ContextWrappingVariable)
) or isinstance(self.vt, NullContextVariable)
del self.value
del self.source