[dynamo] Remove AutoDerefLocalSource and simplify cell handling (#141629)

This patch
1. removes `AutoDerefLocalSource` in favor of `LocalSource`, thereby
   removing its special handling in guards.
2. introduces a `LocalCellSource` for cells from the root frame, with
   only `reconstruct` implemented, to programmatically enforce that thse
   cells should never be used by other components like guards.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141629
Approved by: https://github.com/jansel
ghstack dependencies: #141628
This commit is contained in:
Ryan Guo 2024-11-27 11:53:47 -08:00 committed by PyTorch MergeBot
parent e14d8c980f
commit 2d708752f0
7 changed files with 41 additions and 87 deletions

View file

@ -453,7 +453,7 @@ class PyCodegen:
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
# requires that in the generated bytecode, these cells would keep
# their original local names, which we ensure via
# `LocalSource.is_root_frame_cell`.
# `CellVariable.local_name`.
for var in freevars:
assert var in self.cell_and_freevars()
output.append(self.create_load_closure(var))

View file

@ -44,7 +44,6 @@ from torch._C._dynamo.guards import (
check_type_id,
dict_version,
DictGuardManager,
GuardManager,
install_no_tensor_aliasing_guard,
install_object_aliasing_guard,
profile_guard_manager,
@ -83,7 +82,6 @@ from .eval_frame import set_guard_error_hook
from .source import (
AttrProxySource,
AttrSource,
AutoDerefLocalSource,
CallFunctionNoArgsSource,
ChainedSource,
ConstDictKeySource,
@ -968,14 +966,6 @@ class GuardBuilder(GuardBuilderBase):
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, AutoDerefLocalSource):
# Guard checks run on f_locals, in which the python level
# auto-dereferenced cell objects are also dereferenced (e.g., rather
# than `f_locals` being `{ 'cell' : <cell object of int> }`, it'll
# be `{ 'cell' : <int> }`. So the guard manager is the same as the
# base guard manager.
assert isinstance(base_guard_manager, GuardManager) # tame mypy
out = base_guard_manager
elif istype(source, GlobalSource):
# Global manager accepts a dict but it is not a DictGuardManager
# because globals dict is big and we typically guard on a very

View file

@ -117,7 +117,7 @@ from .variables.builder import (
wrap_fx_proxy,
)
from .variables.lists import BaseListVariable
from .variables.misc import NullVariable
from .variables.misc import CellVariable, NullVariable
from .variables.nn_module import NNModuleVariable
from .variables.tensor import (
NumpyNdarrayVariable,
@ -1041,6 +1041,8 @@ class OutputGraph:
# while running test_subgraphs.py
if isinstance(v.source, LocalSource) and v.source.local_name == k:
continue # no need to restore initial state
if isinstance(v, CellVariable) and v.local_name == k:
continue # no need to restore initial state
# Do not load variable if it is NULL.
if sys.version_info >= (3, 12):
# Continuation function will load the NULL for v.

View file

@ -19,7 +19,7 @@ from .bytecode_transformation import (
)
from .codegen import PyCodegen
from .exc import unimplemented
from .source import GlobalSource, LocalSource, Source
from .source import GlobalSource, LocalCellSource, LocalSource, Source
from .utils import is_frozen_dataclass, nn_module_new, object_new
from .variables.base import (
AttributeMutation,
@ -431,13 +431,15 @@ class SideEffects:
# `MAKE_CELL` or by them being in `co_cellvars`, so we only emit
# `make_cell` for the non-root-frame cells here.
# TODO generalize this so we never need to call `make_cell`.
if not var.is_root_frame_cell():
if var.local_name is None:
cg.add_push_null(
lambda: cg.load_import_from(utils.__name__, "make_cell")
)
cg.extend_output(create_call_function(0, False))
cg.add_cache(var)
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
elif var.source is None:
var.source = LocalCellSource(var.local_name)
elif isinstance(var.mutation_type, AttributeMutationNew):
if isinstance(var, variables.AutogradFunctionContextVariable):
unimplemented("AutogradFunctionContextVariable escaped")
@ -652,14 +654,13 @@ class SideEffects:
cg.call_function(1, False)
cg.append_output(create_instruction("POP_TOP"))
elif isinstance(var, variables.CellVariable) and var.is_root_frame_cell():
elif isinstance(var, variables.CellVariable) and var.local_name is not None:
# Emit more readable and performant bytecode.
# TODO generalize this for cells created during inlining.
if var in self.store_attr_mutations:
contents_var = self.load_cell(var)
cg(contents_var)
cell_name = var.source.local_name # type: ignore[attr-defined]
suffixes.append([cg.create_store_deref(cell_name)])
suffixes.append([cg.create_store_deref(var.local_name)])
elif self.is_attribute_mutation(var):
# Applying mutations involves two steps: 1) Push all

View file

@ -107,15 +107,14 @@ class LocalSource(Source):
# Whether this local is an input to the root frame.
is_input: bool = False
# Whether the item at this source is a that is native to the root frame,
# i.e., a part of its `co_cellvars` or `co_freevars`.
is_root_frame_cell: bool = False
# Whether the item at this source is the _content_ of a cell that is
# dereferenced from the root frame, i.e., it's a part of the `co_cellvars`
# or `co_freevars`.
is_derefed_cell_contents: bool = False
def reconstruct(self, codegen):
if self.is_root_frame_cell:
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
# Dynamo's bytecode transformation differentiates them slightly.
codegen.append_output(codegen.create_load_closure(self.local_name))
if self.is_derefed_cell_contents:
codegen.load_deref(self.local_name)
else:
codegen.append_output(codegen.create_load(self.local_name))
@ -234,48 +233,22 @@ class AttrSource(ChainedSource):
@dataclasses.dataclass(frozen=True)
class AutoDerefLocalSource(ChainedSource):
class LocalCellSource(Source):
"""
In Python, reads and writes to local cell objects (variables captured by a
frame, or created in a frame by captured by a nested frame) are
automatically dereferenced.
At the language level, this means accessing the `cell_contents` attribute of
the cell object, rather than the object itself.
At the bytecode level, this means turning LOAD_FAST into LOAD_DEREF, and
STORE_FAST into STORE_DEREF.
This class represents the source to the _contents_ of such a cell object,
encapsulating the python and bytecode level idiosyncracies of what would
otherwise have been a simple `AttrSource(cell_source, "cell_contents")`
Conceptually, this class is `LocalSource` for cell objects implicitly
generated by Python (e.g., captured variables).
"""
def __post_init__(self):
assert type(self.base) is LocalSource
assert self.base.is_root_frame_cell
local_name: str
def reconstruct(self, codegen):
# Emit more readable and performant bytecode.
assert isinstance(self.base, LocalSource) # tame mypy
codegen.load_deref(self.base.local_name)
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
# Dynamo's bytecode transformation differentiates them slightly, so we
# always emit `LOAD_CLOSURE` here.
codegen.append_output(codegen.create_load_closure(self.local_name))
def guard_source(self):
return self.base.guard_source()
def name(self):
# The requirements for `Source.name` are
# 1. with appropriate scope, `eval()` will turn it into the target
# python value.
# 2. can be used for caching guard managers.
#
# (1) requires us to return `self.base.name()` here, in the scope given
# to `eval()`, cells are already dereferenced.
#
# What about name collision that can affect (2)? Well, auto-deferenced
# cells should never have any guards on them (only guards on the
# contents), so this name collision shouldn't matter.
return self.base.name()
# All the other methods are intentionally unimplemented because e.g., a
# local cell object should never be used for guards.
# Represents tensor.grad source. It could be represented by AttrSource as well.

View file

@ -55,10 +55,10 @@ from .replay_record import DummyModule, ExecutionRecorder
from .resume_execution import ContinueExecutionCache, ReenterWith
from .source import (
AttrSource,
AutoDerefLocalSource,
GetItemSource,
GlobalSource,
GlobalWeakRefSource,
LocalCellSource,
LocalSource,
Source,
)
@ -1133,8 +1133,8 @@ class InstructionTranslatorBase(
self.output.side_effects.store_cell(cell, val)
assert isinstance(cell, CellVariable) # tame mypy
if cell.is_root_frame_cell():
val.set_name_hint(cell.source.local_name) # type: ignore[attr-defined]
if cell.local_name is not None:
val.set_name_hint(cell.local_name) # type: ignore[attr-defined]
LOAD_CLOSURE = LOAD_FAST
@ -2779,30 +2779,26 @@ class InstructionTranslator(InstructionTranslatorBase):
# after these cell objects are created. Thus they cannot be
# captured by any pre-existig function.
dummy_cell = types.CellType(value)
cell_source = LocalSource(
name, is_input=True, is_root_frame_cell=True
cell_source = LocalCellSource(name)
contents_source = LocalSource(
name, is_input=True, is_derefed_cell_contents=True
)
contents_source = AutoDerefLocalSource(cell_source)
contents_var: VariableTracker = LazyVariableTracker.create(
value, contents_source
)
cell_var = side_effects.track_cell_existing(
cell_source, dummy_cell, contents_var
)
self.symbolic_locals[name] = cell_var
else:
cell_var = side_effects.track_cell_new()
self.symbolic_locals[name] = cell_var
# We conveniently piggyback on `LocalSource.reconstruct` so
# we don't have to plumb extra stuff down to simplify
# codegen of `cell_var` and its side effects.
cell_var.source = LocalSource(name, is_root_frame_cell=True)
cell_var.local_name = name
self.symbolic_locals[name] = cell_var
# Populate `symbolic_locals` with cells captured by this frame,
# effectively implementing the `COPY_FREE_VARS` instruction.
for name, cell in zip(self.freevars(), closure):
cell_source = LocalSource(name, is_root_frame_cell=True)
contents_source = AutoDerefLocalSource(cell_source)
cell_source = LocalCellSource(name)
contents_source = LocalSource(name, is_derefed_cell_contents=True)
try:
contents_var = LazyVariableTracker.create(
cell.cell_contents, contents_source
@ -2813,6 +2809,7 @@ class InstructionTranslator(InstructionTranslatorBase):
cell_var = side_effects.track_cell_existing(
cell_source, cell, contents_var
)
cell_var.local_name = name
self.symbolic_locals[name] = cell_var
self.symbolic_torch_function_state = SymbolicTorchFunctionState(

View file

@ -25,7 +25,6 @@ from ..source import (
AttrSource,
DefaultsSource,
GetItemSource,
LocalSource,
ODictGetItemSource,
TypeSource,
WeakRefCallSource,
@ -339,24 +338,16 @@ class CellVariable(VariableTracker):
# `CellVariable` as a special case for `UserDefinedObjectVariable`.
pre_existing_contents: Optional[VariableTracker]
# This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the
# root frame via this name (e.g., the name is in `co_cellvars/co_freevars`).
local_name: Optional[str] = None
def __init__(
self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs
) -> None:
super().__init__(**kwargs)
self.pre_existing_contents = pre_existing_contents
def is_root_frame_cell(self):
"""
Return true if this variable models a cell that is native to the root
frame, i.e., a part of its `co_cellvars` or `co_freevars`.
"""
source = self.source
return (
source is not None
and isinstance(source, LocalSource)
and source.is_root_frame_cell
)
class NewGlobalVariable(VariableTracker):
def __init__(self, **kwargs) -> None: