mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Fix TORCHINDUCTOR_FORCE_DISABLE_CACHES (#129257)
Summary: See https://github.com/pytorch/pytorch/issues/129159; this option wasn't doing its job for a few reasons. In this PR: * Fix the with_fresh_cache_if_config() decorator * Reset the "TORCHINDUCTOR_CACHE_DIR" & "TRITON_CACHE_DIR" env vars in sub-process to support them changing in the parent process Pull Request resolved: https://github.com/pytorch/pytorch/pull/129257 Approved by: https://github.com/oulgen
This commit is contained in:
parent
61bf1452a3
commit
87d14ad419
5 changed files with 27 additions and 15 deletions
|
|
@ -173,11 +173,16 @@ class AsyncCompile:
|
|||
|
||||
kernel = TritonCodeCache.load(kernel_name, source_code)
|
||||
if config.compile_threads > 1:
|
||||
# We want to support changing these env vars after (and while) the
|
||||
# process pool is running, so pass them to the subprocess to reset.
|
||||
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
|
||||
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
|
||||
return TritonFuture(
|
||||
kernel,
|
||||
self.process_pool().submit(
|
||||
_worker_compile_triton,
|
||||
kernel._reload_in_subproc,
|
||||
extra_env,
|
||||
),
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -43,8 +43,8 @@ from torch._inductor.cudagraph_utils import (
|
|||
get_placeholders,
|
||||
log_cudagraph_skip_and_bump_counter,
|
||||
)
|
||||
|
||||
from torch._inductor.debug import save_args_for_compile_fx_inner
|
||||
from torch._inductor.runtime.runtime_utils import cache_dir
|
||||
from torch._inductor.utils import (
|
||||
BoxedBool,
|
||||
count_tangents,
|
||||
|
|
@ -420,13 +420,19 @@ def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
|
|||
return config.get_config_copy()
|
||||
|
||||
|
||||
@functools.wraps
|
||||
def with_fresh_cache_if_config(f):
|
||||
if config.force_disable_caches:
|
||||
with fresh_inductor_cache():
|
||||
return f
|
||||
else:
|
||||
return f
|
||||
def with_fresh_cache_if_config(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if config.force_disable_caches:
|
||||
# Don't delete the cache dir because it has to survive beyond the
|
||||
# compile_fx call. Let's put the temp dirs under the default cache
|
||||
# dir so they're easier to locate.
|
||||
with fresh_inductor_cache(dir=cache_dir(), delete=False):
|
||||
return fn(*args, **kwargs)
|
||||
else:
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@DebugContext.wrap
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import sys
|
|||
import warnings
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any, Callable
|
||||
from typing import Any, Callable, Dict
|
||||
|
||||
|
||||
def _reload_triton_kernel_in_subproc(reload_module, kernel_name):
|
||||
|
|
@ -61,8 +61,7 @@ def _set_triton_ptxas_path() -> None:
|
|||
warnings.warn(f"{ptxas} exists but is not an executable")
|
||||
|
||||
|
||||
def _worker_compile_triton(
|
||||
load_kernel: Callable[[], Any],
|
||||
):
|
||||
def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]):
|
||||
_set_triton_ptxas_path()
|
||||
os.environ.update(extra_env)
|
||||
load_kernel().precompile(warm_cache_only=True)
|
||||
|
|
|
|||
|
|
@ -210,6 +210,7 @@ class CachingAutotuner(KernelInterface):
|
|||
"triton",
|
||||
str(self.triton_meta.get("device", 0)),
|
||||
)
|
||||
log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
|
||||
|
||||
self.size_hints = size_hints
|
||||
self.coordesc_tuner = CoordescTuner(
|
||||
|
|
|
|||
|
|
@ -739,7 +739,7 @@ def clear_inductor_caches():
|
|||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def fresh_inductor_cache(cache_entries=None):
|
||||
def fresh_inductor_cache(cache_entries=None, dir=None, delete=True):
|
||||
"""
|
||||
Contextmanager that provides a clean tmp cachedir for inductor.
|
||||
|
||||
|
|
@ -748,7 +748,7 @@ def fresh_inductor_cache(cache_entries=None):
|
|||
"""
|
||||
clear_inductor_caches()
|
||||
|
||||
inductor_cache_dir = tempfile.mkdtemp()
|
||||
inductor_cache_dir = tempfile.mkdtemp(dir=dir)
|
||||
try:
|
||||
with mock.patch.dict(
|
||||
os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
|
||||
|
|
@ -768,7 +768,8 @@ def fresh_inductor_cache(cache_entries=None):
|
|||
if ".lock" not in f
|
||||
}
|
||||
)
|
||||
shutil.rmtree(inductor_cache_dir)
|
||||
if delete:
|
||||
shutil.rmtree(inductor_cache_dir)
|
||||
except Exception:
|
||||
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
|
||||
raise
|
||||
|
|
|
|||
Loading…
Reference in a new issue