From f2ab96a57e5b19c72cf11bc098fd690d238c1d86 Mon Sep 17 00:00:00 2001 From: William Wen Date: Wed, 1 May 2024 14:45:19 -0700 Subject: [PATCH] [dynamo] fix crash when context manager is passed to a function (#125321) Fix https://github.com/pytorch/pytorch/issues/125274. Main change was to reconstruct `ContextWrappingVariables` as objects in general, but we can replace them with the class on the caller side when generating the resume function. Pull Request resolved: https://github.com/pytorch/pytorch/pull/125321 Approved by: https://github.com/jansel --- test/dynamo/test_ctx_manager.py | 38 +++++++++++- torch/_dynamo/bytecode_transformation.py | 77 ++++++++++++++++++++---- torch/_dynamo/resume_execution.py | 7 ++- torch/_dynamo/symbolic_convert.py | 35 +++++++---- torch/_dynamo/variables/ctx_manager.py | 33 +++++++--- torch/_dynamo/variables/lazy.py | 7 --- 6 files changed, 151 insertions(+), 46 deletions(-) diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index cc6e39de4d1..eab8fdb41aa 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1304,7 +1304,7 @@ class GraphModule(torch.nn.Module): self.assertEqual(fn(x), opt_fn(x)) self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) - def test_inactive_context_graph_break(self): + def test_inactive_context_graph_break_local(self): def fn(x): x = x + 1 ctx = torch.set_grad_enabled(True) @@ -1320,6 +1320,42 @@ class GraphModule(torch.nn.Module): self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) self.assertEqual(cnts.frame_count, 2) + def test_inactive_context_graph_break_stack(self): + def gn(ctx): + torch._dynamo.graph_break() + return ctx + + def fn(x): + x = x + 1 + ctx = gn(torch.set_grad_enabled(True)) + # we expect a graph break on next line as well + with ctx: + x = x + 1 + return x + + x = torch.zeros(10, requires_grad=False) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(fn, backend=cnts) + self.assertEqual(fn(x), opt_fn(x)) + self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) + + def test_inactive_context_graph_break_stack2(self): + def gn(x, ctx, y, z, dummy): + with ctx: + return x * y * z + + def fn(x): + x = x + 1 + x = gn(x, torch.set_grad_enabled(True), 2, 3, torch._dynamo.graph_break()) + return x + + x = torch.zeros(10, requires_grad=False) + cnts = torch._dynamo.testing.CompileCounter() + opt_fn = torch.compile(fn, backend=cnts) + self.assertEqual(fn(x), opt_fn(x)) + self.assertEqual(fn(x).requires_grad, opt_fn(x).requires_grad) + self.assertEqual(cnts.frame_count, 2) + if __name__ == "__main__": from torch._dynamo.test_case import run_tests diff --git a/torch/_dynamo/bytecode_transformation.py b/torch/_dynamo/bytecode_transformation.py index 83f77626e0a..dec673b0e91 100644 --- a/torch/_dynamo/bytecode_transformation.py +++ b/torch/_dynamo/bytecode_transformation.py @@ -244,6 +244,53 @@ def create_load_method(name) -> Instruction: return create_instruction("LOAD_METHOD", argval=name) +def create_setup_with(target) -> Instruction: + opname = "BEFORE_WITH" if sys.version_info >= (3, 11) else "SETUP_WITH" + return create_instruction(opname, target=target) + + +def create_swap(n) -> List[Instruction]: + if sys.version_info >= (3, 11): + return [create_instruction("SWAP", arg=n)] + # in Python < 3.11, SWAP is a macro that expands to multiple instructions + if n == 1: + return [] + """ + e.g. swap "a" and "b" in this stack: + 0 a 1 2 3 b + 0 a [1 2 3 b] + 0 a [1 2 3 b] [1 2 3 b] + 0 a [1 2 3 b] [1 2 3 b] -1 + 0 a [1 2 3 b] b + 0 b a [1 2 3 b] + 0 b a [1 2 3 b] [1 2 3 b] + 0 b [1 2 3 b] a [1 2 3 b] + 0 b [1 2 3 b] a [1 2 3 b] -1 + 0 b [1 2 3 a] + 0 b [1 2 3 a] [1 2 3 a] + 0 b [1 2 3 a] [1 2 3 a] reverse + 0 b [a 3 2 1] None + 0 b [a 3 2 1] + 0 b 1 2 3 a + """ + return [ + create_instruction("BUILD_LIST", arg=n - 1), + create_instruction("DUP_TOP"), + create_instruction("LOAD_CONST", argval=-1), + create_instruction("BINARY_SUBSCR"), + create_instruction("ROT_THREE"), + create_instruction("DUP_TOP"), + create_instruction("ROT_THREE"), + create_instruction("LOAD_CONST", argval=-1), + create_instruction("STORE_SUBSCR"), + create_instruction("DUP_TOP"), + create_load_method("reverse"), + *create_call_method(0), + create_instruction("POP_TOP"), + create_instruction("UNPACK_SEQUENCE", arg=n - 1), + ] + + def lnotab_writer( lineno: int, byteno: int = 0 ) -> Tuple[List[int], Callable[[int, int], None]]: @@ -982,6 +1029,17 @@ def get_const_index(code_options, val) -> int: def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=None): # compute instruction arg from argval if arg is not provided names = {name: idx for idx, name in enumerate(code_options["co_names"])} + + def get_name_index(name) -> int: + try: + idx = names[name] + except KeyError: + # Add a missing item to co_names + idx = names[name] = len(names) + code_options["co_names"] = (*code_options["co_names"], name) + assert len(code_options["co_names"]) == len(names) + return idx + if sys.version_info < (3, 11): assert varname_from_oparg is None varnames = {name: idx for idx, name in enumerate(code_options["co_varnames"])} @@ -1016,27 +1074,27 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N assert instructions[i].arg is not None assert instructions[i].argval is not _NotProvided if sys.version_info >= (3, 11): - instructions[i].arg = (names[instructions[i].argval] << 1) + ( + instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + ( cast(int, instructions[i].arg) % 2 ) else: - instructions[i].arg = names[instructions[i].argval] + instructions[i].arg = get_name_index(instructions[i].argval) elif instructions[i].opname == "LOAD_ATTR": # 3.12 LOAD_ATTR requires both arg and argval, like LOAD_GLOBAL assert instructions[i].arg is not None assert instructions[i].argval is not _NotProvided if sys.version_info >= (3, 12): - instructions[i].arg = (names[instructions[i].argval] << 1) + ( + instructions[i].arg = (get_name_index(instructions[i].argval) << 1) + ( cast(int, instructions[i].arg) % 2 ) else: - instructions[i].arg = names[instructions[i].argval] + instructions[i].arg = get_name_index(instructions[i].argval) elif instructions[i].opname == "LOAD_SUPER_ATTR": assert instructions[i].arg is not None assert instructions[i].argval is not _NotProvided # Copy low bit, force second bit on for explicit super (the "+ 2") instructions[i].arg = ( - (names[instructions[i].argval] << 2) + (get_name_index(instructions[i].argval) << 2) + (cast(int, instructions[i].arg) % 2) + 2 ) @@ -1045,14 +1103,7 @@ def fix_vars(instructions: List[Instruction], code_options, varname_from_oparg=N instructions[i].arg = varnames[instructions[i].argval] elif instructions[i].opcode in HAS_NAME: if should_compute_arg(): - name = instructions[i].argval - try: - instructions[i].arg = names[name] - except KeyError: - # Add a missing item to co_names - instructions[i].arg = names[name] = len(names) - code_options["co_names"] = (*code_options["co_names"], name) - assert len(code_options["co_names"]) == len(names) + instructions[i].arg = get_name_index(instructions[i].argval) elif instructions[i].opcode in HAS_FREE: if should_compute_arg(): instructions[i].arg = freenames[instructions[i].argval] diff --git a/torch/_dynamo/resume_execution.py b/torch/_dynamo/resume_execution.py index 969a679c9e9..ced0013cadb 100644 --- a/torch/_dynamo/resume_execution.py +++ b/torch/_dynamo/resume_execution.py @@ -51,6 +51,7 @@ class ReenterWith: finally: exit context """ + # NOTE: we assume that TOS is a context manager CLASS! load_args = [] if self.target_values: load_args = [ @@ -156,6 +157,7 @@ class ReenterWith: with ctx(args): (rest) """ + # NOTE: we assume that TOS is a context manager CLASS! load_args = [] if self.target_values: load_args = [ @@ -455,8 +457,8 @@ class ContinueExecutionCache: old_hook_target_remap[old_hook_target] = exn_target real_i = i + null_idxes_i if real_i in stack_ctx_vars_d: - # current stack variable is a context var - - # load args for context variable and construct it + # NOTE: we assume that current stack var is a context manager CLASS! + # Load args for context variable and construct it prefix.extend(_load_tuple_and_call(stack_ctx_vars_d[real_i])) if is_py311_plus: @@ -468,6 +470,7 @@ class ContinueExecutionCache: assert not hooks + # NOTE: we assume that local var is a context manager CLASS! # initialize inactive context vars in argnames for name, vals in argnames_ctx_vars: prefix.append(create_instruction("LOAD_FAST", argval=name)) diff --git a/torch/_dynamo/symbolic_convert.py b/torch/_dynamo/symbolic_convert.py index 1ebf5ec26fa..2085d0813ba 100644 --- a/torch/_dynamo/symbolic_convert.py +++ b/torch/_dynamo/symbolic_convert.py @@ -18,7 +18,7 @@ import traceback import types import typing import weakref -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type +from typing import Any, Callable, cast, Dict, List, Optional, Set, Tuple, Type from unittest.mock import patch import torch @@ -37,6 +37,7 @@ from .bytecode_transformation import ( create_call_function, create_instruction, create_jump_absolute, + create_swap, get_code_keys, Instruction, is_generator, @@ -560,10 +561,10 @@ def break_graph_if_unsupported(*, push): self.output.compile_subgraph(self, reason=reason) cg = PyCodegen(self) cleanup: List[Instruction] = [] - # Reconstruct the context variables in the block stack + # Reconstruct the context variable CLASS in the block stack for b in self.block_stack: assert b.with_context is not None - cg(b.with_context) + b.with_context.reconstruct_type(cg) cg.extend_output(b.resume_fn().try_except(cg.code_options, cleanup)) self.output.add_output_instructions(cg.get_instructions()) del cg @@ -2285,24 +2286,32 @@ class InstructionTranslator(InstructionTranslatorBase): if sys.version_info < (3, 12): assert len(argnames_null) == 0, "variables should not be NULL in < 3.12" - # Handle inactive context variables - inactive context variables - # are reconstructed to be the class, NOT the object. - # So the resume function needs to construct the context object - # from the class and the context object's target values. - # e.g. torch.set_grad_enabled(True) will be reconstructed as - # torch.set_grad_enabled + cg = PyCodegen(self) + + # Handle inactive context variables. + # The resume function assumes that context variables are the class, NOT the object. + # e.g. torch.set_grad_enabled(True) will be reconstructed as torch.set_grad_enabled stack_ctx_vars = [] for i, var in enumerate(self.stack): if type.__instancecheck__(ContextWrappingVariable, var): - stack_ctx_vars.append((i, tuple(var.target_values))) # type: ignore[attr-defined] + ctx = cast(ContextWrappingVariable, var) + stack_ctx_vars.append((i, tuple(ctx.target_values))) + # Replace the current stack var with the context class + ctx.reconstruct_type(cg) + cg.extend_output(create_swap(len(self.stack) - i + 1)) + cg.append_output(create_instruction("POP_TOP")) + argnames_ctx_vars = [] for name in argnames: if type.__instancecheck__( ContextWrappingVariable, var := self.symbolic_locals[name] ): - argnames_ctx_vars.append((name, tuple(var.target_values))) # type: ignore[attr-defined] - - cg = PyCodegen(self) + ctx = cast(ContextWrappingVariable, var) + argnames_ctx_vars.append((name, tuple(ctx.target_values))) + # Replace the local with the context class + cg.append_output(create_instruction("LOAD_FAST", argval=name)) + ctx.reconstruct_type(cg) + cg.append_output(create_instruction("STORE_FAST", argval=name)) # Python does not allow null to be an arg to a function, so # we remove nulls from the stack and restore them in the diff --git a/torch/_dynamo/variables/ctx_manager.py b/torch/_dynamo/variables/ctx_manager.py index fa6d7d4f717..637636f1e04 100644 --- a/torch/_dynamo/variables/ctx_manager.py +++ b/torch/_dynamo/variables/ctx_manager.py @@ -1,6 +1,7 @@ # mypy: ignore-errors import dataclasses import inspect +import sys import warnings from typing import Callable, Dict, List, Optional @@ -8,7 +9,11 @@ import torch._C from torch._guards import Guard from .. import variables -from ..bytecode_transformation import create_call_function, create_instruction +from ..bytecode_transformation import ( + create_call_function, + create_instruction, + create_setup_with, +) from ..device_interface import get_interface_for_device from ..exc import unimplemented, Unsupported from ..guards import GuardBuilder, install_guard @@ -77,11 +82,21 @@ class ContextWrappingVariable(VariableTracker): self.state.cleanup_assert() return variables.ConstantVariable.create(None) - def reconstruct(self, codegen): + def reconstruct_type(self, codegen): codegen( AttrSource(codegen.tx.import_source(self.module_name()), self.fn_name()) ) + def reconstruct(self, codegen): + if sys.version_info >= (3, 11): + codegen.append_output(create_instruction("PUSH_NULL")) + self.reconstruct_type(codegen) + target_values = self.target_values + if not target_values: + target_values = () + codegen.extend_output([codegen.create_load_const(val) for val in target_values]) + codegen.extend_output(create_call_function(len(target_values), False)) + def module_name(self): raise NotImplementedError("module_name called on base") @@ -963,18 +978,16 @@ class WithExitFunctionVariable(VariableTracker): # Note here we reconstruct the context manager rather than the # exit function. The handler generated by BlockStackEntry # will re-enter the context in the resume function. - codegen( - AttrSource( - codegen.tx.import_source(self.ctx.module_name()), self.ctx.fn_name() - ) - ) - + self.ctx.reconstruct_type(codegen) if codegen.tx.output.partial_convert: + if sys.version_info >= (3, 11): + codegen.append_output(create_instruction("PUSH_NULL")) + codegen.append_output(create_instruction("SWAP", arg=2)) codegen.extend_output( [codegen.create_load_const(val) for val in self.ctx.target_values] ) codegen.extend_output( - create_call_function(len(self.ctx.target_values), True) + create_call_function(len(self.ctx.target_values), False) ) - codegen.append_output(create_instruction("SETUP_WITH", target=self.target)) + codegen.append_output(create_setup_with(self.target)) codegen.append_output(create_instruction("POP_TOP")) diff --git a/torch/_dynamo/variables/lazy.py b/torch/_dynamo/variables/lazy.py index 4c68c7bf78a..fb4f5cfa76c 100644 --- a/torch/_dynamo/variables/lazy.py +++ b/torch/_dynamo/variables/lazy.py @@ -19,17 +19,10 @@ class LazyCache: assert self.vt is None from ..symbolic_convert import InstructionTranslator from .builder import VariableBuilder - from .ctx_manager import ContextWrappingVariable, NullContextVariable - from .misc import NullVariable tx = InstructionTranslator.current_tx() self.vt = VariableBuilder(tx, self.source)(self.value) - # we do not expect wrapping these variables in lazy VTs - assert not isinstance( - self.vt, (NullVariable, ContextWrappingVariable) - ) or isinstance(self.vt, NullContextVariable) - del self.value del self.source