[compiled autograd] Compiled autograd configs in TLS (#137821)

Multithreaded doesn't work yet, this adds python side TLS only for the python side state

Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821
Approved by: https://github.com/jansel, https://github.com/yf225
ghstack dependencies: #137953
This commit is contained in:
Simon Fan 2024-10-21 17:32:41 -07:00 committed by PyTorch MergeBot
parent 75259145ec
commit 49fa437097
16 changed files with 221 additions and 103 deletions

View file

@ -256,7 +256,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
)
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if not torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
else:
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph

View file

@ -86,7 +86,9 @@ def count_ops(
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]):
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
if not torch._dynamo.compiled_autograd.local.get(
"in_compiled_autograd_region"
): # fwd graph
return_node = list(graph.nodes)[-1]
assert return_node.target == "output"
for x in return_node.args[0]:

View file

@ -6,9 +6,11 @@ import io
import itertools
import logging
import os
import queue
import re
import subprocess
import sys
import threading
import unittest
from importlib.machinery import SourceFileLoader
from pathlib import Path
@ -2405,6 +2407,39 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
not in logs.getvalue()
)
def test_multithreading_tls(self):
def train(errors, model, x):
try:
out = model(x)
with compiled_autograd.enable(compiler_fn):
self.assertEqual(compiled_autograd.local.enabled(), True)
self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1)
except Exception as e:
print(f"Found error: {e}")
errors.put(1)
raise
model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
torch.nn.Linear(4, 4),
torch.nn.ReLU(),
)
x = torch.randn([2, 4])
threads = []
errors = queue.Queue()
with compiled_autograd.enable(compiler_fn):
for i in range(4):
thread = threading.Thread(target=train, args=(errors, model, x))
threads.append(thread)
thread.start()
for thread in threads:
thread.join()
assert errors.empty()
def test_verbose_logs_graph(self):
def fn():
model = torch.nn.Sequential(

View file

@ -1,10 +1,6 @@
from typing import Callable
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
def set_autograd_compiler(
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
) -> Callable[[], AutogradCompilerInstance] | None: ...
def notify_autograd_engine() -> None: ...
def clear_cache() -> None: ...
def is_cache_empty() -> bool: ...
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...

View file

@ -1,7 +1,10 @@
# mypy: allow-untyped-defs
import contextlib
import functools
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import threading
from dataclasses import dataclass
from logging import Logger
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
import torch
from torch._dynamo.external_utils import (
@ -38,14 +41,90 @@ compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
def snapshot_verbose_logging_enabled():
return torch._logging._internal.log_state.is_artifact_enabled(
"compiled_autograd_verbose"
)
@dataclass
class CompiledAutogradTLS:
next_ctx_id: int = 0
in_compiled_autograd_region: bool = False
compiler: Optional["AutogradCompilerInstance"] = None
vlogger: Optional[Logger] = None
def snapshot_cudagraph_enabled():
return torch._inductor.config.triton.cudagraphs
class TLSWrapper:
tls_key = "compiled_autograd_state"
def __init__(self):
self._local = threading.local()
def _get_tls(self) -> CompiledAutogradTLS:
if hasattr(self._local, self.tls_key):
# first look in python
state = getattr(self._local, self.tls_key)
if torch._C._is_key_in_tls(self.tls_key):
# then look in cpp
state = torch._C._get_obj_in_tls(self.tls_key)
else:
# init new thread created outside of autograd
# TODO: what if context manager wrapped outside of thread?
setattr(self._local, self.tls_key, CompiledAutogradTLS())
state = getattr(self._local, self.tls_key)
torch._C._stash_obj_in_tls(self.tls_key, state)
return state
# queries on the object stored in TLS
def get(self, name):
return getattr(self._get_tls(), name)
def set_tls(self, **kwargs) -> Callable[[], None]:
priors: Dict[str, Any] = {}
for k, v in kwargs.items():
state = self._get_tls()
priors[k] = getattr(state, k)
setattr(state, k, v)
torch._C._dynamo.compiled_autograd.notify_autograd_engine()
def revert():
self.set_tls(**priors)
return revert
def enabled(self) -> bool:
return self.get("compiler") is not None
def enter_ctx(self) -> Callable[[], None]:
state = self._get_tls()
state.next_ctx_id += 1
id = state.next_ctx_id
def exit():
assert (
state is self._get_tls()
), "Runtime must begin and end on the same thread"
assert state.next_ctx_id == id, (
"Error nesting compiled autograd context managers: "
"inner context managers must have shorter lifetime than the outer context manager"
)
state.next_ctx_id -= 1
return exit
def enter_compiled_region(self) -> Callable[[], None]:
state = self._get_tls()
prior = state.in_compiled_autograd_region
state.in_compiled_autograd_region = True
assert prior is False, "Nested compiled autograd regions are not supported"
def exit():
assert (
state is self._get_tls()
), "Runtime must begin and end on the same thread"
assert state.in_compiled_autograd_region is True
state.in_compiled_autograd_region = prior
return exit
local = TLSWrapper()
def maybe_clone(x):
@ -307,7 +386,7 @@ class AutogradCompilerInstance:
self.rename_aot_dispatcher_nodes()
self.reorder_accumulate_grad_nodes()
runtime_inputs_to_move: List[int] = []
if snapshot_cudagraph_enabled():
if torch._inductor.config.triton.cudagraphs:
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
graph = GraphModule(
@ -329,16 +408,15 @@ class AutogradCompilerInstance:
)
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
global in_compiled_autograd_region
try:
in_compiled_autograd_region = True
exit_compiled_region = local.enter_compiled_region()
for i in runtime_inputs_to_move:
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
with disable():
return compiled_fn(inputs, sizes, scalars, hooks)
finally:
in_compiled_autograd_region = False
exit_compiled_region()
return runtime_wrapper, self.compiler_fn(graph)
@ -510,15 +588,9 @@ class AutogradCompilerInstance:
set_stack_trace(new_stack_trace)
# state of the autograd engine dispatch, kept in sync by enable/disable context managers
compiled_autograd_enabled = False
# global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager"
compiled_autograd_enabled_force_eager = False
# global flag to check if we are processing graphs produced from a compiled autograd graph
in_compiled_autograd_region = False
@contextlib.contextmanager
def enable(compiler_fn):
@ -538,39 +610,42 @@ def enable(compiler_fn):
# we need to lazily import it, because of circular dependencies
import torch._inductor.cudagraph_trees
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
functools.partial(AutogradCompilerInstance, compiler_fn)
exit_ctx = local.enter_ctx()
revert_tls = local.set_tls(
compiler=functools.partial(AutogradCompilerInstance, compiler_fn),
vlogger=verbose_log
if torch._logging._internal.log_state.is_artifact_enabled(
"compiled_autograd_verbose"
)
else None,
)
if snapshot_verbose_logging_enabled():
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log)
global compiled_autograd_enabled
compiled_autograd_enabled = True
try:
with torch.autograd.set_multithreading_enabled(False):
yield
finally:
if not prior:
compiled_autograd_enabled = False
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
revert_tls()
exit_ctx()
@contextlib.contextmanager
def disable():
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
global compiled_autograd_enabled
compiled_autograd_enabled = False
exit_ctx = local.enter_ctx()
revert_tls = local.set_tls(
compiler=None,
vlogger=None,
)
try:
yield
finally:
if prior:
compiled_autograd_enabled = True
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
revert_tls()
exit_ctx()
# return to starting state of a new process
def reset() -> None:
global compiled_autograd_enabled
compiled_autograd_enabled = False
assert not in_compiled_autograd_region
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
assert local.get("next_ctx_id") == 0
assert local.get("in_compiled_autograd_region") is False
local.set_tls(
compiler=None,
vlogger=None,
)

View file

@ -472,6 +472,11 @@ compiled_autograd = False
# Overrides torch.compile() kwargs for Compiled Autograd:
compiled_autograd_kwargs_override: Dict[str, Any] = {}
# Compiled Autograd will attempt to automatically wrap C++ autograd functions found in the autograd graph,
# and make them opaque to the compiler. This does not work when the C++ backward implementation involves
# other dispatcher subsystems e.g. custom subclasses, autocast, vmap.
compiled_autograd_opaque_cpp_node = False
# Enables use of collectives *during* compilation to synchronize behavior
# across ranks. Today, this is used solely to modify automatic_dynamic_shapes
# behavior, making it so that we infer that if an input is dynamic by

View file

@ -3051,7 +3051,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
if node.op == "placeholder" and node.meta.get("steal_arg", False)
]
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
# fast path, avoid pytree overhead
# compiled autograd inputs are always a list of tensors, maybe followed by symints
assert inputs_idx_to_clear == [0]

View file

@ -313,7 +313,7 @@ class BackwardHookVariable(VariableTracker):
user_hooks: VariableTracker,
user_pre_hooks: VariableTracker,
):
if not compiled_autograd.compiled_autograd_enabled:
if not compiled_autograd.local.enabled():
unimplemented("module-level backwards hooks require compiled autograd")
def _in_graph_bw_hooks(bw_state: BackwardState):

View file

@ -929,7 +929,7 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
if name == "queue_callback":
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
assert (
tx.one_graph
), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"

View file

@ -1007,7 +1007,7 @@ class TensorVariable(VariableTracker):
tx = InstructionTranslator.current_tx()
if not self.source:
if not compiled_autograd.compiled_autograd_enabled:
if not compiled_autograd.local.enabled():
# TODO(voz):
# We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
# python state.

View file

@ -176,7 +176,7 @@ def check_cacheable(gm: torch.fx.GraphModule):
Checks that the graph module only uses supported operators
"""
nodes = gm.graph.nodes
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
raise BypassAOTAutogradCache(
"Cannot cache a graph with compiled autograd enabled"
)

View file

@ -704,7 +704,7 @@ from a multi-output view call"
traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats]
nonlocal static_input_indices
static_input_indices = static_input_indices or []
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
passed_indices = set(static_input_indices)
static_input_indices = [
i

View file

@ -760,7 +760,7 @@ def aot_dispatch_autograd(
# becomes the lazy version again. One example is when dynamic shape is enabled
# upfront, the bw_compiler will be called above which can cause extra
# graph module recompilation on bw_module.
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
from torch.fx._lazy_graph_module import _LazyGraphModule
_LazyGraphModule.force_recompile(bw_module)

View file

@ -224,11 +224,29 @@ struct LiftedIValueArgs {
const std::optional<size_t>& active_node_call_idx;
};
// Hold GIL while using
struct PyTLSWrapper {
PyTLSWrapper(PyObject* state) : state(state) {}
PyTLSWrapper(const PyTLSWrapper&) = delete;
PyTLSWrapper& operator=(const PyTLSWrapper&) = delete;
PyTLSWrapper(PyTLSWrapper&&) = default;
PyTLSWrapper& operator=(PyTLSWrapper&&) = default;
static PyTLSWrapper create();
PyObject* get(std::string_view key) const;
private:
PyObject* state;
};
struct AutogradCompilerCall {
AutogradCompilerCall()
AutogradCompilerCall() = delete;
AutogradCompilerCall(PyTLSWrapper&& state)
: active_node_call_idx(std::nullopt),
tensor_args(active_node_call_idx),
lifted_ivalue_args(active_node_call_idx) {}
lifted_ivalue_args(active_node_call_idx),
state(std::move(state)) {}
void add_size_input(const c10::SymInt& s) {
all_size_inputs.emplace_back(
default_dyn_type, s.guard_int(__FILE__, __LINE__));
@ -254,8 +272,11 @@ struct AutogradCompilerCall {
std::vector<c10::SafePyObject> hooks;
NodeCalls node_calls;
SizeInput::DynType default_dyn_type = SizeInput::STATIC;
// NodeCall id of each size, only when verbose logging is enabled
std::vector<uint32_t> size_input_origins;
const PyTLSWrapper state;
};
class CompiledNodeArgs {

View file

@ -88,10 +88,6 @@ static void check(bool result) {
if (C10_UNLIKELY(!result))
check(nullptr);
}
// snapshot of python verbose logging toggle
static PyObject* python_verbose_logger = nullptr;
struct PythonLogger {
PythonLogger() = delete;
explicit PythonLogger(PyObject* logger) : logger_(logger) {
@ -135,15 +131,15 @@ struct PythonLogger {
};
struct VerboseLogger : public PythonLogger {
static std::optional<VerboseLogger> maybe_create() {
if (python_verbose_logger == nullptr) {
VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {}
static std::optional<VerboseLogger> maybe_create(PyObject* vlogger) {
if (vlogger == Py_None) {
return std::nullopt;
}
return VerboseLogger(python_verbose_logger);
return VerboseLogger(vlogger);
}
VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {}
void log_node_check(
const Node& fn,
size_t size_inputs_num,
@ -368,8 +364,22 @@ struct InputBuffers : public std::unordered_map<Node*, InputBuffer> {
}
};
static PyObject* the_autograd_compiler = nullptr;
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args);
/* static */ PyTLSWrapper PyTLSWrapper::create() {
TORCH_INTERNAL_ASSERT(
at::impl::ThreadLocalPythonObjects::contains("compiled_autograd_state"));
PyObject* compiled_autograd_state =
check(at::impl::ThreadLocalPythonObjects::get("compiled_autograd_state")
->ptr(getPyInterpreter()));
return PyTLSWrapper(compiled_autograd_state);
}
// Refer to fields in python class CompiledAutogradTLS
// May return Py_None
PyObject* PyTLSWrapper::get(std::string_view key) const {
return check(PyObject_GetAttrString(state, key.data()));
}
static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args);
static PyObject* clear_cache(PyObject* dummy, PyObject* args) {
HANDLE_TH_ERRORS;
@ -387,28 +397,11 @@ static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) {
END_HANDLE_TH_ERRORS;
}
static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) {
HANDLE_TH_ERRORS;
PyObject* logger = nullptr;
if (!PyArg_ParseTuple(args, "O", &logger)) {
throw_python_error();
}
if (logger == Py_None) {
python_verbose_logger = nullptr;
} else {
python_verbose_logger = logger;
}
Py_RETURN_TRUE;
END_HANDLE_TH_ERRORS;
}
// NOLINTNEXTLINE(*array*)
static PyMethodDef _methods[] = {
{"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr},
{"notify_autograd_engine", notify_autograd_engine, METH_NOARGS, nullptr},
{"clear_cache", clear_cache, METH_NOARGS, nullptr},
{"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr},
{"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr},
{nullptr, nullptr, 0, nullptr}};
static struct PyModuleDef _module = {
@ -568,7 +561,7 @@ CacheNode* _compiled_autograd_impl(
THPObjectPtr* graph_arg_hooks) {
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
std::vector<std::shared_ptr<Node>> worklist{graph_root};
AutogradCompilerCall compiler_call;
AutogradCompilerCall compiler_call(PyTLSWrapper::create());
for (const auto i : c10::irange(output_edges.size())) {
compiler_call.node_calls
@ -583,7 +576,8 @@ CacheNode* _compiled_autograd_impl(
check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1);
int i = 0;
std::optional<VerboseLogger> vlogger = VerboseLogger::maybe_create();
std::optional<VerboseLogger> vlogger =
VerboseLogger::maybe_create(compiler_call.state.get("vlogger"));
while (!worklist.empty()) {
std::shared_ptr<Node> fn = std::move(worklist.back());
worklist.pop_back();
@ -642,6 +636,8 @@ CacheNode* _compiled_autograd_impl(
// TODO(jansel): some dynamic sizes seem to be ints not symints
if (!cache->check_dynamic_sizes(compiler_call, vlogger)) {
// cache miss, need to capture FX graph
PyObject* the_autograd_compiler = compiler_call.state.get("compiler");
TORCH_INTERNAL_ASSERT(the_autograd_compiler != Py_None);
ClosingTHPObjectPtr py_compiler(
check(PyObject_CallNoArgs((the_autograd_compiler))));
@ -839,28 +835,16 @@ variable_list compiled_autograd(
return outputs;
}
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) {
static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args) {
HANDLE_TH_ERRORS;
PyObject* obj = nullptr;
if (!PyArg_ParseTuple(args, "O", &obj)) {
return nullptr;
}
PyObject* prior = the_autograd_compiler;
if (obj == Py_None) { // disable
the_autograd_compiler = nullptr; // decref not needed due to `prior`
PyTLSWrapper state = PyTLSWrapper::create();
PyObject* compiler = state.get("compiler");
if (compiler == Py_None) { // disable
Engine::set_compiled_autograd(nullptr);
} else { // enable
Py_INCREF(obj);
the_autograd_compiler = obj;
Engine::set_compiled_autograd(&compiled_autograd);
}
if (prior == nullptr) {
Py_RETURN_NONE;
} else {
return prior;
}
Py_RETURN_NONE;
END_HANDLE_TH_ERRORS;
}

View file

@ -33,9 +33,9 @@ else:
import torch._dynamo.compiled_autograd as ca
_compiled_autograd_enabled = (
ca.compiled_autograd_enabled
ca.local.enabled()
or ca.compiled_autograd_enabled_force_eager
or ca.in_compiled_autograd_region
or ca.local.get("in_compiled_autograd_region")
)
def compiled_autograd_enabled():