mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[compiled autograd] Compiled autograd configs in TLS (#137821)
Multithreaded doesn't work yet, this adds python side TLS only for the python side state Pull Request resolved: https://github.com/pytorch/pytorch/pull/137821 Approved by: https://github.com/jansel, https://github.com/yf225 ghstack dependencies: #137953
This commit is contained in:
parent
75259145ec
commit
49fa437097
16 changed files with 221 additions and 103 deletions
|
|
@ -256,7 +256,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]},
|
|||
f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950
|
||||
)
|
||||
|
||||
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
if not torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
|
||||
_check_count(fwd_copy_count, fwd_resize_count) # fwd graph
|
||||
else:
|
||||
_check_count(bwd_copy_count, bwd_resize_count) # bwd graph
|
||||
|
|
|
|||
|
|
@ -86,7 +86,9 @@ def count_ops(
|
|||
|
||||
|
||||
def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]):
|
||||
if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph
|
||||
if not torch._dynamo.compiled_autograd.local.get(
|
||||
"in_compiled_autograd_region"
|
||||
): # fwd graph
|
||||
return_node = list(graph.nodes)[-1]
|
||||
assert return_node.target == "output"
|
||||
for x in return_node.args[0]:
|
||||
|
|
|
|||
|
|
@ -6,9 +6,11 @@ import io
|
|||
import itertools
|
||||
import logging
|
||||
import os
|
||||
import queue
|
||||
import re
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import unittest
|
||||
from importlib.machinery import SourceFileLoader
|
||||
from pathlib import Path
|
||||
|
|
@ -2405,6 +2407,39 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) {
|
|||
not in logs.getvalue()
|
||||
)
|
||||
|
||||
def test_multithreading_tls(self):
|
||||
def train(errors, model, x):
|
||||
try:
|
||||
out = model(x)
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
self.assertEqual(compiled_autograd.local.enabled(), True)
|
||||
self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1)
|
||||
except Exception as e:
|
||||
print(f"Found error: {e}")
|
||||
errors.put(1)
|
||||
raise
|
||||
|
||||
model = torch.nn.Sequential(
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(4, 4),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
x = torch.randn([2, 4])
|
||||
|
||||
threads = []
|
||||
errors = queue.Queue()
|
||||
with compiled_autograd.enable(compiler_fn):
|
||||
for i in range(4):
|
||||
thread = threading.Thread(target=train, args=(errors, model, x))
|
||||
threads.append(thread)
|
||||
thread.start()
|
||||
|
||||
for thread in threads:
|
||||
thread.join()
|
||||
|
||||
assert errors.empty()
|
||||
|
||||
def test_verbose_logs_graph(self):
|
||||
def fn():
|
||||
model = torch.nn.Sequential(
|
||||
|
|
|
|||
|
|
@ -1,10 +1,6 @@
|
|||
from typing import Callable
|
||||
|
||||
from torch._dynamo.compiled_autograd import AutogradCompilerInstance
|
||||
|
||||
def set_autograd_compiler(
|
||||
autograd_compiler: Callable[[], AutogradCompilerInstance] | None,
|
||||
) -> Callable[[], AutogradCompilerInstance] | None: ...
|
||||
def notify_autograd_engine() -> None: ...
|
||||
def clear_cache() -> None: ...
|
||||
def is_cache_empty() -> bool: ...
|
||||
def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ...
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import functools
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from logging import Logger
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
|
||||
import torch
|
||||
from torch._dynamo.external_utils import (
|
||||
|
|
@ -38,14 +41,90 @@ compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd")
|
|||
verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose")
|
||||
|
||||
|
||||
def snapshot_verbose_logging_enabled():
|
||||
return torch._logging._internal.log_state.is_artifact_enabled(
|
||||
"compiled_autograd_verbose"
|
||||
)
|
||||
@dataclass
|
||||
class CompiledAutogradTLS:
|
||||
next_ctx_id: int = 0
|
||||
in_compiled_autograd_region: bool = False
|
||||
compiler: Optional["AutogradCompilerInstance"] = None
|
||||
vlogger: Optional[Logger] = None
|
||||
|
||||
|
||||
def snapshot_cudagraph_enabled():
|
||||
return torch._inductor.config.triton.cudagraphs
|
||||
class TLSWrapper:
|
||||
tls_key = "compiled_autograd_state"
|
||||
|
||||
def __init__(self):
|
||||
self._local = threading.local()
|
||||
|
||||
def _get_tls(self) -> CompiledAutogradTLS:
|
||||
if hasattr(self._local, self.tls_key):
|
||||
# first look in python
|
||||
state = getattr(self._local, self.tls_key)
|
||||
if torch._C._is_key_in_tls(self.tls_key):
|
||||
# then look in cpp
|
||||
state = torch._C._get_obj_in_tls(self.tls_key)
|
||||
else:
|
||||
# init new thread created outside of autograd
|
||||
# TODO: what if context manager wrapped outside of thread?
|
||||
setattr(self._local, self.tls_key, CompiledAutogradTLS())
|
||||
state = getattr(self._local, self.tls_key)
|
||||
torch._C._stash_obj_in_tls(self.tls_key, state)
|
||||
return state
|
||||
|
||||
# queries on the object stored in TLS
|
||||
def get(self, name):
|
||||
return getattr(self._get_tls(), name)
|
||||
|
||||
def set_tls(self, **kwargs) -> Callable[[], None]:
|
||||
priors: Dict[str, Any] = {}
|
||||
for k, v in kwargs.items():
|
||||
state = self._get_tls()
|
||||
priors[k] = getattr(state, k)
|
||||
setattr(state, k, v)
|
||||
|
||||
torch._C._dynamo.compiled_autograd.notify_autograd_engine()
|
||||
|
||||
def revert():
|
||||
self.set_tls(**priors)
|
||||
|
||||
return revert
|
||||
|
||||
def enabled(self) -> bool:
|
||||
return self.get("compiler") is not None
|
||||
|
||||
def enter_ctx(self) -> Callable[[], None]:
|
||||
state = self._get_tls()
|
||||
state.next_ctx_id += 1
|
||||
id = state.next_ctx_id
|
||||
|
||||
def exit():
|
||||
assert (
|
||||
state is self._get_tls()
|
||||
), "Runtime must begin and end on the same thread"
|
||||
assert state.next_ctx_id == id, (
|
||||
"Error nesting compiled autograd context managers: "
|
||||
"inner context managers must have shorter lifetime than the outer context manager"
|
||||
)
|
||||
state.next_ctx_id -= 1
|
||||
|
||||
return exit
|
||||
|
||||
def enter_compiled_region(self) -> Callable[[], None]:
|
||||
state = self._get_tls()
|
||||
prior = state.in_compiled_autograd_region
|
||||
state.in_compiled_autograd_region = True
|
||||
assert prior is False, "Nested compiled autograd regions are not supported"
|
||||
|
||||
def exit():
|
||||
assert (
|
||||
state is self._get_tls()
|
||||
), "Runtime must begin and end on the same thread"
|
||||
assert state.in_compiled_autograd_region is True
|
||||
state.in_compiled_autograd_region = prior
|
||||
|
||||
return exit
|
||||
|
||||
|
||||
local = TLSWrapper()
|
||||
|
||||
|
||||
def maybe_clone(x):
|
||||
|
|
@ -307,7 +386,7 @@ class AutogradCompilerInstance:
|
|||
self.rename_aot_dispatcher_nodes()
|
||||
self.reorder_accumulate_grad_nodes()
|
||||
runtime_inputs_to_move: List[int] = []
|
||||
if snapshot_cudagraph_enabled():
|
||||
if torch._inductor.config.triton.cudagraphs:
|
||||
runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph)
|
||||
|
||||
graph = GraphModule(
|
||||
|
|
@ -329,16 +408,15 @@ class AutogradCompilerInstance:
|
|||
)
|
||||
|
||||
def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks):
|
||||
global in_compiled_autograd_region
|
||||
try:
|
||||
in_compiled_autograd_region = True
|
||||
exit_compiled_region = local.enter_compiled_region()
|
||||
for i in runtime_inputs_to_move:
|
||||
inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True)
|
||||
|
||||
with disable():
|
||||
return compiled_fn(inputs, sizes, scalars, hooks)
|
||||
finally:
|
||||
in_compiled_autograd_region = False
|
||||
exit_compiled_region()
|
||||
|
||||
return runtime_wrapper, self.compiler_fn(graph)
|
||||
|
||||
|
|
@ -510,15 +588,9 @@ class AutogradCompilerInstance:
|
|||
set_stack_trace(new_stack_trace)
|
||||
|
||||
|
||||
# state of the autograd engine dispatch, kept in sync by enable/disable context managers
|
||||
compiled_autograd_enabled = False
|
||||
|
||||
# global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager"
|
||||
compiled_autograd_enabled_force_eager = False
|
||||
|
||||
# global flag to check if we are processing graphs produced from a compiled autograd graph
|
||||
in_compiled_autograd_region = False
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def enable(compiler_fn):
|
||||
|
|
@ -538,39 +610,42 @@ def enable(compiler_fn):
|
|||
# we need to lazily import it, because of circular dependencies
|
||||
import torch._inductor.cudagraph_trees
|
||||
|
||||
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(
|
||||
functools.partial(AutogradCompilerInstance, compiler_fn)
|
||||
exit_ctx = local.enter_ctx()
|
||||
revert_tls = local.set_tls(
|
||||
compiler=functools.partial(AutogradCompilerInstance, compiler_fn),
|
||||
vlogger=verbose_log
|
||||
if torch._logging._internal.log_state.is_artifact_enabled(
|
||||
"compiled_autograd_verbose"
|
||||
)
|
||||
else None,
|
||||
)
|
||||
if snapshot_verbose_logging_enabled():
|
||||
torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log)
|
||||
global compiled_autograd_enabled
|
||||
compiled_autograd_enabled = True
|
||||
try:
|
||||
with torch.autograd.set_multithreading_enabled(False):
|
||||
yield
|
||||
finally:
|
||||
if not prior:
|
||||
compiled_autograd_enabled = False
|
||||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
|
||||
revert_tls()
|
||||
exit_ctx()
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def disable():
|
||||
prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
|
||||
global compiled_autograd_enabled
|
||||
compiled_autograd_enabled = False
|
||||
exit_ctx = local.enter_ctx()
|
||||
revert_tls = local.set_tls(
|
||||
compiler=None,
|
||||
vlogger=None,
|
||||
)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
if prior:
|
||||
compiled_autograd_enabled = True
|
||||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior)
|
||||
revert_tls()
|
||||
exit_ctx()
|
||||
|
||||
|
||||
# return to starting state of a new process
|
||||
def reset() -> None:
|
||||
global compiled_autograd_enabled
|
||||
compiled_autograd_enabled = False
|
||||
assert not in_compiled_autograd_region
|
||||
torch._C._dynamo.compiled_autograd.set_autograd_compiler(None)
|
||||
torch._C._dynamo.compiled_autograd.set_verbose_logger(None)
|
||||
assert local.get("next_ctx_id") == 0
|
||||
assert local.get("in_compiled_autograd_region") is False
|
||||
local.set_tls(
|
||||
compiler=None,
|
||||
vlogger=None,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -472,6 +472,11 @@ compiled_autograd = False
|
|||
# Overrides torch.compile() kwargs for Compiled Autograd:
|
||||
compiled_autograd_kwargs_override: Dict[str, Any] = {}
|
||||
|
||||
# Compiled Autograd will attempt to automatically wrap C++ autograd functions found in the autograd graph,
|
||||
# and make them opaque to the compiler. This does not work when the C++ backward implementation involves
|
||||
# other dispatcher subsystems e.g. custom subclasses, autocast, vmap.
|
||||
compiled_autograd_opaque_cpp_node = False
|
||||
|
||||
# Enables use of collectives *during* compilation to synchronize behavior
|
||||
# across ranks. Today, this is used solely to modify automatic_dynamic_shapes
|
||||
# behavior, making it so that we infer that if an input is dynamic by
|
||||
|
|
|
|||
|
|
@ -3051,7 +3051,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm):
|
|||
if node.op == "placeholder" and node.meta.get("steal_arg", False)
|
||||
]
|
||||
|
||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
|
||||
# fast path, avoid pytree overhead
|
||||
# compiled autograd inputs are always a list of tensors, maybe followed by symints
|
||||
assert inputs_idx_to_clear == [0]
|
||||
|
|
|
|||
|
|
@ -313,7 +313,7 @@ class BackwardHookVariable(VariableTracker):
|
|||
user_hooks: VariableTracker,
|
||||
user_pre_hooks: VariableTracker,
|
||||
):
|
||||
if not compiled_autograd.compiled_autograd_enabled:
|
||||
if not compiled_autograd.local.enabled():
|
||||
unimplemented("module-level backwards hooks require compiled autograd")
|
||||
|
||||
def _in_graph_bw_hooks(bw_state: BackwardState):
|
||||
|
|
|
|||
|
|
@ -929,7 +929,7 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
|
|||
kwargs: "Dict[str, VariableTracker]",
|
||||
) -> "VariableTracker":
|
||||
if name == "queue_callback":
|
||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
|
||||
assert (
|
||||
tx.one_graph
|
||||
), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
|
||||
|
|
|
|||
|
|
@ -1007,7 +1007,7 @@ class TensorVariable(VariableTracker):
|
|||
tx = InstructionTranslator.current_tx()
|
||||
|
||||
if not self.source:
|
||||
if not compiled_autograd.compiled_autograd_enabled:
|
||||
if not compiled_autograd.local.enabled():
|
||||
# TODO(voz):
|
||||
# We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary
|
||||
# python state.
|
||||
|
|
|
|||
|
|
@ -176,7 +176,7 @@ def check_cacheable(gm: torch.fx.GraphModule):
|
|||
Checks that the graph module only uses supported operators
|
||||
"""
|
||||
nodes = gm.graph.nodes
|
||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
|
||||
raise BypassAOTAutogradCache(
|
||||
"Cannot cache a graph with compiled autograd enabled"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -704,7 +704,7 @@ from a multi-output view call"
|
|||
traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats]
|
||||
nonlocal static_input_indices
|
||||
static_input_indices = static_input_indices or []
|
||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
|
||||
passed_indices = set(static_input_indices)
|
||||
static_input_indices = [
|
||||
i
|
||||
|
|
|
|||
|
|
@ -760,7 +760,7 @@ def aot_dispatch_autograd(
|
|||
# becomes the lazy version again. One example is when dynamic shape is enabled
|
||||
# upfront, the bw_compiler will be called above which can cause extra
|
||||
# graph module recompilation on bw_module.
|
||||
if torch._dynamo.compiled_autograd.in_compiled_autograd_region:
|
||||
if torch._dynamo.compiled_autograd.local.get("in_compiled_autograd_region"):
|
||||
from torch.fx._lazy_graph_module import _LazyGraphModule
|
||||
|
||||
_LazyGraphModule.force_recompile(bw_module)
|
||||
|
|
|
|||
|
|
@ -224,11 +224,29 @@ struct LiftedIValueArgs {
|
|||
const std::optional<size_t>& active_node_call_idx;
|
||||
};
|
||||
|
||||
// Hold GIL while using
|
||||
struct PyTLSWrapper {
|
||||
PyTLSWrapper(PyObject* state) : state(state) {}
|
||||
PyTLSWrapper(const PyTLSWrapper&) = delete;
|
||||
PyTLSWrapper& operator=(const PyTLSWrapper&) = delete;
|
||||
PyTLSWrapper(PyTLSWrapper&&) = default;
|
||||
PyTLSWrapper& operator=(PyTLSWrapper&&) = default;
|
||||
|
||||
static PyTLSWrapper create();
|
||||
|
||||
PyObject* get(std::string_view key) const;
|
||||
|
||||
private:
|
||||
PyObject* state;
|
||||
};
|
||||
|
||||
struct AutogradCompilerCall {
|
||||
AutogradCompilerCall()
|
||||
AutogradCompilerCall() = delete;
|
||||
AutogradCompilerCall(PyTLSWrapper&& state)
|
||||
: active_node_call_idx(std::nullopt),
|
||||
tensor_args(active_node_call_idx),
|
||||
lifted_ivalue_args(active_node_call_idx) {}
|
||||
lifted_ivalue_args(active_node_call_idx),
|
||||
state(std::move(state)) {}
|
||||
void add_size_input(const c10::SymInt& s) {
|
||||
all_size_inputs.emplace_back(
|
||||
default_dyn_type, s.guard_int(__FILE__, __LINE__));
|
||||
|
|
@ -254,8 +272,11 @@ struct AutogradCompilerCall {
|
|||
std::vector<c10::SafePyObject> hooks;
|
||||
NodeCalls node_calls;
|
||||
SizeInput::DynType default_dyn_type = SizeInput::STATIC;
|
||||
|
||||
// NodeCall id of each size, only when verbose logging is enabled
|
||||
std::vector<uint32_t> size_input_origins;
|
||||
|
||||
const PyTLSWrapper state;
|
||||
};
|
||||
|
||||
class CompiledNodeArgs {
|
||||
|
|
|
|||
|
|
@ -88,10 +88,6 @@ static void check(bool result) {
|
|||
if (C10_UNLIKELY(!result))
|
||||
check(nullptr);
|
||||
}
|
||||
|
||||
// snapshot of python verbose logging toggle
|
||||
static PyObject* python_verbose_logger = nullptr;
|
||||
|
||||
struct PythonLogger {
|
||||
PythonLogger() = delete;
|
||||
explicit PythonLogger(PyObject* logger) : logger_(logger) {
|
||||
|
|
@ -135,15 +131,15 @@ struct PythonLogger {
|
|||
};
|
||||
|
||||
struct VerboseLogger : public PythonLogger {
|
||||
static std::optional<VerboseLogger> maybe_create() {
|
||||
if (python_verbose_logger == nullptr) {
|
||||
VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {}
|
||||
|
||||
static std::optional<VerboseLogger> maybe_create(PyObject* vlogger) {
|
||||
if (vlogger == Py_None) {
|
||||
return std::nullopt;
|
||||
}
|
||||
return VerboseLogger(python_verbose_logger);
|
||||
return VerboseLogger(vlogger);
|
||||
}
|
||||
|
||||
VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {}
|
||||
|
||||
void log_node_check(
|
||||
const Node& fn,
|
||||
size_t size_inputs_num,
|
||||
|
|
@ -368,8 +364,22 @@ struct InputBuffers : public std::unordered_map<Node*, InputBuffer> {
|
|||
}
|
||||
};
|
||||
|
||||
static PyObject* the_autograd_compiler = nullptr;
|
||||
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args);
|
||||
/* static */ PyTLSWrapper PyTLSWrapper::create() {
|
||||
TORCH_INTERNAL_ASSERT(
|
||||
at::impl::ThreadLocalPythonObjects::contains("compiled_autograd_state"));
|
||||
PyObject* compiled_autograd_state =
|
||||
check(at::impl::ThreadLocalPythonObjects::get("compiled_autograd_state")
|
||||
->ptr(getPyInterpreter()));
|
||||
return PyTLSWrapper(compiled_autograd_state);
|
||||
}
|
||||
|
||||
// Refer to fields in python class CompiledAutogradTLS
|
||||
// May return Py_None
|
||||
PyObject* PyTLSWrapper::get(std::string_view key) const {
|
||||
return check(PyObject_GetAttrString(state, key.data()));
|
||||
}
|
||||
|
||||
static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args);
|
||||
|
||||
static PyObject* clear_cache(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
|
|
@ -387,28 +397,11 @@ static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) {
|
|||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
PyObject* logger = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "O", &logger)) {
|
||||
throw_python_error();
|
||||
}
|
||||
|
||||
if (logger == Py_None) {
|
||||
python_verbose_logger = nullptr;
|
||||
} else {
|
||||
python_verbose_logger = logger;
|
||||
}
|
||||
Py_RETURN_TRUE;
|
||||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(*array*)
|
||||
static PyMethodDef _methods[] = {
|
||||
{"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr},
|
||||
{"notify_autograd_engine", notify_autograd_engine, METH_NOARGS, nullptr},
|
||||
{"clear_cache", clear_cache, METH_NOARGS, nullptr},
|
||||
{"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr},
|
||||
{"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr},
|
||||
{nullptr, nullptr, 0, nullptr}};
|
||||
|
||||
static struct PyModuleDef _module = {
|
||||
|
|
@ -568,7 +561,7 @@ CacheNode* _compiled_autograd_impl(
|
|||
THPObjectPtr* graph_arg_hooks) {
|
||||
std::unordered_map<Node*, int>& dependencies = graph_task.dependencies_;
|
||||
std::vector<std::shared_ptr<Node>> worklist{graph_root};
|
||||
AutogradCompilerCall compiler_call;
|
||||
AutogradCompilerCall compiler_call(PyTLSWrapper::create());
|
||||
|
||||
for (const auto i : c10::irange(output_edges.size())) {
|
||||
compiler_call.node_calls
|
||||
|
|
@ -583,7 +576,8 @@ CacheNode* _compiled_autograd_impl(
|
|||
check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1);
|
||||
|
||||
int i = 0;
|
||||
std::optional<VerboseLogger> vlogger = VerboseLogger::maybe_create();
|
||||
std::optional<VerboseLogger> vlogger =
|
||||
VerboseLogger::maybe_create(compiler_call.state.get("vlogger"));
|
||||
while (!worklist.empty()) {
|
||||
std::shared_ptr<Node> fn = std::move(worklist.back());
|
||||
worklist.pop_back();
|
||||
|
|
@ -642,6 +636,8 @@ CacheNode* _compiled_autograd_impl(
|
|||
// TODO(jansel): some dynamic sizes seem to be ints not symints
|
||||
if (!cache->check_dynamic_sizes(compiler_call, vlogger)) {
|
||||
// cache miss, need to capture FX graph
|
||||
PyObject* the_autograd_compiler = compiler_call.state.get("compiler");
|
||||
TORCH_INTERNAL_ASSERT(the_autograd_compiler != Py_None);
|
||||
ClosingTHPObjectPtr py_compiler(
|
||||
check(PyObject_CallNoArgs((the_autograd_compiler))));
|
||||
|
||||
|
|
@ -839,28 +835,16 @@ variable_list compiled_autograd(
|
|||
return outputs;
|
||||
}
|
||||
|
||||
static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) {
|
||||
static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args) {
|
||||
HANDLE_TH_ERRORS;
|
||||
PyObject* obj = nullptr;
|
||||
if (!PyArg_ParseTuple(args, "O", &obj)) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
PyObject* prior = the_autograd_compiler;
|
||||
if (obj == Py_None) { // disable
|
||||
the_autograd_compiler = nullptr; // decref not needed due to `prior`
|
||||
PyTLSWrapper state = PyTLSWrapper::create();
|
||||
PyObject* compiler = state.get("compiler");
|
||||
if (compiler == Py_None) { // disable
|
||||
Engine::set_compiled_autograd(nullptr);
|
||||
} else { // enable
|
||||
Py_INCREF(obj);
|
||||
the_autograd_compiler = obj;
|
||||
Engine::set_compiled_autograd(&compiled_autograd);
|
||||
}
|
||||
|
||||
if (prior == nullptr) {
|
||||
Py_RETURN_NONE;
|
||||
} else {
|
||||
return prior;
|
||||
}
|
||||
Py_RETURN_NONE;
|
||||
END_HANDLE_TH_ERRORS;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,9 +33,9 @@ else:
|
|||
import torch._dynamo.compiled_autograd as ca
|
||||
|
||||
_compiled_autograd_enabled = (
|
||||
ca.compiled_autograd_enabled
|
||||
ca.local.enabled()
|
||||
or ca.compiled_autograd_enabled_force_eager
|
||||
or ca.in_compiled_autograd_region
|
||||
or ca.local.get("in_compiled_autograd_region")
|
||||
)
|
||||
|
||||
def compiled_autograd_enabled():
|
||||
|
|
|
|||
Loading…
Reference in a new issue