[inductor] Fix TORCHINDUCTOR_FORCE_DISABLE_CACHES (#129257)

Summary: See https://github.com/pytorch/pytorch/issues/129159; this option wasn't doing its job for a few reasons. In this PR:
* Fix the with_fresh_cache_if_config() decorator
* Reset the "TORCHINDUCTOR_CACHE_DIR" & "TRITON_CACHE_DIR" env vars in sub-process to support them changing in the parent process

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129257
Approved by: https://github.com/oulgen
This commit is contained in:
Sam Larsen 2024-06-24 12:52:07 -07:00 committed by PyTorch MergeBot
parent 61bf1452a3
commit 87d14ad419
5 changed files with 27 additions and 15 deletions

View file

@ -173,11 +173,16 @@ class AsyncCompile:
kernel = TritonCodeCache.load(kernel_name, source_code)
if config.compile_threads > 1:
# We want to support changing these env vars after (and while) the
# process pool is running, so pass them to the subprocess to reset.
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
return TritonFuture(
kernel,
self.process_pool().submit(
_worker_compile_triton,
kernel._reload_in_subproc,
extra_env,
),
)
else:

View file

@ -43,8 +43,8 @@ from torch._inductor.cudagraph_utils import (
get_placeholders,
log_cudagraph_skip_and_bump_counter,
)
from torch._inductor.debug import save_args_for_compile_fx_inner
from torch._inductor.runtime.runtime_utils import cache_dir
from torch._inductor.utils import (
BoxedBool,
count_tangents,
@ -420,13 +420,19 @@ def get_patched_config_dict(config_patches=None) -> Dict[str, Any]:
return config.get_config_copy()
@functools.wraps
def with_fresh_cache_if_config(f):
if config.force_disable_caches:
with fresh_inductor_cache():
return f
else:
return f
def with_fresh_cache_if_config(fn):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if config.force_disable_caches:
# Don't delete the cache dir because it has to survive beyond the
# compile_fx call. Let's put the temp dirs under the default cache
# dir so they're easier to locate.
with fresh_inductor_cache(dir=cache_dir(), delete=False):
return fn(*args, **kwargs)
else:
return fn(*args, **kwargs)
return wrapper
@DebugContext.wrap

View file

@ -7,7 +7,7 @@ import sys
import warnings
from pathlib import Path
from types import ModuleType
from typing import Any, Callable
from typing import Any, Callable, Dict
def _reload_triton_kernel_in_subproc(reload_module, kernel_name):
@ -61,8 +61,7 @@ def _set_triton_ptxas_path() -> None:
warnings.warn(f"{ptxas} exists but is not an executable")
def _worker_compile_triton(
load_kernel: Callable[[], Any],
):
def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]):
_set_triton_ptxas_path()
os.environ.update(extra_env)
load_kernel().precompile(warm_cache_only=True)

View file

@ -210,6 +210,7 @@ class CachingAutotuner(KernelInterface):
"triton",
str(self.triton_meta.get("device", 0)),
)
log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"])
self.size_hints = size_hints
self.coordesc_tuner = CoordescTuner(

View file

@ -739,7 +739,7 @@ def clear_inductor_caches():
@contextlib.contextmanager
def fresh_inductor_cache(cache_entries=None):
def fresh_inductor_cache(cache_entries=None, dir=None, delete=True):
"""
Contextmanager that provides a clean tmp cachedir for inductor.
@ -748,7 +748,7 @@ def fresh_inductor_cache(cache_entries=None):
"""
clear_inductor_caches()
inductor_cache_dir = tempfile.mkdtemp()
inductor_cache_dir = tempfile.mkdtemp(dir=dir)
try:
with mock.patch.dict(
os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir}
@ -768,7 +768,8 @@ def fresh_inductor_cache(cache_entries=None):
if ".lock" not in f
}
)
shutil.rmtree(inductor_cache_dir)
if delete:
shutil.rmtree(inductor_cache_dir)
except Exception:
log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir)
raise