mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] fix crash when context manager is passed to a function (#125321)
Fix https://github.com/pytorch/pytorch/issues/125274. Main change was to reconstruct `ContextWrappingVariables` as objects in general, but we can replace them with the class on the caller side when generating the resume function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125321 Approved by: https://github.com/jansel
This commit is contained in:
parent
59abd1dccb
commit
f2ab96a57e
6 changed files with 151 additions and 46 deletions
|
|
@ -1304,7 +1304,7 @@ class GraphModule(torch.nn.Module):
|
|||
self.assertEqual(fn(x), opt_fn(x))
|
||||
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
|
||||
|
||||
def test_inactive_context_graph_break(self):
|
||||
def test_inactive_context_graph_break_local(self):
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
ctx = torch.set_grad_enabled(True)
|
||||
|
|
@ -1320,6 +1320,42 @@ class GraphModule(torch.nn.Module):
|
|||
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
def test_inactive_context_graph_break_stack(self):
|
||||
def gn(ctx):
|
||||
torch._dynamo.graph_break()
|
||||
return ctx
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
ctx = gn(torch.set_grad_enabled(True))
|
||||
# we expect a graph break on next line as well
|
||||
with ctx:
|
||||
x = x + 1
|
||||
return x
|
||||
|
||||
x = torch.zeros(10, requires_grad=False)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
|
||||
|
||||
def test_inactive_context_graph_break_stack2(self):
|
||||
def gn(x, ctx, y, z, dummy):
|
||||
with ctx:
|
||||
return x * y * z
|
||||
|
||||
def fn(x):
|
||||
x = x + 1
|
||||
x = gn(x, torch.set_grad_enabled(True), 2, 3, torch._dynamo.graph_break())
|
||||
return x
|
||||
|
||||
x = torch.zeros(10, requires_grad=False)
|
||||
cnts = torch._dynamo.testing.CompileCounter()
|
||||
opt_fn = torch.compile(fn, backend=cnts)
|
||||
self.assertEqual(fn(x), opt_fn(x))
|
||||
self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad)
|
||||
self.assertEqual(cnts.frame_count, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._dynamo.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -244,6 +244,53 @@ def create_load_method(name) -> Instruction:
|
|||
return create_instruction("LOAD_METHOD", argval=name)
|
||||
|
||||
|
||||
def create_setup_with(target) -> Instruction:
|
||||
opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH"
|
||||
return create_instruction(opname, target=target)
|
||||
|
||||
|
||||
def create_swap(n) -> List[Instruction]:
|
||||
if sys.version_info >= (3, 11):
|
||||
return [create_instruction("SWAP", arg=n)]
|
||||
# in Python < 3.11, SWAP is a macro that expands to multiple instructions
|
||||
if n == 1:
|
||||
return []
|
||||
"""
|
||||
e.g. swap "a" and "b" in this stack:
|
||||
0 a 1 2 3 b
|
||||
0 a [1 2 3 b]
|
||||
0 a [1 2 3 b] [1 2 3 b]
|
||||
0 a [1 2 3 b] [1 2 3 b] -1
|
||||
0 a [1 2 3 b] b
|
||||
0 b a [1 2 3 b]
|
||||
0 b a [1 2 3 b] [1 2 3 b]
|
||||
0 b [1 2 3 b] a [1 2 3 b]
|
||||
0 b [1 2 3 b] a [1 2 3 b] -1
|
||||
0 b [1 2 3 a]
|
||||
0 b [1 2 3 a] [1 2 3 a]
|
||||
0 b [1 2 3 a] [1 2 3 a] reverse
|
||||
0 b [a 3 2 1] None
|
||||
0 b [a 3 2 1]
|
||||
0 b 1 2 3 a
|
||||
"""
|
||||
return [
|
||||
create_instruction("BUILD_LIST", arg=n - 1),
|
||||
create_instruction("DUP_TOP"),
|
||||
create_instruction("LOAD_CONST", argval=-1),
|
||||
create_instruction("BINARY_SUBSCR"),
|
||||
create_instruction("ROT_THREE"),
|
||||
create_instruction("DUP_TOP"),
|
||||
create_instruction("ROT_THREE"),
|
||||
create_instruction("LOAD_CONST", argval=-1),
|
||||
create_instruction("STORE_SUBSCR"),
|
||||
create_instruction("DUP_TOP"),
|
||||
create_load_method("reverse"),
|
||||
*create_call_method(0),
|
||||
create_instruction("POP_TOP"),
|
||||
create_instruction("UNPACK_SEQUENCE", arg=n - 1),
|
||||
]
|
||||
|
||||
|
||||
def lnotab_writer(
|
||||
lineno: int, byteno: int = 0
|
||||
) -> Tuple[List[int], Callable[[int, int], None]]:
|
||||
|
|
@ -982,6 +1029,17 @@ def get_const_index(code_options, val) -> int:
|
|||
def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None):
|
||||
# compute instruction arg from argval if arg is not provided
|
||||
names = {name: idx for idx, name in enumerate(code_options["co_names"])}
|
||||
|
||||
def get_name_index(name) -> int:
|
||||
try:
|
||||
idx = names[name]
|
||||
except KeyError:
|
||||
# Add a missing item to co_names
|
||||
idx = names[name] = len(names)
|
||||
code_options["co_names"] = (*code_options["co_names"], name)
|
||||
assert len(code_options["co_names"]) == len(names)
|
||||
return idx
|
||||
|
||||
if sys.version_info < (3, 11):
|
||||
assert varname_from_oparg is None
|
||||
varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])}
|
||||
|
|
@ -1016,27 +1074,27 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N
|
|||
assert instructions[i].arg is not None
|
||||
assert instructions[i].argval is not _NotProvided
|
||||
if sys.version_info >= (3, 11):
|
||||
instructions[i].arg = (names[instructions[i].argval] << 1) + (
|
||||
instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + (
|
||||
cast(int, instructions[i].arg) % 2
|
||||
)
|
||||
else:
|
||||
instructions[i].arg = names[instructions[i].argval]
|
||||
instructions[i].arg = get_name_index(instructions[i].argval)
|
||||
elif instructions[i].opname == "LOAD_ATTR":
|
||||
# 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL
|
||||
assert instructions[i].arg is not None
|
||||
assert instructions[i].argval is not _NotProvided
|
||||
if sys.version_info >= (3, 12):
|
||||
instructions[i].arg = (names[instructions[i].argval] << 1) + (
|
||||
instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + (
|
||||
cast(int, instructions[i].arg) % 2
|
||||
)
|
||||
else:
|
||||
instructions[i].arg = names[instructions[i].argval]
|
||||
instructions[i].arg = get_name_index(instructions[i].argval)
|
||||
elif instructions[i].opname == "LOAD_SUPER_ATTR":
|
||||
assert instructions[i].arg is not None
|
||||
assert instructions[i].argval is not _NotProvided
|
||||
# Copy low bit, force second bit on for explicit super (the "+ 2")
|
||||
instructions[i].arg = (
|
||||
(names[instructions[i].argval] << 2)
|
||||
(get_name_index(instructions[i].argval) << 2)
|
||||
+ (cast(int, instructions[i].arg) % 2)
|
||||
+ 2
|
||||
)
|
||||
|
|
@ -1045,14 +1103,7 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N
|
|||
instructions[i].arg = varnames[instructions[i].argval]
|
||||
elif instructions[i].opcode in HAS_NAME:
|
||||
if should_compute_arg():
|
||||
name = instructions[i].argval
|
||||
try:
|
||||
instructions[i].arg = names[name]
|
||||
except KeyError:
|
||||
# Add a missing item to co_names
|
||||
instructions[i].arg = names[name] = len(names)
|
||||
code_options["co_names"] = (*code_options["co_names"], name)
|
||||
assert len(code_options["co_names"]) == len(names)
|
||||
instructions[i].arg = get_name_index(instructions[i].argval)
|
||||
elif instructions[i].opcode in HAS_FREE:
|
||||
if should_compute_arg():
|
||||
instructions[i].arg = freenames[instructions[i].argval]
|
||||
|
|
|
|||
|
|
@ -51,6 +51,7 @@ class ReenterWith:
|
|||
finally:
|
||||
exit context
|
||||
"""
|
||||
# NOTE: we assume that TOS is a context manager CLASS!
|
||||
load_args = []
|
||||
if self.target_values:
|
||||
load_args = [
|
||||
|
|
@ -156,6 +157,7 @@ class ReenterWith:
|
|||
with ctx(args):
|
||||
(rest)
|
||||
"""
|
||||
# NOTE: we assume that TOS is a context manager CLASS!
|
||||
load_args = []
|
||||
if self.target_values:
|
||||
load_args = [
|
||||
|
|
@ -455,8 +457,8 @@ class ContinueExecutionCache:
|
|||
old_hook_target_remap[old_hook_target] = exn_target
|
||||
real_i = i + null_idxes_i
|
||||
if real_i in stack_ctx_vars_d:
|
||||
# current stack variable is a context var -
|
||||
# load args for context variable and construct it
|
||||
# NOTE: we assume that current stack var is a context manager CLASS!
|
||||
# Load args for context variable and construct it
|
||||
prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[real_i]))
|
||||
|
||||
if is_py311_plus:
|
||||
|
|
@ -468,6 +470,7 @@ class ContinueExecutionCache:
|
|||
|
||||
assert not hooks
|
||||
|
||||
# NOTE: we assume that local var is a context manager CLASS!
|
||||
# initialize inactive context vars in argnames
|
||||
for name, vals in argnames_ctx_vars:
|
||||
prefix.append(create_instruction("LOAD_FAST", argval=name))
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ import traceback
|
|||
import types
|
||||
import typing
|
||||
import weakref
|
||||
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type
|
||||
from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type
|
||||
from unittest.mock import patch
|
||||
|
||||
import torch
|
||||
|
|
@ -37,6 +37,7 @@ from .bytecode_transformation import (
|
|||
create_call_function,
|
||||
create_instruction,
|
||||
create_jump_absolute,
|
||||
create_swap,
|
||||
get_code_keys,
|
||||
Instruction,
|
||||
is_generator,
|
||||
|
|
@ -560,10 +561,10 @@ def break_graph_if_unsupported(*, push):
|
|||
self.output.compile_subgraph(self, reason=reason)
|
||||
cg = PyCodegen(self)
|
||||
cleanup: List[Instruction] = []
|
||||
# Reconstruct the context variables in the block stack
|
||||
# Reconstruct the context variable CLASS in the block stack
|
||||
for b in self.block_stack:
|
||||
assert b.with_context is not None
|
||||
cg(b.with_context)
|
||||
b.with_context.reconstruct_type(cg)
|
||||
cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup))
|
||||
self.output.add_output_instructions(cg.get_instructions())
|
||||
del cg
|
||||
|
|
@ -2285,24 +2286,32 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
if sys.version_info < (3, 12):
|
||||
assert len(argnames_null) == 0, "variables should not be NULL in < 3.12"
|
||||
|
||||
# Handle inactive context variables - inactive context variables
|
||||
# are reconstructed to be the class, NOT the object.
|
||||
# So the resume function needs to construct the context object
|
||||
# from the class and the context object's target values.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as
|
||||
# torch.set_grad_enabled
|
||||
cg = PyCodegen(self)
|
||||
|
||||
# Handle inactive context variables.
|
||||
# The resume function assumes that context variables are the class, NOT the object.
|
||||
# e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled
|
||||
stack_ctx_vars = []
|
||||
for i, var in enumerate(self.stack):
|
||||
if type.__instancecheck__(ContextWrappingVariable, var):
|
||||
stack_ctx_vars.append((i, tuple(var.target_values))) # type: ignore[attr-defined]
|
||||
ctx = cast(ContextWrappingVariable, var)
|
||||
stack_ctx_vars.append((i, tuple(ctx.target_values)))
|
||||
# Replace the current stack var with the context class
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.extend_output(create_swap(len(self.stack) - i + 1))
|
||||
cg.append_output(create_instruction("POP_TOP"))
|
||||
|
||||
argnames_ctx_vars = []
|
||||
for name in argnames:
|
||||
if type.__instancecheck__(
|
||||
ContextWrappingVariable, var := self.symbolic_locals[name]
|
||||
):
|
||||
argnames_ctx_vars.append((name, tuple(var.target_values))) # type: ignore[attr-defined]
|
||||
|
||||
cg = PyCodegen(self)
|
||||
ctx = cast(ContextWrappingVariable, var)
|
||||
argnames_ctx_vars.append((name, tuple(ctx.target_values)))
|
||||
# Replace the local with the context class
|
||||
cg.append_output(create_instruction("LOAD_FAST", argval=name))
|
||||
ctx.reconstruct_type(cg)
|
||||
cg.append_output(create_instruction("STORE_FAST", argval=name))
|
||||
|
||||
# Python does not allow null to be an arg to a function, so
|
||||
# we remove nulls from the stack and restore them in the
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
# mypy: ignore-errors
|
||||
import dataclasses
|
||||
import inspect
|
||||
import sys
|
||||
import warnings
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
|
|
@ -8,7 +9,11 @@ import torch._C
|
|||
from torch._guards import Guard
|
||||
|
||||
from .. import variables
|
||||
from ..bytecode_transformation import create_call_function, create_instruction
|
||||
from ..bytecode_transformation import (
|
||||
create_call_function,
|
||||
create_instruction,
|
||||
create_setup_with,
|
||||
)
|
||||
from ..device_interface import get_interface_for_device
|
||||
from ..exc import unimplemented, Unsupported
|
||||
from ..guards import GuardBuilder, install_guard
|
||||
|
|
@ -77,11 +82,21 @@ class ContextWrappingVariable(VariableTracker):
|
|||
self.state.cleanup_assert()
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
def reconstruct_type(self, codegen):
|
||||
codegen(
|
||||
AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name())
|
||||
)
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
if sys.version_info >= (3, 11):
|
||||
codegen.append_output(create_instruction("PUSH_NULL"))
|
||||
self.reconstruct_type(codegen)
|
||||
target_values = self.target_values
|
||||
if not target_values:
|
||||
target_values = ()
|
||||
codegen.extend_output([codegen.create_load_const(val) for val in target_values])
|
||||
codegen.extend_output(create_call_function(len(target_values), False))
|
||||
|
||||
def module_name(self):
|
||||
raise NotImplementedError("module_name called on base")
|
||||
|
||||
|
|
@ -963,18 +978,16 @@ class WithExitFunctionVariable(VariableTracker):
|
|||
# Note here we reconstruct the context manager rather than the
|
||||
# exit function. The handler generated by BlockStackEntry
|
||||
# will re-enter the context in the resume function.
|
||||
codegen(
|
||||
AttrSource(
|
||||
codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name()
|
||||
)
|
||||
)
|
||||
|
||||
self.ctx.reconstruct_type(codegen)
|
||||
if codegen.tx.output.partial_convert:
|
||||
if sys.version_info >= (3, 11):
|
||||
codegen.append_output(create_instruction("PUSH_NULL"))
|
||||
codegen.append_output(create_instruction("SWAP", arg=2))
|
||||
codegen.extend_output(
|
||||
[codegen.create_load_const(val) for val in self.ctx.target_values]
|
||||
)
|
||||
codegen.extend_output(
|
||||
create_call_function(len(self.ctx.target_values), True)
|
||||
create_call_function(len(self.ctx.target_values), False)
|
||||
)
|
||||
codegen.append_output(create_instruction("SETUP_WITH", target=self.target))
|
||||
codegen.append_output(create_setup_with(self.target))
|
||||
codegen.append_output(create_instruction("POP_TOP"))
|
||||
|
|
|
|||
|
|
@ -19,17 +19,10 @@ class LazyCache:
|
|||
assert self.vt is None
|
||||
from ..symbolic_convert import InstructionTranslator
|
||||
from .builder import VariableBuilder
|
||||
from .ctx_manager import ContextWrappingVariable, NullContextVariable
|
||||
from .misc import NullVariable
|
||||
|
||||
tx = InstructionTranslator.current_tx()
|
||||
self.vt = VariableBuilder(tx, self.source)(self.value)
|
||||
|
||||
# we do not expect wrapping these variables in lazy VTs
|
||||
assert not isinstance(
|
||||
self.vt, (NullVariable, ContextWrappingVariable)
|
||||
) or isinstance(self.vt, NullContextVariable)
|
||||
|
||||
del self.value
|
||||
del self.source
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue