mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[dynamo] Remove AutoDerefLocalSource and simplify cell handling (#141629)
This patch 1. removes `AutoDerefLocalSource` in favor of `LocalSource`, thereby removing its special handling in guards. 2. introduces a `LocalCellSource` for cells from the root frame, with only `reconstruct` implemented, to programmatically enforce that thse cells should never be used by other components like guards. Pull Request resolved: https://github.com/pytorch/pytorch/pull/141629 Approved by: https://github.com/jansel ghstack dependencies: #141628
This commit is contained in:
parent
e14d8c980f
commit
2d708752f0
7 changed files with 41 additions and 87 deletions
|
|
@ -453,7 +453,7 @@ class PyCodegen:
|
|||
# Emitting `LOAD_FAST/LOAD_CLOSURE` with names in `co_freevars`
|
||||
# requires that in the generated bytecode, these cells would keep
|
||||
# their original local names, which we ensure via
|
||||
# `LocalSource.is_root_frame_cell`.
|
||||
# `CellVariable.local_name`.
|
||||
for var in freevars:
|
||||
assert var in self.cell_and_freevars()
|
||||
output.append(self.create_load_closure(var))
|
||||
|
|
|
|||
|
|
@ -44,7 +44,6 @@ from torch._C._dynamo.guards import (
|
|||
check_type_id,
|
||||
dict_version,
|
||||
DictGuardManager,
|
||||
GuardManager,
|
||||
install_no_tensor_aliasing_guard,
|
||||
install_object_aliasing_guard,
|
||||
profile_guard_manager,
|
||||
|
|
@ -83,7 +82,6 @@ from .eval_frame import set_guard_error_hook
|
|||
from .source import (
|
||||
AttrProxySource,
|
||||
AttrSource,
|
||||
AutoDerefLocalSource,
|
||||
CallFunctionNoArgsSource,
|
||||
ChainedSource,
|
||||
ConstDictKeySource,
|
||||
|
|
@ -968,14 +966,6 @@ class GuardBuilder(GuardBuilderBase):
|
|||
example_value=example_value,
|
||||
guard_manager_enum=guard_manager_enum,
|
||||
)
|
||||
elif istype(source, AutoDerefLocalSource):
|
||||
# Guard checks run on f_locals, in which the python level
|
||||
# auto-dereferenced cell objects are also dereferenced (e.g., rather
|
||||
# than `f_locals` being `{ 'cell' : <cell object of int> }`, it'll
|
||||
# be `{ 'cell' : <int> }`. So the guard manager is the same as the
|
||||
# base guard manager.
|
||||
assert isinstance(base_guard_manager, GuardManager) # tame mypy
|
||||
out = base_guard_manager
|
||||
elif istype(source, GlobalSource):
|
||||
# Global manager accepts a dict but it is not a DictGuardManager
|
||||
# because globals dict is big and we typically guard on a very
|
||||
|
|
|
|||
|
|
@ -117,7 +117,7 @@ from .variables.builder import (
|
|||
wrap_fx_proxy,
|
||||
)
|
||||
from .variables.lists import BaseListVariable
|
||||
from .variables.misc import NullVariable
|
||||
from .variables.misc import CellVariable, NullVariable
|
||||
from .variables.nn_module import NNModuleVariable
|
||||
from .variables.tensor import (
|
||||
NumpyNdarrayVariable,
|
||||
|
|
@ -1041,6 +1041,8 @@ class OutputGraph:
|
|||
# while running test_subgraphs.py
|
||||
if isinstance(v.source, LocalSource) and v.source.local_name == k:
|
||||
continue # no need to restore initial state
|
||||
if isinstance(v, CellVariable) and v.local_name == k:
|
||||
continue # no need to restore initial state
|
||||
# Do not load variable if it is NULL.
|
||||
if sys.version_info >= (3, 12):
|
||||
# Continuation function will load the NULL for v.
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from .bytecode_transformation import (
|
|||
)
|
||||
from .codegen import PyCodegen
|
||||
from .exc import unimplemented
|
||||
from .source import GlobalSource, LocalSource, Source
|
||||
from .source import GlobalSource, LocalCellSource, LocalSource, Source
|
||||
from .utils import is_frozen_dataclass, nn_module_new, object_new
|
||||
from .variables.base import (
|
||||
AttributeMutation,
|
||||
|
|
@ -431,13 +431,15 @@ class SideEffects:
|
|||
# `MAKE_CELL` or by them being in `co_cellvars`, so we only emit
|
||||
# `make_cell` for the non-root-frame cells here.
|
||||
# TODO generalize this so we never need to call `make_cell`.
|
||||
if not var.is_root_frame_cell():
|
||||
if var.local_name is None:
|
||||
cg.add_push_null(
|
||||
lambda: cg.load_import_from(utils.__name__, "make_cell")
|
||||
)
|
||||
cg.extend_output(create_call_function(0, False))
|
||||
cg.add_cache(var)
|
||||
var.source = LocalSource(cg.tempvars[var]) # type: ignore[attr-defined]
|
||||
elif var.source is None:
|
||||
var.source = LocalCellSource(var.local_name)
|
||||
elif isinstance(var.mutation_type, AttributeMutationNew):
|
||||
if isinstance(var, variables.AutogradFunctionContextVariable):
|
||||
unimplemented("AutogradFunctionContextVariable escaped")
|
||||
|
|
@ -652,14 +654,13 @@ class SideEffects:
|
|||
cg.call_function(1, False)
|
||||
cg.append_output(create_instruction("POP_TOP"))
|
||||
|
||||
elif isinstance(var, variables.CellVariable) and var.is_root_frame_cell():
|
||||
elif isinstance(var, variables.CellVariable) and var.local_name is not None:
|
||||
# Emit more readable and performant bytecode.
|
||||
# TODO generalize this for cells created during inlining.
|
||||
if var in self.store_attr_mutations:
|
||||
contents_var = self.load_cell(var)
|
||||
cg(contents_var)
|
||||
cell_name = var.source.local_name # type: ignore[attr-defined]
|
||||
suffixes.append([cg.create_store_deref(cell_name)])
|
||||
suffixes.append([cg.create_store_deref(var.local_name)])
|
||||
|
||||
elif self.is_attribute_mutation(var):
|
||||
# Applying mutations involves two steps: 1) Push all
|
||||
|
|
|
|||
|
|
@ -107,15 +107,14 @@ class LocalSource(Source):
|
|||
# Whether this local is an input to the root frame.
|
||||
is_input: bool = False
|
||||
|
||||
# Whether the item at this source is a that is native to the root frame,
|
||||
# i.e., a part of its `co_cellvars` or `co_freevars`.
|
||||
is_root_frame_cell: bool = False
|
||||
# Whether the item at this source is the _content_ of a cell that is
|
||||
# dereferenced from the root frame, i.e., it's a part of the `co_cellvars`
|
||||
# or `co_freevars`.
|
||||
is_derefed_cell_contents: bool = False
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
if self.is_root_frame_cell:
|
||||
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
|
||||
# Dynamo's bytecode transformation differentiates them slightly.
|
||||
codegen.append_output(codegen.create_load_closure(self.local_name))
|
||||
if self.is_derefed_cell_contents:
|
||||
codegen.load_deref(self.local_name)
|
||||
else:
|
||||
codegen.append_output(codegen.create_load(self.local_name))
|
||||
|
||||
|
|
@ -234,48 +233,22 @@ class AttrSource(ChainedSource):
|
|||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class AutoDerefLocalSource(ChainedSource):
|
||||
class LocalCellSource(Source):
|
||||
"""
|
||||
In Python, reads and writes to local cell objects (variables captured by a
|
||||
frame, or created in a frame by captured by a nested frame) are
|
||||
automatically dereferenced.
|
||||
|
||||
At the language level, this means accessing the `cell_contents` attribute of
|
||||
the cell object, rather than the object itself.
|
||||
|
||||
At the bytecode level, this means turning LOAD_FAST into LOAD_DEREF, and
|
||||
STORE_FAST into STORE_DEREF.
|
||||
|
||||
This class represents the source to the _contents_ of such a cell object,
|
||||
encapsulating the python and bytecode level idiosyncracies of what would
|
||||
otherwise have been a simple `AttrSource(cell_source, "cell_contents")`
|
||||
Conceptually, this class is `LocalSource` for cell objects implicitly
|
||||
generated by Python (e.g., captured variables).
|
||||
"""
|
||||
|
||||
def __post_init__(self):
|
||||
assert type(self.base) is LocalSource
|
||||
assert self.base.is_root_frame_cell
|
||||
local_name: str
|
||||
|
||||
def reconstruct(self, codegen):
|
||||
# Emit more readable and performant bytecode.
|
||||
assert isinstance(self.base, LocalSource) # tame mypy
|
||||
codegen.load_deref(self.base.local_name)
|
||||
# Although `LOAD_FAST` and `LOAD_CLOSURE` have the same semantics,
|
||||
# Dynamo's bytecode transformation differentiates them slightly, so we
|
||||
# always emit `LOAD_CLOSURE` here.
|
||||
codegen.append_output(codegen.create_load_closure(self.local_name))
|
||||
|
||||
def guard_source(self):
|
||||
return self.base.guard_source()
|
||||
|
||||
def name(self):
|
||||
# The requirements for `Source.name` are
|
||||
# 1. with appropriate scope, `eval()` will turn it into the target
|
||||
# python value.
|
||||
# 2. can be used for caching guard managers.
|
||||
#
|
||||
# (1) requires us to return `self.base.name()` here, in the scope given
|
||||
# to `eval()`, cells are already dereferenced.
|
||||
#
|
||||
# What about name collision that can affect (2)? Well, auto-deferenced
|
||||
# cells should never have any guards on them (only guards on the
|
||||
# contents), so this name collision shouldn't matter.
|
||||
return self.base.name()
|
||||
# All the other methods are intentionally unimplemented because e.g., a
|
||||
# local cell object should never be used for guards.
|
||||
|
||||
|
||||
# Represents tensor.grad source. It could be represented by AttrSource as well.
|
||||
|
|
|
|||
|
|
@ -55,10 +55,10 @@ from .replay_record import DummyModule, ExecutionRecorder
|
|||
from .resume_execution import ContinueExecutionCache, ReenterWith
|
||||
from .source import (
|
||||
AttrSource,
|
||||
AutoDerefLocalSource,
|
||||
GetItemSource,
|
||||
GlobalSource,
|
||||
GlobalWeakRefSource,
|
||||
LocalCellSource,
|
||||
LocalSource,
|
||||
Source,
|
||||
)
|
||||
|
|
@ -1133,8 +1133,8 @@ class InstructionTranslatorBase(
|
|||
self.output.side_effects.store_cell(cell, val)
|
||||
|
||||
assert isinstance(cell, CellVariable) # tame mypy
|
||||
if cell.is_root_frame_cell():
|
||||
val.set_name_hint(cell.source.local_name) # type: ignore[attr-defined]
|
||||
if cell.local_name is not None:
|
||||
val.set_name_hint(cell.local_name) # type: ignore[attr-defined]
|
||||
|
||||
LOAD_CLOSURE = LOAD_FAST
|
||||
|
||||
|
|
@ -2779,30 +2779,26 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
# after these cell objects are created. Thus they cannot be
|
||||
# captured by any pre-existig function.
|
||||
dummy_cell = types.CellType(value)
|
||||
cell_source = LocalSource(
|
||||
name, is_input=True, is_root_frame_cell=True
|
||||
cell_source = LocalCellSource(name)
|
||||
contents_source = LocalSource(
|
||||
name, is_input=True, is_derefed_cell_contents=True
|
||||
)
|
||||
contents_source = AutoDerefLocalSource(cell_source)
|
||||
contents_var: VariableTracker = LazyVariableTracker.create(
|
||||
value, contents_source
|
||||
)
|
||||
cell_var = side_effects.track_cell_existing(
|
||||
cell_source, dummy_cell, contents_var
|
||||
)
|
||||
self.symbolic_locals[name] = cell_var
|
||||
else:
|
||||
cell_var = side_effects.track_cell_new()
|
||||
self.symbolic_locals[name] = cell_var
|
||||
# We conveniently piggyback on `LocalSource.reconstruct` so
|
||||
# we don't have to plumb extra stuff down to simplify
|
||||
# codegen of `cell_var` and its side effects.
|
||||
cell_var.source = LocalSource(name, is_root_frame_cell=True)
|
||||
cell_var.local_name = name
|
||||
self.symbolic_locals[name] = cell_var
|
||||
|
||||
# Populate `symbolic_locals` with cells captured by this frame,
|
||||
# effectively implementing the `COPY_FREE_VARS` instruction.
|
||||
for name, cell in zip(self.freevars(), closure):
|
||||
cell_source = LocalSource(name, is_root_frame_cell=True)
|
||||
contents_source = AutoDerefLocalSource(cell_source)
|
||||
cell_source = LocalCellSource(name)
|
||||
contents_source = LocalSource(name, is_derefed_cell_contents=True)
|
||||
try:
|
||||
contents_var = LazyVariableTracker.create(
|
||||
cell.cell_contents, contents_source
|
||||
|
|
@ -2813,6 +2809,7 @@ class InstructionTranslator(InstructionTranslatorBase):
|
|||
cell_var = side_effects.track_cell_existing(
|
||||
cell_source, cell, contents_var
|
||||
)
|
||||
cell_var.local_name = name
|
||||
self.symbolic_locals[name] = cell_var
|
||||
|
||||
self.symbolic_torch_function_state = SymbolicTorchFunctionState(
|
||||
|
|
|
|||
|
|
@ -25,7 +25,6 @@ from ..source import (
|
|||
AttrSource,
|
||||
DefaultsSource,
|
||||
GetItemSource,
|
||||
LocalSource,
|
||||
ODictGetItemSource,
|
||||
TypeSource,
|
||||
WeakRefCallSource,
|
||||
|
|
@ -339,24 +338,16 @@ class CellVariable(VariableTracker):
|
|||
# `CellVariable` as a special case for `UserDefinedObjectVariable`.
|
||||
pre_existing_contents: Optional[VariableTracker]
|
||||
|
||||
# This is set when this cell can be referenced via `LOAD/STORE_DEREF` in the
|
||||
# root frame via this name (e.g., the name is in `co_cellvars/co_freevars`).
|
||||
local_name: Optional[str] = None
|
||||
|
||||
def __init__(
|
||||
self, pre_existing_contents: Optional[VariableTracker] = None, **kwargs
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.pre_existing_contents = pre_existing_contents
|
||||
|
||||
def is_root_frame_cell(self):
|
||||
"""
|
||||
Return true if this variable models a cell that is native to the root
|
||||
frame, i.e., a part of its `co_cellvars` or `co_freevars`.
|
||||
"""
|
||||
source = self.source
|
||||
return (
|
||||
source is not None
|
||||
and isinstance(source, LocalSource)
|
||||
and source.is_root_frame_cell
|
||||
)
|
||||
|
||||
|
||||
class NewGlobalVariable(VariableTracker):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
|
|
|
|||
Loading…
Reference in a new issue