From d0aa1386b8c61b468302bd2d6100a689a3eec7dc Mon Sep 17 00:00:00 2001 From: James Wu Date: Wed, 29 Jan 2025 08:38:52 -0800 Subject: [PATCH] Disable AOTAutogradCache for triton version < 3.2 (#145937) Pull Request resolved: https://github.com/pytorch/pytorch/pull/145937 Approved by: https://github.com/bdhirsh --- torch/_functorch/_aot_autograd/autograd_cache.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/torch/_functorch/_aot_autograd/autograd_cache.py b/torch/_functorch/_aot_autograd/autograd_cache.py index 7cdf7a14764..0f56826d25a 100644 --- a/torch/_functorch/_aot_autograd/autograd_cache.py +++ b/torch/_functorch/_aot_autograd/autograd_cache.py @@ -37,6 +37,7 @@ from torch._inductor.utils import should_use_remote_fx_graph_cache from torch._logging import LazyString from torch._utils_internal import log_cache_bypass from torch.compiler._cache import CacheArtifactManager, CacheArtifactType +from torch.utils._triton import has_triton_package from torchgen.utils import dataclass_repr from .runtime_wrappers import ( @@ -318,6 +319,19 @@ def autograd_cache_key( Generate a unique hash of the FX graph for caching. """ check_cacheable(gm) + if has_triton_package(): + # Due to https://github.com/triton-lang/triton/issues/3729, + # if triton is < 3.2.0, AOTAutogradCache may cause us to + # attempt to load a cache entry without initializing + # the CUDA context on the autograd thread. + + # Without caching, we naturally do this initialization when + # tracing through the graph with the autograd engine. + import triton + + if triton.__version__ < "3.2.0": + raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0") + details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config) pickler = AOTAutogradCachePickler(gm) # The prefix distinguishes among the other kinds of objects we cache