diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index adb8319ba6d..4309ad9cfa9 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1494,6 +1494,9 @@ class TestAutotuneCache(TestCase): @config.patch({"autotune_remote_cache": True}) @config.patch({"bundled_autotune_remote_cache": False}) @config.patch({"max_autotune": True}) + @config.patch( + {"compile_threads": 1} + ) # Worker processes do not register PatchCaches() properly def test_autotune_cache(self): class Model(torch.nn.Module): def forward(self, x, y, a, b): @@ -1531,6 +1534,7 @@ class TestAutotuneCache(TestCase): @config.patch({"autotune_local_cache": True}) @config.patch({"autotune_remote_cache": False}) @config.patch({"bundled_autotune_remote_cache": True}) + @config.patch({"compile_threads": 1}) @config.patch({"max_autotune": True}) def test_bundled_autotune_remote_cache(self): class Model(torch.nn.Module): diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 29a8d4d3004..399e6b47819 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1031,6 +1031,9 @@ class TestMaxAutotuneRemoteCache(TestCase): PatchCaches.tearDown() @parametrize("dynamic", (False, True)) + @config.patch( + {"compile_threads": 1} + ) # Worker processes do not register PatchCaches() properly def test_max_autotune_remote_caching(self, dynamic: bool): from unittest.mock import patch diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 4815f6f568a..d44bcd9452a 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -11,13 +11,15 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool from functools import partial from time import time -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING import torch from torch._dynamo.device_interface import get_registered_device_interfaces from torch._dynamo.utils import counters, dynamo_timed, set_feature_use from torch._inductor import config from torch._inductor.codecache import ( + _load_triton_kernel_from_source, + code_hash, CodeCacheFuture, CppCodeCache, CppPythonBindingsCodeCache, @@ -25,8 +27,7 @@ from torch._inductor.codecache import ( HalideCodeCache, LambdaFuture, ROCmCodeCache, - TritonCodeCache, - TritonFuture, + torch_key, ) from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool from torch._inductor.compile_worker.watchdog import _async_compile_initializer @@ -42,6 +43,7 @@ from torch.utils._triton import has_triton_package if TYPE_CHECKING: from torch._inductor.runtime.hints import HalideMeta + from torch._inductor.runtime.triton_heuristics import CachingAutotuner # timing metrics for time spent in the compilation _cumulative_compile_time = 0.0 @@ -130,9 +132,49 @@ def get_compile_threads() -> int: @clear_on_fresh_inductor_cache -@functools.lru_cache(None) -def get_future_cache(): - return {} +class CompiledTritonKernels: + """ + In memory cache for storing compiled triton kernels. + + Each triton kernel is keyed by the hash of its source code. Each value stored + in the cache is a return value of AsyncCompile.triton(). + + Currently, the cache stores Future objects, but it should be generalizable for any kernels. + """ + + _cache: Dict[str, LambdaFuture] = {} + + @staticmethod + def key(kernel_src: str): + """ + Generates a cache key given a triton kernel's full source code. + This source includes the inductor meta, compilation metadata, the kernel itself, etc. + `kernel_src` should be the exact string passed to async_compile.triton()'s first argument. + """ + # Hashes the kernel source with torch_key into a single hash key + return code_hash(kernel_src, extra=torch_key()) + + @staticmethod + def save(kernel_src: str, future: LambdaFuture): + """ + Saves a compiled triton kernel to the cache. + TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton, + but the real type we want to return here is actually an abstract triton kernel. + + TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc. + so it could be less strict. + """ + key = CompiledTritonKernels.key(kernel_src) + CompiledTritonKernels._cache[key] = future + + @staticmethod + def get(kernel_src: str, default: Any) -> LambdaFuture: + key = CompiledTritonKernels.key(kernel_src) + return CompiledTritonKernels._cache.get(key, default) + + @staticmethod + def cache_clear(): + CompiledTritonKernels._cache = {} class AsyncCompile: @@ -208,51 +250,84 @@ class AsyncCompile: ) def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + """ + Async_compile.triton is more complicated than the other backends because + we're trying to optimize compile time as much as possible for this hot callsite. + + First of all, the function is cached by CompiledTritonKernels; if there's a kernel + already compiled, we grab it directly from the cache and return. + + Otherwise, if we have multiple compile threads, we kick off triton compilations on each + worker process by giving it a kernel and source code to compile. The worker initializes + a CachingAutotuner, runs triton compilation, and pickles the kernel back to us. + We use TritonCompileResult to represent the objects being pickled back to us by each + worker. + + Some maybe not obvious things that are pickled back to us: + - Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata + and do not have to pay the cost of loading the triton kernel on the parent. But certain + cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function + in the parent lazily when we require it. + - The AutotuneCache, if enabled, is constructed on each worker per triton config + and pickled by to us via `CachingAutotuner.save_cache_hook`. + """ + if future := CompiledTritonKernels.get(source_code, None): + counters["inductor"]["async_compile_cache_hit"] += 1 + return future + + counters["inductor"]["async_compile_cache_miss"] += 1 + kernel_code_log.info("Triton Kernel:\n%s", source_code) _compile_start() - _set_triton_ptxas_path() if os.environ.get("TRITON_INTERPRET", "0") == "1": return getattr( torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name ) - kernel = TritonCodeCache.load(kernel_name, source_code) - if self.use_process_pool(): - set_feature_use("parallel_compile_post_warmup", True) + load_kernel = functools.partial( + _load_triton_kernel_from_source, kernel_name, source_code + ) + is_parallel = self.use_process_pool() + set_feature_use("parallel_compile_post_warmup", is_parallel) + if is_parallel: # We want to support changing these env vars after (and while) the # process pool is running, so pass them to the subprocess to reset. env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} - future_cache = get_future_cache() - - if future := future_cache.get(source_code, None): - counters["inductor"]["async_compile_cache_hit"] += 1 - return future - - counters["inductor"]["async_compile_cache_miss"] += 1 - future = TritonFuture( - kernel, - self.process_pool().submit( - _worker_compile_triton, - kernel._reload_in_subproc, - extra_env, - ), + task = self.process_pool().submit( + _worker_compile_triton, + load_kernel, + extra_env, ) - future_cache[source_code] = future - return future + def reload_kernel_in_parent(): + # Benchmark how often this happens + with dynamo_timed("reload_kernel_in_parent"): + return load_kernel() + + def get_result() -> CachingAutotuner: + kernel = task.result() + kernel.precompile( + warm_cache_only=False, reload_kernel=reload_kernel_in_parent + ) + return kernel + + future = LambdaFuture(get_result, future=task) + CompiledTritonKernels.save(source_code, future) + return future else: - set_feature_use("parallel_compile_post_warmup", False) with dynamo_timed( "async_compile.precompile", log_pt2_compile_event=True, dynamo_compile_column_us="triton_compile_time_us", log_waitcounter=True, ): - kernel.precompile() - return kernel + _set_triton_ptxas_path() + kernel = load_kernel() + kernel.precompile(warm_cache_only=False) + return kernel def multi_kernel(self, *args, **kwargs) -> Any: from torch._inductor.codegen.multi_kernel import MultiKernelCall diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6a3474987d3..c9eb117f856 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -68,7 +68,6 @@ from torch._inductor.cpu_vec_isa import pick_vec_isa from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param from torch._inductor.runtime.compile_tasks import ( - _module_to_triton_kernel, _reload_python_module, _reload_python_module_in_subproc, ) @@ -358,10 +357,11 @@ def sha256_hash(data: bytes) -> str: return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() -def code_hash(code: Union[str, bytes], extra: str = "") -> str: +def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str: hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") - if extra != "": - hashing_str = hashing_str + b"||" + extra.encode("utf-8") + if extra: + extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8") + hashing_str = hashing_str + b"||" + extra_b return "c" + sha256_hash(hashing_str) @@ -2815,10 +2815,10 @@ class PyCodeCache: return parse_stack_trace(entry) -class TritonCodeCache: - @classmethod - def load(cls, kernel_name: str, source_code: str) -> ModuleType: - return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name) +def _load_triton_kernel_from_source( + kernel_name: str, source_code: str +) -> CachingAutotuner: + return getattr(PyCodeCache.load(source_code), kernel_name) def _cuda_compiler() -> Optional[str]: @@ -3222,30 +3222,12 @@ class CodeCacheFuture: raise NotImplementedError -class TritonFuture(CodeCacheFuture): - kernel: CachingAutotuner - - def __init__( - self, - kernel: Any, - future: Optional[Future[Any]], - ) -> None: - self.kernel = kernel - self.future = future - - def result(self) -> Callable[..., Any]: - if self.future is not None: - # If the worker failed this will throw an exception. - result = self.future.result() - assert result is None - self.future = None - self.kernel.precompile() - return self.kernel - - class LambdaFuture(CodeCacheFuture): - def __init__(self, result_fn: Callable[..., Any]) -> None: + def __init__( + self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None + ) -> None: self.result_fn = result_fn + self.future = future def result(self) -> Callable[..., Any]: # type: ignore[override] return self.result_fn() diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 3b30cd40a09..64244a32df1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -29,6 +29,7 @@ from torch.utils._triton import has_triton_package from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges from .. import config, ir, metrics +from ..async_compile import AsyncCompile from ..codecache import code_hash, get_path, PyCodeCache from ..runtime.benchmarking import benchmarker from ..runtime.hints import ( @@ -110,6 +111,7 @@ log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +async_compile = AsyncCompile() class OpDtypeSupport: @@ -3939,9 +3941,16 @@ class TritonScheduling(SIMDScheduling): src_code = src_code.replace("#pragma CMT", "#") _basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") - compile_wrapper = IndentedBuffer() + + if async_compile.use_process_pool(): + # The process pool is warm, we can shell out to workers right away. This + # allows us to save the result in async_compile.CompiledTritonKernels, + # so that the second time we call async_compile.triton, we do no work. + async_compile.triton(subs_name, src_code) + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + compile_wrapper.splice(src_code, strip=True) current_device = V.graph.get_current_device_or_throw() compile_wrapper.writeline(f"''', device_str='{current_device.type}')") diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index dabe299fa2c..0c098f6afa4 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -6,7 +6,7 @@ import logging import os import os.path import re -from typing import Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from typing_extensions import override import torch @@ -154,8 +154,47 @@ class AutotuneCache: if not remote_cache: return + # Save the args passed to create_cache + # in case AutotuneCache needs to be pickled + self.remote_cache_full_key = key + self.is_fbcode = is_fbcode self.remote_cache = (remote_cache, cache_key) + # The AutotuneCache may be serialized/deserialized if we're using + # AsyncCompile worker processes to run triton compilation. + # This is because AutotuneCache instances are created on the worker + # process, but we need to run AutotuneCache.save on the parent process + # when actually doing autotuning. + def __getstate__(self) -> dict[str, Any]: + # The remote cache handles themselves may not be serializable + # So clear it and reconstruct it on setstate + remote_cache = getattr(self, "remote_cache", None) + return { + **self.__dict__, + # Save the cache_key portion + "remote_cache": remote_cache and remote_cache[1], + } + + def __setstate__(self, state: dict[str, Any]) -> None: + # Reconstruct the remote cache on the parent class + self.__dict__.update(state) + if self.remote_cache is not None: + assert isinstance(self.remote_cache, str) + assert hasattr(self, "remote_cache_full_key") + assert hasattr(self, "is_fbcode") + cache_key = self.remote_cache + remote_cache = create_cache( + self.remote_cache_full_key, + self.is_fbcode, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + ) + if remote_cache is not None: + self.remote_cache = (remote_cache, cache_key) + else: + log.warning("Warning, failed to recreate remote cache after pickling") + self.remote_cache = None + # Save the config in the caches def save( self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 8e1f8659adf..b15fc568e0b 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -68,7 +68,10 @@ def _set_triton_ptxas_path() -> None: def _worker_compile_triton( load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str] -) -> None: +) -> CachingAutotuner: _set_triton_ptxas_path() os.environ.update(extra_env) - load_kernel().precompile(warm_cache_only=True) + kernel = load_kernel() + kernel.precompile(warm_cache_only=True) + kernel.prepare_for_pickle() + return kernel diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c920be4813b..487610b81f5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -256,18 +256,29 @@ class CachingAutotuner(KernelInterface): def precompile( self, warm_cache_only=False, - reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None, + reload_kernel: Optional[Callable[[], CachingAutotuner]] = None, ): if warm_cache_only: self._precompile_worker() return with self.lock: + # Helper function for reloading a kernel generated in a worker + # in the parent class. Normally we don't need to reload the kernel + # in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock), + # we need to actually run compilation on the parent process + if reload_kernel is not None: + self._reload_kernel = reload_kernel self._precompile_worker() self._make_launchers() - self._dynamic_scale_rblock(reload_in_parent) + self._dynamic_scale_rblock() def _precompile_worker(self): if self.compile_results: + for result in self.compile_results: + TritonBundler.put( + triton_hash_to_path_key(result.kernel.hash), + self.triton_meta.get("device", 0), + ) return assert not self.launchers if not self.configs: @@ -285,9 +296,7 @@ class CachingAutotuner(KernelInterface): self.compile_results = compile_results self.configs = None - def _dynamic_scale_rblock( - self, reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None - ): + def _dynamic_scale_rblock(self): # TODO(jansel): we should find a way to move this extra compile into the worker process # Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg. device_prop = self.device_props @@ -392,8 +401,9 @@ class CachingAutotuner(KernelInterface): and the fn was dropped in prepare_for_pickle(). We haven't loaded the module containing the real fn yet. """ - assert reload_in_parent - self.fn = reload_in_parent().fn + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn self.compile_results.append(self._precompile_config(new_config)) self._make_launchers() @@ -415,6 +425,7 @@ class CachingAutotuner(KernelInterface): for result in self.compile_results: try: launchers.append(result.make_launcher()) + except (OutOfResources, PTXASError) as e: exc = e if len(launchers) == 0: @@ -519,7 +530,6 @@ class CachingAutotuner(KernelInterface): compile_meta, ) raise - TritonBundler.put( triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0) ) @@ -819,6 +829,17 @@ class CachingAutotuner(KernelInterface): config2launcher = {launcher.config: launcher} + # TODO: should we just load the kernels ahead of time if we know we're going to call this? + if self.fn.fn is None: + """ + We are in the parent process, while this program was compiled in a worker + and the fn was dropped in prepare_for_pickle(). We haven't loaded the module + containing the real fn yet. + """ + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn + def benchmark_one_config(config): with self.lock: launcher = self._precompile_config(config).make_launcher() @@ -957,12 +978,50 @@ class TritonCompileResult: self.compile_meta = compile_meta self.inductor_meta = inductor_meta + @staticmethod + def _serialize_metadata(metadata): + """ + Triton uses a nested class called KernelMetadata to store metadata information. + Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear + in the toplevel namespace of the module. So these serialization/deser functions + are used to convert the namedtuples to a dict and back. + + As for packed_metadata, depending on the triton backend, KernelMetadata can be + a namedtuple, or a regular tuple! So the serialization function branches on whether + the metadata to be serialized is a namedtuple or regular, serializable one. + """ + + def is_namedtuple(obj) -> bool: + return ( + isinstance(obj, tuple) + and hasattr(obj, "_asdict") + and hasattr(obj, "_fields") + ) + + if is_namedtuple(metadata): + return metadata._asdict() + else: + return metadata + + @staticmethod + def _deserialize_metadata(metadata): + if isinstance(metadata, dict): + return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))( + **metadata + ) + else: + return metadata + def __getstate__(self) -> dict[str, Any]: kernel = self.kernel # replace the fields that don't pickle nicely kernel_state = { **kernel.__dict__, - "metadata": kernel.metadata._asdict(), + # See doc about serializing metadata above + "metadata": self._serialize_metadata(kernel.metadata), + "packed_metadata": self._serialize_metadata( + getattr(kernel, "packed_metadata", None) + ), "module": None, # regenerated by kernel._init_handles() "function": None, # regenerated by kernel._init_handles() "run": None, # regenerated by kernel._init_handles() @@ -975,13 +1034,13 @@ class TritonCompileResult: # TODO(jansel): need to fixup src.fn which is now None kernel = CompiledKernel.__new__(CompiledKernel) metadata = state["kernel"]["metadata"] + packed_metadata = state["kernel"]["packed_metadata"] kernel.__dict__.update( { **state["kernel"], # "src": src, - "metadata": self._kernel_metadata_cls(tuple(metadata.keys()))( - **metadata - ), + "metadata": self._deserialize_metadata(metadata), + "packed_metadata": self._deserialize_metadata(packed_metadata), } ) self.__dict__.update(state) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index ee272adf557..fdbfd58fc5a 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -36,7 +36,7 @@ import sympy import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.utils import counters, dynamo_timed -from torch._inductor.codecache import PyCodeCache, TritonFuture +from torch._inductor.codecache import LambdaFuture, PyCodeCache from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet @@ -2734,7 +2734,7 @@ class Scheduler: def compile_kernel( nodes: Sequence[BaseSchedulerNode], - ) -> tuple[Optional[TritonFuture], ModuleType]: + ) -> tuple[Optional[LambdaFuture], ModuleType]: src_code = self.generate_kernel_code_from_nodes( nodes, benchmark_kernel=True ) @@ -2743,7 +2743,7 @@ class Scheduler: fut = None else: fut = async_compile.triton(kernel_name="triton_", source_code=src_code) - assert isinstance(fut, TritonFuture) + assert isinstance(fut, LambdaFuture) return (fut, mod) @@ -2772,7 +2772,7 @@ class Scheduler: ) # Start compiling choices in parallel - future_choices: List[tuple[Any, Optional[TritonFuture], ModuleType]] = [] + future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] triton_choices = 0 for choice, unfused_time in sorted( choice_timings.items(), key=lambda x: x[1]