diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 4d02a06af69..b1552909d45 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -256,7 +256,7 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, f"Unexpected number of `inductor.resize_storage_bytes_` ops (expected {resize_count}, got {actual_resize_count}) in graph: {graph}", # noqa: B950 ) - if not torch._dynamo.compiled_autograd.in_compiled_autograd_region(): + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: _check_count(fwd_copy_count, fwd_resize_count) # fwd graph else: _check_count(bwd_copy_count, bwd_resize_count) # bwd graph diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index afdc7bcadf6..e82ad38d1fb 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -86,7 +86,7 @@ def count_ops( def collect_fwd_graph_outputs(graph: torch.fx.Graph, *, fwd_outputs: Set[str]): - if not torch._dynamo.compiled_autograd.in_compiled_autograd_region(): # fwd graph + if not torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fwd graph return_node = list(graph.nodes)[-1] assert return_node.target == "output" for x in return_node.args[0]: diff --git a/test/inductor/test_compiled_autograd.py b/test/inductor/test_compiled_autograd.py index f37b71ca59a..5ec5715c38a 100644 --- a/test/inductor/test_compiled_autograd.py +++ b/test/inductor/test_compiled_autograd.py @@ -6,11 +6,9 @@ import io import itertools import logging import os -import queue import re import subprocess import sys -import threading import unittest from importlib.machinery import SourceFileLoader from pathlib import Path @@ -2415,39 +2413,6 @@ TORCH_LIBRARY(test_cudagraphs_cpu_scalar_used_in_cpp_custom_op, m) { not in logs.getvalue() ) - def test_multithreading_tls(self): - def train(errors, model, x): - try: - out = model(x) - with compiled_autograd.enable(compiler_fn): - self.assertEqual(compiled_autograd.enabled(), True) - self.assertEqual(compiled_autograd.local.get("next_ctx_id"), 1) - except Exception as e: - print(f"Found error: {e}") - errors.put(1) - raise - - model = torch.nn.Sequential( - torch.nn.Linear(4, 4), - torch.nn.ReLU(), - torch.nn.Linear(4, 4), - torch.nn.ReLU(), - ) - x = torch.randn([2, 4]) - - threads = [] - errors = queue.Queue() - with compiled_autograd.enable(compiler_fn): - for i in range(4): - thread = threading.Thread(target=train, args=(errors, model, x)) - threads.append(thread) - thread.start() - - for thread in threads: - thread.join() - - assert errors.empty() - def test_verbose_logs_graph(self): def fn(): model = torch.nn.Sequential( diff --git a/torch/_C/_dynamo/compiled_autograd.pyi b/torch/_C/_dynamo/compiled_autograd.pyi index b308f63844e..80144e3a779 100644 --- a/torch/_C/_dynamo/compiled_autograd.pyi +++ b/torch/_C/_dynamo/compiled_autograd.pyi @@ -1,6 +1,10 @@ from typing import Callable -def notify_autograd_engine() -> None: ... +from torch._dynamo.compiled_autograd import AutogradCompilerInstance + +def set_autograd_compiler( + autograd_compiler: Callable[[], AutogradCompilerInstance] | None, +) -> Callable[[], AutogradCompilerInstance] | None: ... def clear_cache() -> None: ... def is_cache_empty() -> bool: ... def set_verbose_logger(fn: Callable[[str], None] | None) -> bool: ... diff --git a/torch/_dynamo/compiled_autograd.py b/torch/_dynamo/compiled_autograd.py index ada3aa4eee0..950737d4bcb 100644 --- a/torch/_dynamo/compiled_autograd.py +++ b/torch/_dynamo/compiled_autograd.py @@ -1,10 +1,7 @@ # mypy: allow-untyped-defs import contextlib import functools -import threading -from dataclasses import dataclass -from logging import Logger -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union +from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING, Union import torch from torch._dynamo.external_utils import ( @@ -41,95 +38,14 @@ compiled_autograd_log = getArtifactLogger(__name__, "compiled_autograd") verbose_log = getArtifactLogger(__name__, "compiled_autograd_verbose") -@dataclass -class CompiledAutogradTLS: - next_ctx_id: int = 0 - in_compiled_autograd_region: bool = False - compiler: Optional["AutogradCompilerInstance"] = None - vlogger: Optional[Logger] = None +def snapshot_verbose_logging_enabled(): + return torch._logging._internal.log_state.is_artifact_enabled( + "compiled_autograd_verbose" + ) -class TLSWrapper: - tls_key = "compiled_autograd_state" - - def __init__(self): - self._local = threading.local() - - def _get_tls(self) -> CompiledAutogradTLS: - if hasattr(self._local, self.tls_key): - # first look in python - state = getattr(self._local, self.tls_key) - if torch._C._is_key_in_tls(self.tls_key): - # then look in cpp - state = torch._C._get_obj_in_tls(self.tls_key) - else: - # init new thread created outside of autograd - # TODO: what if context manager wrapped outside of thread? - setattr(self._local, self.tls_key, CompiledAutogradTLS()) - state = getattr(self._local, self.tls_key) - torch._C._stash_obj_in_tls(self.tls_key, state) - return state - - # queries on the object stored in TLS - def get(self, name): - return getattr(self._get_tls(), name) - - def set_tls(self, **kwargs) -> Callable[[], None]: - priors: Dict[str, Any] = {} - for k, v in kwargs.items(): - state = self._get_tls() - priors[k] = getattr(state, k) - setattr(state, k, v) - - torch._C._dynamo.compiled_autograd.notify_autograd_engine() - - def revert(): - self.set_tls(**priors) - - return revert - - def enter_ctx(self) -> Callable[[], None]: - state = self._get_tls() - state.next_ctx_id += 1 - id = state.next_ctx_id - - def exit(): - assert ( - state is self._get_tls() - ), "Runtime must begin and end on the same thread" - assert state.next_ctx_id == id, ( - "Error nesting compiled autograd context managers: " - "inner context managers must have shorter lifetime than the outer context manager" - ) - state.next_ctx_id -= 1 - - return exit - - def enter_compiled_region(self) -> Callable[[], None]: - state = self._get_tls() - prior = state.in_compiled_autograd_region - state.in_compiled_autograd_region = True - assert prior is False, "Nested compiled autograd regions are not supported" - - def exit(): - assert ( - state is self._get_tls() - ), "Runtime must begin and end on the same thread" - assert state.in_compiled_autograd_region is True - state.in_compiled_autograd_region = prior - - return exit - - -local = TLSWrapper() - - -def enabled() -> bool: - return local.get("compiler") is not None - - -def in_compiled_autograd_region() -> bool: - return local.get("in_compiled_autograd_region") +def snapshot_cudagraph_enabled(): + return torch._inductor.config.triton.cudagraphs def maybe_clone(x): @@ -391,7 +307,7 @@ class AutogradCompilerInstance: self.rename_aot_dispatcher_nodes() self.reorder_accumulate_grad_nodes() runtime_inputs_to_move: List[int] = [] - if torch._inductor.config.triton.cudagraphs: + if snapshot_cudagraph_enabled(): runtime_inputs_to_move = self.move_graph_nodes_to_cuda(self.fx_tracer.graph) graph = GraphModule( @@ -413,15 +329,16 @@ class AutogradCompilerInstance: ) def runtime_wrapper(compiled_fn, inputs, sizes, scalars, hooks): + global in_compiled_autograd_region try: - exit_compiled_region = local.enter_compiled_region() + in_compiled_autograd_region = True for i in runtime_inputs_to_move: inputs[i] = inputs[i].pin_memory().cuda(non_blocking=True) with disable(): return compiled_fn(inputs, sizes, scalars, hooks) finally: - exit_compiled_region() + in_compiled_autograd_region = False return runtime_wrapper, self.compiler_fn(graph) @@ -593,9 +510,15 @@ class AutogradCompilerInstance: set_stack_trace(new_stack_trace) +# state of the autograd engine dispatch, kept in sync by enable/disable context managers +compiled_autograd_enabled = False + # global flag to check if compiled autograd is enabled but Dynamo stance is "force_eager" compiled_autograd_enabled_force_eager = False +# global flag to check if we are processing graphs produced from a compiled autograd graph +in_compiled_autograd_region = False + @contextlib.contextmanager def enable(compiler_fn): @@ -615,42 +538,39 @@ def enable(compiler_fn): # we need to lazily import it, because of circular dependencies import torch._inductor.cudagraph_trees - exit_ctx = local.enter_ctx() - revert_tls = local.set_tls( - compiler=functools.partial(AutogradCompilerInstance, compiler_fn), - vlogger=verbose_log - if torch._logging._internal.log_state.is_artifact_enabled( - "compiled_autograd_verbose" - ) - else None, + prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler( + functools.partial(AutogradCompilerInstance, compiler_fn) ) + if snapshot_verbose_logging_enabled(): + torch._C._dynamo.compiled_autograd.set_verbose_logger(verbose_log) + global compiled_autograd_enabled + compiled_autograd_enabled = True try: with torch.autograd.set_multithreading_enabled(False): yield finally: - revert_tls() - exit_ctx() + if not prior: + compiled_autograd_enabled = False + torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) @contextlib.contextmanager def disable(): - exit_ctx = local.enter_ctx() - revert_tls = local.set_tls( - compiler=None, - vlogger=None, - ) + prior = torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) + global compiled_autograd_enabled + compiled_autograd_enabled = False try: yield finally: - revert_tls() - exit_ctx() + if prior: + compiled_autograd_enabled = True + torch._C._dynamo.compiled_autograd.set_autograd_compiler(prior) # return to starting state of a new process def reset() -> None: - assert local.get("next_ctx_id") == 0 - assert local.get("in_compiled_autograd_region") is False - local.set_tls( - compiler=None, - vlogger=None, - ) + global compiled_autograd_enabled + compiled_autograd_enabled = False + assert not in_compiled_autograd_region + torch._C._dynamo.compiled_autograd.set_autograd_compiler(None) + torch._C._dynamo.compiled_autograd.set_verbose_logger(None) diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index e974e4ccb85..cbe6411cc4a 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -468,11 +468,6 @@ compiled_autograd = False # Overrides torch.compile() kwargs for Compiled Autograd: compiled_autograd_kwargs_override: Dict[str, Any] = {} -# Compiled Autograd will attempt to automatically wrap C++ autograd functions found in the autograd graph, -# and make them opaque to the compiler. This does not work when the C++ backward implementation involves -# other dispatcher subsystems e.g. custom subclasses, autocast, vmap. -compiled_autograd_opaque_cpp_node = False - # Enables use of collectives *during* compilation to synchronize behavior # across ranks. Today, this is used solely to modify automatic_dynamic_shapes # behavior, making it so that we infer that if an input is dynamic by diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index b991eb6456c..f3c35d263d5 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -3175,7 +3175,7 @@ def flatten_graph_inputs(gm: torch.fx.GraphModule, inputs, compile_gm): if node.op == "placeholder" and node.meta.get("steal_arg", False) ] - if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: # fast path, avoid pytree overhead # compiled autograd inputs are always a list of tensors, maybe followed by symints assert inputs_idx_to_clear == [0] diff --git a/torch/_dynamo/variables/distributed.py b/torch/_dynamo/variables/distributed.py index 6afffc15ad1..c14b8794cba 100644 --- a/torch/_dynamo/variables/distributed.py +++ b/torch/_dynamo/variables/distributed.py @@ -313,7 +313,7 @@ class BackwardHookVariable(VariableTracker): user_hooks: VariableTracker, user_pre_hooks: VariableTracker, ): - if not compiled_autograd.enabled(): + if not compiled_autograd.compiled_autograd_enabled: unimplemented("module-level backwards hooks require compiled autograd") def _in_graph_bw_hooks(bw_state: BackwardState): diff --git a/torch/_dynamo/variables/misc.py b/torch/_dynamo/variables/misc.py index 32a09df5961..7f4ad96601a 100644 --- a/torch/_dynamo/variables/misc.py +++ b/torch/_dynamo/variables/misc.py @@ -929,7 +929,7 @@ class AutogradEngineVariable(UserDefinedObjectVariable): kwargs: "Dict[str, VariableTracker]", ) -> "VariableTracker": if name == "queue_callback": - if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: assert ( tx.one_graph ), "queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True" diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py index 8c55e7d5a6c..4bbefc1c1df 100644 --- a/torch/_dynamo/variables/tensor.py +++ b/torch/_dynamo/variables/tensor.py @@ -1015,7 +1015,7 @@ class TensorVariable(VariableTracker): tx = InstructionTranslator.current_tx() if not self.source: - if not compiled_autograd.enabled(): + if not compiled_autograd.compiled_autograd_enabled: # TODO(voz): # We can relax this by speculating the callable and ensuring that it doesn't modify arbitrary # python state. diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index fe309bb96b2..fc908dfb5f6 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -177,7 +177,7 @@ def check_cacheable(gm: torch.fx.GraphModule): Checks that the graph module only uses supported operators """ nodes = gm.graph.nodes - if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: raise BypassAOTAutogradCache( "Cannot cache a graph with compiled autograd enabled" ) diff --git a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py index 01d6bca5a10..dfef0b3a19a 100644 --- a/torch/_functorch/_aot_autograd/collect_metadata_analysis.py +++ b/torch/_functorch/_aot_autograd/collect_metadata_analysis.py @@ -707,7 +707,7 @@ from a multi-output view call" traced_tangent_memory_formats = [t[1] for t in tangents_and_memory_formats] nonlocal static_input_indices static_input_indices = static_input_indices or [] - if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: passed_indices = set(static_input_indices) static_input_indices = [ i diff --git a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py index fda2a22d202..4e6263e153b 100644 --- a/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py +++ b/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py @@ -758,7 +758,7 @@ def aot_dispatch_autograd( # becomes the lazy version again. One example is when dynamic shape is enabled # upfront, the bw_compiler will be called above which can cause extra # graph module recompilation on bw_module. - if torch._dynamo.compiled_autograd.in_compiled_autograd_region(): + if torch._dynamo.compiled_autograd.in_compiled_autograd_region: from torch.fx._lazy_graph_module import _LazyGraphModule _LazyGraphModule.force_recompile(bw_module) diff --git a/torch/csrc/dynamo/compiled_autograd.h b/torch/csrc/dynamo/compiled_autograd.h index 38530f8e8ae..40c2db639ec 100644 --- a/torch/csrc/dynamo/compiled_autograd.h +++ b/torch/csrc/dynamo/compiled_autograd.h @@ -224,29 +224,11 @@ struct LiftedIValueArgs { const std::optional& active_node_call_idx; }; -// Hold GIL while using -struct PyTLSWrapper { - PyTLSWrapper(PyObject* state) : state(state) {} - PyTLSWrapper(const PyTLSWrapper&) = delete; - PyTLSWrapper& operator=(const PyTLSWrapper&) = delete; - PyTLSWrapper(PyTLSWrapper&&) = default; - PyTLSWrapper& operator=(PyTLSWrapper&&) = default; - - static PyTLSWrapper create(); - - PyObject* get(std::string_view key) const; - - private: - PyObject* state; -}; - struct AutogradCompilerCall { - AutogradCompilerCall() = delete; - AutogradCompilerCall(PyTLSWrapper&& state) + AutogradCompilerCall() : active_node_call_idx(std::nullopt), tensor_args(active_node_call_idx), - lifted_ivalue_args(active_node_call_idx), - state(std::move(state)) {} + lifted_ivalue_args(active_node_call_idx) {} void add_size_input(const c10::SymInt& s) { all_size_inputs.emplace_back( default_dyn_type, s.guard_int(__FILE__, __LINE__)); @@ -272,11 +254,8 @@ struct AutogradCompilerCall { std::vector hooks; NodeCalls node_calls; SizeInput::DynType default_dyn_type = SizeInput::STATIC; - // NodeCall id of each size, only when verbose logging is enabled std::vector size_input_origins; - - const PyTLSWrapper state; }; class CompiledNodeArgs { diff --git a/torch/csrc/dynamo/python_compiled_autograd.cpp b/torch/csrc/dynamo/python_compiled_autograd.cpp index eb11bd1b65a..33dac77d743 100644 --- a/torch/csrc/dynamo/python_compiled_autograd.cpp +++ b/torch/csrc/dynamo/python_compiled_autograd.cpp @@ -88,6 +88,10 @@ static void check(bool result) { if (C10_UNLIKELY(!result)) check(nullptr); } + +// snapshot of python verbose logging toggle +static PyObject* python_verbose_logger = nullptr; + struct PythonLogger { PythonLogger() = delete; explicit PythonLogger(PyObject* logger) : logger_(logger) { @@ -131,15 +135,15 @@ struct PythonLogger { }; struct VerboseLogger : public PythonLogger { - VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {} - - static std::optional maybe_create(PyObject* vlogger) { - if (vlogger == Py_None) { + static std::optional maybe_create() { + if (python_verbose_logger == nullptr) { return std::nullopt; } - return VerboseLogger(vlogger); + return VerboseLogger(python_verbose_logger); } + VerboseLogger(PyObject* vlogger) : PythonLogger(vlogger) {} + void log_node_check( const Node& fn, size_t size_inputs_num, @@ -364,22 +368,8 @@ struct InputBuffers : public std::unordered_map { } }; -/* static */ PyTLSWrapper PyTLSWrapper::create() { - TORCH_INTERNAL_ASSERT( - at::impl::ThreadLocalPythonObjects::contains("compiled_autograd_state")); - PyObject* compiled_autograd_state = - check(at::impl::ThreadLocalPythonObjects::get("compiled_autograd_state") - ->ptr(getPyInterpreter())); - return PyTLSWrapper(compiled_autograd_state); -} - -// Refer to fields in python class CompiledAutogradTLS -// May return Py_None -PyObject* PyTLSWrapper::get(std::string_view key) const { - return check(PyObject_GetAttrString(state, key.data())); -} - -static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args); +static PyObject* the_autograd_compiler = nullptr; +static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args); static PyObject* clear_cache(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; @@ -397,11 +387,28 @@ static PyObject* is_cache_empty(PyObject* dummy, PyObject* args) { END_HANDLE_TH_ERRORS; } +static PyObject* set_verbose_logger(PyObject* dummy, PyObject* args) { + HANDLE_TH_ERRORS; + PyObject* logger = nullptr; + if (!PyArg_ParseTuple(args, "O", &logger)) { + throw_python_error(); + } + + if (logger == Py_None) { + python_verbose_logger = nullptr; + } else { + python_verbose_logger = logger; + } + Py_RETURN_TRUE; + END_HANDLE_TH_ERRORS; +} + // NOLINTNEXTLINE(*array*) static PyMethodDef _methods[] = { - {"notify_autograd_engine", notify_autograd_engine, METH_NOARGS, nullptr}, + {"set_autograd_compiler", set_autograd_compiler, METH_VARARGS, nullptr}, {"clear_cache", clear_cache, METH_NOARGS, nullptr}, {"is_cache_empty", is_cache_empty, METH_NOARGS, nullptr}, + {"set_verbose_logger", set_verbose_logger, METH_VARARGS, nullptr}, {nullptr, nullptr, 0, nullptr}}; static struct PyModuleDef _module = { @@ -561,7 +568,7 @@ CacheNode* _compiled_autograd_impl( THPObjectPtr* graph_arg_hooks) { std::unordered_map& dependencies = graph_task.dependencies_; std::vector> worklist{graph_root}; - AutogradCompilerCall compiler_call(PyTLSWrapper::create()); + AutogradCompilerCall compiler_call; for (const auto i : c10::irange(output_edges.size())) { compiler_call.node_calls @@ -576,8 +583,7 @@ CacheNode* _compiled_autograd_impl( check_exec_info ? graph_task.exec_info_.size() : dependencies.size() + 1); int i = 0; - std::optional vlogger = - VerboseLogger::maybe_create(compiler_call.state.get("vlogger")); + std::optional vlogger = VerboseLogger::maybe_create(); while (!worklist.empty()) { std::shared_ptr fn = std::move(worklist.back()); worklist.pop_back(); @@ -636,8 +642,6 @@ CacheNode* _compiled_autograd_impl( // TODO(jansel): some dynamic sizes seem to be ints not symints if (!cache->check_dynamic_sizes(compiler_call, vlogger)) { // cache miss, need to capture FX graph - PyObject* the_autograd_compiler = compiler_call.state.get("compiler"); - TORCH_INTERNAL_ASSERT(the_autograd_compiler != Py_None); ClosingTHPObjectPtr py_compiler( check(PyObject_CallNoArgs((the_autograd_compiler)))); @@ -835,16 +839,28 @@ variable_list compiled_autograd( return outputs; } -static PyObject* notify_autograd_engine(PyObject* dummy, PyObject* args) { +static PyObject* set_autograd_compiler(PyObject* dummy, PyObject* args) { HANDLE_TH_ERRORS; - PyTLSWrapper state = PyTLSWrapper::create(); - PyObject* compiler = state.get("compiler"); - if (compiler == Py_None) { // disable + PyObject* obj = nullptr; + if (!PyArg_ParseTuple(args, "O", &obj)) { + return nullptr; + } + + PyObject* prior = the_autograd_compiler; + if (obj == Py_None) { // disable + the_autograd_compiler = nullptr; // decref not needed due to `prior` Engine::set_compiled_autograd(nullptr); } else { // enable + Py_INCREF(obj); + the_autograd_compiler = obj; Engine::set_compiled_autograd(&compiled_autograd); } - Py_RETURN_NONE; + + if (prior == nullptr) { + Py_RETURN_NONE; + } else { + return prior; + } END_HANDLE_TH_ERRORS; } diff --git a/torch/distributed/_composable/fsdp/_fsdp_common.py b/torch/distributed/_composable/fsdp/_fsdp_common.py index d967a55d254..74c6f4fdfea 100644 --- a/torch/distributed/_composable/fsdp/_fsdp_common.py +++ b/torch/distributed/_composable/fsdp/_fsdp_common.py @@ -33,9 +33,9 @@ else: import torch._dynamo.compiled_autograd as ca _compiled_autograd_enabled = ( - ca.enabled() + ca.compiled_autograd_enabled or ca.compiled_autograd_enabled_force_eager - or ca.in_compiled_autograd_region() + or ca.in_compiled_autograd_region ) def compiled_autograd_enabled():