Update (base update)

[ghstack-poisoned]
This commit is contained in:
James Wu 2025-02-08 20:15:28 -08:00
parent 7c8ec84dab
commit 99af07f2f3
9 changed files with 253 additions and 79 deletions

View file

@ -1494,6 +1494,9 @@ class TestAutotuneCache(TestCase):
@config.patch({"autotune_remote_cache": True})
@config.patch({"bundled_autotune_remote_cache": False})
@config.patch({"max_autotune": True})
@config.patch(
{"compile_threads": 1}
) # Worker processes do not register PatchCaches() properly
def test_autotune_cache(self):
class Model(torch.nn.Module):
def forward(self, x, y, a, b):
@ -1531,6 +1534,7 @@ class TestAutotuneCache(TestCase):
@config.patch({"autotune_local_cache": True})
@config.patch({"autotune_remote_cache": False})
@config.patch({"bundled_autotune_remote_cache": True})
@config.patch({"compile_threads": 1})
@config.patch({"max_autotune": True})
def test_bundled_autotune_remote_cache(self):
class Model(torch.nn.Module):

View file

@ -1031,6 +1031,9 @@ class TestMaxAutotuneRemoteCache(TestCase):
PatchCaches.tearDown()
@parametrize("dynamic", (False, True))
@config.patch(
{"compile_threads": 1}
) # Worker processes do not register PatchCaches() properly
def test_max_autotune_remote_caching(self, dynamic: bool):
from unittest.mock import patch

View file

@ -11,13 +11,15 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
from concurrent.futures.process import BrokenProcessPool
from functools import partial
from time import time
from typing import Any, Callable, Optional, TYPE_CHECKING
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
import torch
from torch._dynamo.device_interface import get_registered_device_interfaces
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
from torch._inductor import config
from torch._inductor.codecache import (
_load_triton_kernel_from_source,
code_hash,
CodeCacheFuture,
CppCodeCache,
CppPythonBindingsCodeCache,
@ -25,8 +27,7 @@ from torch._inductor.codecache import (
HalideCodeCache,
LambdaFuture,
ROCmCodeCache,
TritonCodeCache,
TritonFuture,
torch_key,
)
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
@ -42,6 +43,7 @@ from torch.utils._triton import has_triton_package
if TYPE_CHECKING:
from torch._inductor.runtime.hints import HalideMeta
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
# timing metrics for time spent in the compilation
_cumulative_compile_time = 0.0
@ -130,9 +132,49 @@ def get_compile_threads() -> int:
@clear_on_fresh_inductor_cache
@functools.lru_cache(None)
def get_future_cache():
return {}
class CompiledTritonKernels:
"""
In memory cache for storing compiled triton kernels.
Each triton kernel is keyed by the hash of its source code. Each value stored
in the cache is a return value of AsyncCompile.triton().
Currently, the cache stores Future objects, but it should be generalizable for any kernels.
"""
_cache: Dict[str, LambdaFuture] = {}
@staticmethod
def key(kernel_src: str):
"""
Generates a cache key given a triton kernel's full source code.
This source includes the inductor meta, compilation metadata, the kernel itself, etc.
`kernel_src` should be the exact string passed to async_compile.triton()'s first argument.
"""
# Hashes the kernel source with torch_key into a single hash key
return code_hash(kernel_src, extra=torch_key())
@staticmethod
def save(kernel_src: str, future: LambdaFuture):
"""
Saves a compiled triton kernel to the cache.
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
but the real type we want to return here is actually an abstract triton kernel.
TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc.
so it could be less strict.
"""
key = CompiledTritonKernels.key(kernel_src)
CompiledTritonKernels._cache[key] = future
@staticmethod
def get(kernel_src: str, default: Any) -> LambdaFuture:
key = CompiledTritonKernels.key(kernel_src)
return CompiledTritonKernels._cache.get(key, default)
@staticmethod
def cache_clear():
CompiledTritonKernels._cache = {}
class AsyncCompile:
@ -208,51 +250,84 @@ class AsyncCompile:
)
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
"""
Async_compile.triton is more complicated than the other backends because
we're trying to optimize compile time as much as possible for this hot callsite.
First of all, the function is cached by CompiledTritonKernels; if there's a kernel
already compiled, we grab it directly from the cache and return.
Otherwise, if we have multiple compile threads, we kick off triton compilations on each
worker process by giving it a kernel and source code to compile. The worker initializes
a CachingAutotuner, runs triton compilation, and pickles the kernel back to us.
We use TritonCompileResult to represent the objects being pickled back to us by each
worker.
Some maybe not obvious things that are pickled back to us:
- Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata
and do not have to pay the cost of loading the triton kernel on the parent. But certain
cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function
in the parent lazily when we require it.
- The AutotuneCache, if enabled, is constructed on each worker per triton config
and pickled by to us via `CachingAutotuner.save_cache_hook`.
"""
if future := CompiledTritonKernels.get(source_code, None):
counters["inductor"]["async_compile_cache_hit"] += 1
return future
counters["inductor"]["async_compile_cache_miss"] += 1
kernel_code_log.info("Triton Kernel:\n%s", source_code)
_compile_start()
_set_triton_ptxas_path()
if os.environ.get("TRITON_INTERPRET", "0") == "1":
return getattr(
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
)
kernel = TritonCodeCache.load(kernel_name, source_code)
if self.use_process_pool():
set_feature_use("parallel_compile_post_warmup", True)
load_kernel = functools.partial(
_load_triton_kernel_from_source, kernel_name, source_code
)
is_parallel = self.use_process_pool()
set_feature_use("parallel_compile_post_warmup", is_parallel)
if is_parallel:
# We want to support changing these env vars after (and while) the
# process pool is running, so pass them to the subprocess to reset.
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
future_cache = get_future_cache()
if future := future_cache.get(source_code, None):
counters["inductor"]["async_compile_cache_hit"] += 1
return future
counters["inductor"]["async_compile_cache_miss"] += 1
future = TritonFuture(
kernel,
self.process_pool().submit(
_worker_compile_triton,
kernel._reload_in_subproc,
extra_env,
),
task = self.process_pool().submit(
_worker_compile_triton,
load_kernel,
extra_env,
)
future_cache[source_code] = future
return future
def reload_kernel_in_parent():
# Benchmark how often this happens
with dynamo_timed("reload_kernel_in_parent"):
return load_kernel()
def get_result() -> CachingAutotuner:
kernel = task.result()
kernel.precompile(
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
)
return kernel
future = LambdaFuture(get_result, future=task)
CompiledTritonKernels.save(source_code, future)
return future
else:
set_feature_use("parallel_compile_post_warmup", False)
with dynamo_timed(
"async_compile.precompile",
log_pt2_compile_event=True,
dynamo_compile_column_us="triton_compile_time_us",
log_waitcounter=True,
):
kernel.precompile()
return kernel
_set_triton_ptxas_path()
kernel = load_kernel()
kernel.precompile(warm_cache_only=False)
return kernel
def multi_kernel(self, *args, **kwargs) -> Any:
from torch._inductor.codegen.multi_kernel import MultiKernelCall

View file

@ -68,7 +68,6 @@ from torch._inductor.cpu_vec_isa import pick_vec_isa
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
from torch._inductor.runtime.compile_tasks import (
_module_to_triton_kernel,
_reload_python_module,
_reload_python_module_in_subproc,
)
@ -358,10 +357,11 @@ def sha256_hash(data: bytes) -> str:
return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
def code_hash(code: Union[str, bytes], extra: str = "") -> str:
def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str:
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
if extra != "":
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
if extra:
extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8")
hashing_str = hashing_str + b"||" + extra_b
return "c" + sha256_hash(hashing_str)
@ -2815,10 +2815,10 @@ class PyCodeCache:
return parse_stack_trace(entry)
class TritonCodeCache:
@classmethod
def load(cls, kernel_name: str, source_code: str) -> ModuleType:
return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
def _load_triton_kernel_from_source(
kernel_name: str, source_code: str
) -> CachingAutotuner:
return getattr(PyCodeCache.load(source_code), kernel_name)
def _cuda_compiler() -> Optional[str]:
@ -3222,30 +3222,12 @@ class CodeCacheFuture:
raise NotImplementedError
class TritonFuture(CodeCacheFuture):
kernel: CachingAutotuner
def __init__(
self,
kernel: Any,
future: Optional[Future[Any]],
) -> None:
self.kernel = kernel
self.future = future
def result(self) -> Callable[..., Any]:
if self.future is not None:
# If the worker failed this will throw an exception.
result = self.future.result()
assert result is None
self.future = None
self.kernel.precompile()
return self.kernel
class LambdaFuture(CodeCacheFuture):
def __init__(self, result_fn: Callable[..., Any]) -> None:
def __init__(
self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None
) -> None:
self.result_fn = result_fn
self.future = future
def result(self) -> Callable[..., Any]: # type: ignore[override]
return self.result_fn()

View file

@ -29,6 +29,7 @@ from torch.utils._triton import has_triton_package
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
from ...utils._sympy.value_ranges import ValueRanges
from .. import config, ir, metrics
from ..async_compile import AsyncCompile
from ..codecache import code_hash, get_path, PyCodeCache
from ..runtime.benchmarking import benchmarker
from ..runtime.hints import (
@ -110,6 +111,7 @@ log = logging.getLogger(__name__)
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
async_compile = AsyncCompile()
class OpDtypeSupport:
@ -3939,9 +3941,16 @@ class TritonScheduling(SIMDScheduling):
src_code = src_code.replace("#pragma CMT", "#")
_basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py")
compile_wrapper = IndentedBuffer()
if async_compile.use_process_pool():
# The process pool is warm, we can shell out to workers right away. This
# allows us to save the result in async_compile.CompiledTritonKernels,
# so that the second time we call async_compile.triton, we do no work.
async_compile.triton(subs_name, src_code)
compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''")
compile_wrapper.splice(src_code, strip=True)
current_device = V.graph.get_current_device_or_throw()
compile_wrapper.writeline(f"''', device_str='{current_device.type}')")

View file

@ -6,7 +6,7 @@ import logging
import os
import os.path
import re
from typing import Optional, TYPE_CHECKING
from typing import Any, Optional, TYPE_CHECKING
from typing_extensions import override
import torch
@ -154,8 +154,47 @@ class AutotuneCache:
if not remote_cache:
return
# Save the args passed to create_cache
# in case AutotuneCache needs to be pickled
self.remote_cache_full_key = key
self.is_fbcode = is_fbcode
self.remote_cache = (remote_cache, cache_key)
# The AutotuneCache may be serialized/deserialized if we're using
# AsyncCompile worker processes to run triton compilation.
# This is because AutotuneCache instances are created on the worker
# process, but we need to run AutotuneCache.save on the parent process
# when actually doing autotuning.
def __getstate__(self) -> dict[str, Any]:
# The remote cache handles themselves may not be serializable
# So clear it and reconstruct it on setstate
remote_cache = getattr(self, "remote_cache", None)
return {
**self.__dict__,
# Save the cache_key portion
"remote_cache": remote_cache and remote_cache[1],
}
def __setstate__(self, state: dict[str, Any]) -> None:
# Reconstruct the remote cache on the parent class
self.__dict__.update(state)
if self.remote_cache is not None:
assert isinstance(self.remote_cache, str)
assert hasattr(self, "remote_cache_full_key")
assert hasattr(self, "is_fbcode")
cache_key = self.remote_cache
remote_cache = create_cache(
self.remote_cache_full_key,
self.is_fbcode,
"FbRemoteAutotuneCache",
"RemoteAutotuneCache",
)
if remote_cache is not None:
self.remote_cache = (remote_cache, cache_key)
else:
log.warning("Warning, failed to recreate remote cache after pickling")
self.remote_cache = None
# Save the config in the caches
def save(
self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False

View file

@ -68,7 +68,10 @@ def _set_triton_ptxas_path() -> None:
def _worker_compile_triton(
load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str]
) -> None:
) -> CachingAutotuner:
_set_triton_ptxas_path()
os.environ.update(extra_env)
load_kernel().precompile(warm_cache_only=True)
kernel = load_kernel()
kernel.precompile(warm_cache_only=True)
kernel.prepare_for_pickle()
return kernel

View file

@ -256,18 +256,29 @@ class CachingAutotuner(KernelInterface):
def precompile(
self,
warm_cache_only=False,
reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None,
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
):
if warm_cache_only:
self._precompile_worker()
return
with self.lock:
# Helper function for reloading a kernel generated in a worker
# in the parent class. Normally we don't need to reload the kernel
# in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock),
# we need to actually run compilation on the parent process
if reload_kernel is not None:
self._reload_kernel = reload_kernel
self._precompile_worker()
self._make_launchers()
self._dynamic_scale_rblock(reload_in_parent)
self._dynamic_scale_rblock()
def _precompile_worker(self):
if self.compile_results:
for result in self.compile_results:
TritonBundler.put(
triton_hash_to_path_key(result.kernel.hash),
self.triton_meta.get("device", 0),
)
return
assert not self.launchers
if not self.configs:
@ -285,9 +296,7 @@ class CachingAutotuner(KernelInterface):
self.compile_results = compile_results
self.configs = None
def _dynamic_scale_rblock(
self, reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None
):
def _dynamic_scale_rblock(self):
# TODO(jansel): we should find a way to move this extra compile into the worker process
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
device_prop = self.device_props
@ -392,8 +401,9 @@ class CachingAutotuner(KernelInterface):
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
containing the real fn yet.
"""
assert reload_in_parent
self.fn = reload_in_parent().fn
assert hasattr(self, "_reload_kernel")
assert callable(self._reload_kernel)
self.fn = self._reload_kernel().fn
self.compile_results.append(self._precompile_config(new_config))
self._make_launchers()
@ -415,6 +425,7 @@ class CachingAutotuner(KernelInterface):
for result in self.compile_results:
try:
launchers.append(result.make_launcher())
except (OutOfResources, PTXASError) as e:
exc = e
if len(launchers) == 0:
@ -519,7 +530,6 @@ class CachingAutotuner(KernelInterface):
compile_meta,
)
raise
TritonBundler.put(
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
)
@ -819,6 +829,17 @@ class CachingAutotuner(KernelInterface):
config2launcher = {launcher.config: launcher}
# TODO: should we just load the kernels ahead of time if we know we're going to call this?
if self.fn.fn is None:
"""
We are in the parent process, while this program was compiled in a worker
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
containing the real fn yet.
"""
assert hasattr(self, "_reload_kernel")
assert callable(self._reload_kernel)
self.fn = self._reload_kernel().fn
def benchmark_one_config(config):
with self.lock:
launcher = self._precompile_config(config).make_launcher()
@ -957,12 +978,50 @@ class TritonCompileResult:
self.compile_meta = compile_meta
self.inductor_meta = inductor_meta
@staticmethod
def _serialize_metadata(metadata):
"""
Triton uses a nested class called KernelMetadata to store metadata information.
Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear
in the toplevel namespace of the module. So these serialization/deser functions
are used to convert the namedtuples to a dict and back.
As for packed_metadata, depending on the triton backend, KernelMetadata can be
a namedtuple, or a regular tuple! So the serialization function branches on whether
the metadata to be serialized is a namedtuple or regular, serializable one.
"""
def is_namedtuple(obj) -> bool:
return (
isinstance(obj, tuple)
and hasattr(obj, "_asdict")
and hasattr(obj, "_fields")
)
if is_namedtuple(metadata):
return metadata._asdict()
else:
return metadata
@staticmethod
def _deserialize_metadata(metadata):
if isinstance(metadata, dict):
return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))(
**metadata
)
else:
return metadata
def __getstate__(self) -> dict[str, Any]:
kernel = self.kernel
# replace the fields that don't pickle nicely
kernel_state = {
**kernel.__dict__,
"metadata": kernel.metadata._asdict(),
# See doc about serializing metadata above
"metadata": self._serialize_metadata(kernel.metadata),
"packed_metadata": self._serialize_metadata(
getattr(kernel, "packed_metadata", None)
),
"module": None, # regenerated by kernel._init_handles()
"function": None, # regenerated by kernel._init_handles()
"run": None, # regenerated by kernel._init_handles()
@ -975,13 +1034,13 @@ class TritonCompileResult:
# TODO(jansel): need to fixup src.fn which is now None
kernel = CompiledKernel.__new__(CompiledKernel)
metadata = state["kernel"]["metadata"]
packed_metadata = state["kernel"]["packed_metadata"]
kernel.__dict__.update(
{
**state["kernel"],
# "src": src,
"metadata": self._kernel_metadata_cls(tuple(metadata.keys()))(
**metadata
),
"metadata": self._deserialize_metadata(metadata),
"packed_metadata": self._deserialize_metadata(packed_metadata),
}
)
self.__dict__.update(state)

View file

@ -36,7 +36,7 @@ import sympy
import torch
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
from torch._dynamo.utils import counters, dynamo_timed
from torch._inductor.codecache import PyCodeCache, TritonFuture
from torch._inductor.codecache import LambdaFuture, PyCodeCache
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
from torch.utils._ordered_set import OrderedSet
@ -2734,7 +2734,7 @@ class Scheduler:
def compile_kernel(
nodes: Sequence[BaseSchedulerNode],
) -> tuple[Optional[TritonFuture], ModuleType]:
) -> tuple[Optional[LambdaFuture], ModuleType]:
src_code = self.generate_kernel_code_from_nodes(
nodes, benchmark_kernel=True
)
@ -2743,7 +2743,7 @@ class Scheduler:
fut = None
else:
fut = async_compile.triton(kernel_name="triton_", source_code=src_code)
assert isinstance(fut, TritonFuture)
assert isinstance(fut, LambdaFuture)
return (fut, mod)
@ -2772,7 +2772,7 @@ class Scheduler:
)
# Start compiling choices in parallel
future_choices: List[tuple[Any, Optional[TritonFuture], ModuleType]] = []
future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
triton_choices = 0
for choice, unfused_time in sorted(
choice_timings.items(), key=lambda x: x[1]