Disable AOTAutogradCache for triton version < 3.2 (#145937)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145937
Approved by: https://github.com/bdhirsh
This commit is contained in:
James Wu 2025-01-29 08:38:52 -08:00 committed by PyTorch MergeBot
parent 1185b81c51
commit d0aa1386b8

View file

@ -37,6 +37,7 @@ from torch._inductor.utils import should_use_remote_fx_graph_cache
from torch._logging import LazyString
from torch._utils_internal import log_cache_bypass
from torch.compiler._cache import CacheArtifactManager, CacheArtifactType
from torch.utils._triton import has_triton_package
from torchgen.utils import dataclass_repr
from .runtime_wrappers import (
@ -318,6 +319,19 @@ def autograd_cache_key(
Generate a unique hash of the FX graph for caching.
"""
check_cacheable(gm)
if has_triton_package():
# Due to https://github.com/triton-lang/triton/issues/3729,
# if triton is < 3.2.0, AOTAutogradCache may cause us to
# attempt to load a cache entry without initializing
# the CUDA context on the autograd thread.
# Without caching, we naturally do this initialization when
# tracing through the graph with the autograd engine.
import triton
if triton.__version__ < "3.2.0":
raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0")
details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config)
pickler = AOTAutogradCachePickler(gm)
# The prefix distinguishes among the other kinds of objects we cache