mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Disable AOTAutogradCache for triton version < 3.2 (#145937)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/145937 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
1185b81c51
commit
d0aa1386b8
1 changed files with 14 additions and 0 deletions
|
|
@ -37,6 +37,7 @@ from torch._inductor.utils import should_use_remote_fx_graph_cache
|
|||
from torch._logging import LazyString
|
||||
from torch._utils_internal import log_cache_bypass
|
||||
from torch.compiler._cache import CacheArtifactManager, CacheArtifactType
|
||||
from torch.utils._triton import has_triton_package
|
||||
from torchgen.utils import dataclass_repr
|
||||
|
||||
from .runtime_wrappers import (
|
||||
|
|
@ -318,6 +319,19 @@ def autograd_cache_key(
|
|||
Generate a unique hash of the FX graph for caching.
|
||||
"""
|
||||
check_cacheable(gm)
|
||||
if has_triton_package():
|
||||
# Due to https://github.com/triton-lang/triton/issues/3729,
|
||||
# if triton is < 3.2.0, AOTAutogradCache may cause us to
|
||||
# attempt to load a cache entry without initializing
|
||||
# the CUDA context on the autograd thread.
|
||||
|
||||
# Without caching, we naturally do this initialization when
|
||||
# tracing through the graph with the autograd engine.
|
||||
import triton
|
||||
|
||||
if triton.__version__ < "3.2.0":
|
||||
raise BypassAOTAutogradCache("AOTAutogradCache requires triton 3.2.0")
|
||||
|
||||
details = AOTAutogradCacheDetails(gm, example_inputs, config, fx_config)
|
||||
pickler = AOTAutogradCachePickler(gm)
|
||||
# The prefix distinguishes among the other kinds of objects we cache
|
||||
|
|
|
|||
Loading…
Reference in a new issue