diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 1f98c236306..b23a34b50d7 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -173,11 +173,16 @@ class AsyncCompile: kernel = TritonCodeCache.load(kernel_name, source_code) if config.compile_threads > 1: + # We want to support changing these env vars after (and while) the + # process pool is running, so pass them to the subprocess to reset. + env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] + extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} return TritonFuture( kernel, self.process_pool().submit( _worker_compile_triton, kernel._reload_in_subproc, + extra_env, ), ) else: diff --git a/torch/_inductor/compile_fx.py b/torch/_inductor/compile_fx.py index f9d870f7b2f..f93c0772ac0 100644 --- a/torch/_inductor/compile_fx.py +++ b/torch/_inductor/compile_fx.py @@ -43,8 +43,8 @@ from torch._inductor.cudagraph_utils import ( get_placeholders, log_cudagraph_skip_and_bump_counter, ) - from torch._inductor.debug import save_args_for_compile_fx_inner +from torch._inductor.runtime.runtime_utils import cache_dir from torch._inductor.utils import ( BoxedBool, count_tangents, @@ -420,13 +420,19 @@ def get_patched_config_dict(config_patches=None) -> Dict[str, Any]: return config.get_config_copy() -@functools.wraps -def with_fresh_cache_if_config(f): - if config.force_disable_caches: - with fresh_inductor_cache(): - return f - else: - return f +def with_fresh_cache_if_config(fn): + @functools.wraps(fn) + def wrapper(*args, **kwargs): + if config.force_disable_caches: + # Don't delete the cache dir because it has to survive beyond the + # compile_fx call. Let's put the temp dirs under the default cache + # dir so they're easier to locate. + with fresh_inductor_cache(dir=cache_dir(), delete=False): + return fn(*args, **kwargs) + else: + return fn(*args, **kwargs) + + return wrapper @DebugContext.wrap diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index fc70640c07d..3a962687571 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -7,7 +7,7 @@ import sys import warnings from pathlib import Path from types import ModuleType -from typing import Any, Callable +from typing import Any, Callable, Dict def _reload_triton_kernel_in_subproc(reload_module, kernel_name): @@ -61,8 +61,7 @@ def _set_triton_ptxas_path() -> None: warnings.warn(f"{ptxas} exists but is not an executable") -def _worker_compile_triton( - load_kernel: Callable[[], Any], -): +def _worker_compile_triton(load_kernel: Callable[[], Any], extra_env: Dict[str, str]): _set_triton_ptxas_path() + os.environ.update(extra_env) load_kernel().precompile(warm_cache_only=True) diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index 036e87b9630..1c96aa75b89 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -210,6 +210,7 @@ class CachingAutotuner(KernelInterface): "triton", str(self.triton_meta.get("device", 0)), ) + log.debug("Triton cache dir: %s", os.environ["TRITON_CACHE_DIR"]) self.size_hints = size_hints self.coordesc_tuner = CoordescTuner( diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 28a1a1451df..55d425f8cb5 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -739,7 +739,7 @@ def clear_inductor_caches(): @contextlib.contextmanager -def fresh_inductor_cache(cache_entries=None): +def fresh_inductor_cache(cache_entries=None, dir=None, delete=True): """ Contextmanager that provides a clean tmp cachedir for inductor. @@ -748,7 +748,7 @@ def fresh_inductor_cache(cache_entries=None): """ clear_inductor_caches() - inductor_cache_dir = tempfile.mkdtemp() + inductor_cache_dir = tempfile.mkdtemp(dir=dir) try: with mock.patch.dict( os.environ, {"TORCHINDUCTOR_CACHE_DIR": inductor_cache_dir} @@ -768,7 +768,8 @@ def fresh_inductor_cache(cache_entries=None): if ".lock" not in f } ) - shutil.rmtree(inductor_cache_dir) + if delete: + shutil.rmtree(inductor_cache_dir) except Exception: log.warning("on error, temporary cache dir kept at %s", inductor_cache_dir) raise