mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Only call triton in worker process, ahead of time compile
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146417
### Big idea
This PR extends https://github.com/pytorch/pytorch/pull/144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism.
### Implementation Overview
In total, the diff does the following:
- Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes
- Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers
- Extend @eellison's future cache to a class, mostly as a refactor
- Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent
async_compile.triton call that occurs after codegen to cache hit on cold start.
In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts.
Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen.
Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/)
ghstack-source-id: 22f0e5a3bd
This commit is contained in:
parent
7c8ec84dab
commit
b8c5c4443c
9 changed files with 253 additions and 79 deletions
|
|
@ -1494,6 +1494,9 @@ class TestAutotuneCache(TestCase):
|
|||
@config.patch({"autotune_remote_cache": True})
|
||||
@config.patch({"bundled_autotune_remote_cache": False})
|
||||
@config.patch({"max_autotune": True})
|
||||
@config.patch(
|
||||
{"compile_threads": 1}
|
||||
) # Worker processes do not register PatchCaches() properly
|
||||
def test_autotune_cache(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y, a, b):
|
||||
|
|
@ -1531,6 +1534,7 @@ class TestAutotuneCache(TestCase):
|
|||
@config.patch({"autotune_local_cache": True})
|
||||
@config.patch({"autotune_remote_cache": False})
|
||||
@config.patch({"bundled_autotune_remote_cache": True})
|
||||
@config.patch({"compile_threads": 1})
|
||||
@config.patch({"max_autotune": True})
|
||||
def test_bundled_autotune_remote_cache(self):
|
||||
class Model(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -1031,6 +1031,9 @@ class TestMaxAutotuneRemoteCache(TestCase):
|
|||
PatchCaches.tearDown()
|
||||
|
||||
@parametrize("dynamic", (False, True))
|
||||
@config.patch(
|
||||
{"compile_threads": 1}
|
||||
) # Worker processes do not register PatchCaches() properly
|
||||
def test_max_autotune_remote_caching(self, dynamic: bool):
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,15 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
|||
from concurrent.futures.process import BrokenProcessPool
|
||||
from functools import partial
|
||||
from time import time
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._dynamo.device_interface import get_registered_device_interfaces
|
||||
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import (
|
||||
_load_triton_kernel_from_source,
|
||||
code_hash,
|
||||
CodeCacheFuture,
|
||||
CppCodeCache,
|
||||
CppPythonBindingsCodeCache,
|
||||
|
|
@ -25,8 +27,7 @@ from torch._inductor.codecache import (
|
|||
HalideCodeCache,
|
||||
LambdaFuture,
|
||||
ROCmCodeCache,
|
||||
TritonCodeCache,
|
||||
TritonFuture,
|
||||
torch_key,
|
||||
)
|
||||
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
|
||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
|
|
@ -42,6 +43,7 @@ from torch.utils._triton import has_triton_package
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.runtime.hints import HalideMeta
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
|
||||
# timing metrics for time spent in the compilation
|
||||
_cumulative_compile_time = 0.0
|
||||
|
|
@ -130,9 +132,49 @@ def get_compile_threads() -> int:
|
|||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@functools.lru_cache(None)
|
||||
def get_future_cache():
|
||||
return {}
|
||||
class CompiledTritonKernels:
|
||||
"""
|
||||
In memory cache for storing compiled triton kernels.
|
||||
|
||||
Each triton kernel is keyed by the hash of its source code. Each value stored
|
||||
in the cache is a return value of AsyncCompile.triton().
|
||||
|
||||
Currently, the cache stores Future objects, but it should be generalizable for any kernels.
|
||||
"""
|
||||
|
||||
_cache: Dict[str, LambdaFuture] = {}
|
||||
|
||||
@staticmethod
|
||||
def key(kernel_src: str):
|
||||
"""
|
||||
Generates a cache key given a triton kernel's full source code.
|
||||
This source includes the inductor meta, compilation metadata, the kernel itself, etc.
|
||||
`kernel_src` should be the exact string passed to async_compile.triton()'s first argument.
|
||||
"""
|
||||
# Hashes the kernel source with torch_key into a single hash key
|
||||
return code_hash(kernel_src, extra=torch_key())
|
||||
|
||||
@staticmethod
|
||||
def save(kernel_src: str, future: LambdaFuture):
|
||||
"""
|
||||
Saves a compiled triton kernel to the cache.
|
||||
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
|
||||
but the real type we want to return here is actually an abstract triton kernel.
|
||||
|
||||
TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc.
|
||||
so it could be less strict.
|
||||
"""
|
||||
key = CompiledTritonKernels.key(kernel_src)
|
||||
CompiledTritonKernels._cache[key] = future
|
||||
|
||||
@staticmethod
|
||||
def get(kernel_src: str, default: Any) -> LambdaFuture:
|
||||
key = CompiledTritonKernels.key(kernel_src)
|
||||
return CompiledTritonKernels._cache.get(key, default)
|
||||
|
||||
@staticmethod
|
||||
def cache_clear():
|
||||
CompiledTritonKernels._cache = {}
|
||||
|
||||
|
||||
class AsyncCompile:
|
||||
|
|
@ -208,51 +250,84 @@ class AsyncCompile:
|
|||
)
|
||||
|
||||
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
||||
"""
|
||||
Async_compile.triton is more complicated than the other backends because
|
||||
we're trying to optimize compile time as much as possible for this hot callsite.
|
||||
|
||||
First of all, the function is cached by CompiledTritonKernels; if there's a kernel
|
||||
already compiled, we grab it directly from the cache and return.
|
||||
|
||||
Otherwise, if we have multiple compile threads, we kick off triton compilations on each
|
||||
worker process by giving it a kernel and source code to compile. The worker initializes
|
||||
a CachingAutotuner, runs triton compilation, and pickles the kernel back to us.
|
||||
We use TritonCompileResult to represent the objects being pickled back to us by each
|
||||
worker.
|
||||
|
||||
Some maybe not obvious things that are pickled back to us:
|
||||
- Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata
|
||||
and do not have to pay the cost of loading the triton kernel on the parent. But certain
|
||||
cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function
|
||||
in the parent lazily when we require it.
|
||||
- The AutotuneCache, if enabled, is constructed on each worker per triton config
|
||||
and pickled by to us via `CachingAutotuner.save_cache_hook`.
|
||||
"""
|
||||
if future := CompiledTritonKernels.get(source_code, None):
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
return future
|
||||
|
||||
counters["inductor"]["async_compile_cache_miss"] += 1
|
||||
|
||||
kernel_code_log.info("Triton Kernel:\n%s", source_code)
|
||||
_compile_start()
|
||||
_set_triton_ptxas_path()
|
||||
|
||||
if os.environ.get("TRITON_INTERPRET", "0") == "1":
|
||||
return getattr(
|
||||
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
|
||||
)
|
||||
|
||||
kernel = TritonCodeCache.load(kernel_name, source_code)
|
||||
if self.use_process_pool():
|
||||
set_feature_use("parallel_compile_post_warmup", True)
|
||||
load_kernel = functools.partial(
|
||||
_load_triton_kernel_from_source, kernel_name, source_code
|
||||
)
|
||||
is_parallel = self.use_process_pool()
|
||||
set_feature_use("parallel_compile_post_warmup", is_parallel)
|
||||
if is_parallel:
|
||||
# We want to support changing these env vars after (and while) the
|
||||
# process pool is running, so pass them to the subprocess to reset.
|
||||
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
|
||||
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
|
||||
|
||||
future_cache = get_future_cache()
|
||||
|
||||
if future := future_cache.get(source_code, None):
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
return future
|
||||
|
||||
counters["inductor"]["async_compile_cache_miss"] += 1
|
||||
future = TritonFuture(
|
||||
kernel,
|
||||
self.process_pool().submit(
|
||||
_worker_compile_triton,
|
||||
kernel._reload_in_subproc,
|
||||
extra_env,
|
||||
),
|
||||
task = self.process_pool().submit(
|
||||
_worker_compile_triton,
|
||||
load_kernel,
|
||||
extra_env,
|
||||
)
|
||||
future_cache[source_code] = future
|
||||
return future
|
||||
|
||||
def reload_kernel_in_parent():
|
||||
# Benchmark how often this happens
|
||||
with dynamo_timed("reload_kernel_in_parent"):
|
||||
return load_kernel()
|
||||
|
||||
def get_result() -> CachingAutotuner:
|
||||
kernel = task.result()
|
||||
kernel.precompile(
|
||||
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
|
||||
)
|
||||
return kernel
|
||||
|
||||
future = LambdaFuture(get_result, future=task)
|
||||
CompiledTritonKernels.save(source_code, future)
|
||||
return future
|
||||
else:
|
||||
set_feature_use("parallel_compile_post_warmup", False)
|
||||
with dynamo_timed(
|
||||
"async_compile.precompile",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="triton_compile_time_us",
|
||||
log_waitcounter=True,
|
||||
):
|
||||
kernel.precompile()
|
||||
return kernel
|
||||
_set_triton_ptxas_path()
|
||||
kernel = load_kernel()
|
||||
kernel.precompile(warm_cache_only=False)
|
||||
return kernel
|
||||
|
||||
def multi_kernel(self, *args, **kwargs) -> Any:
|
||||
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ from torch._inductor.cpu_vec_isa import pick_vec_isa
|
|||
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
|
||||
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_module_to_triton_kernel,
|
||||
_reload_python_module,
|
||||
_reload_python_module_in_subproc,
|
||||
)
|
||||
|
|
@ -358,10 +357,11 @@ def sha256_hash(data: bytes) -> str:
|
|||
return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
|
||||
|
||||
|
||||
def code_hash(code: Union[str, bytes], extra: str = "") -> str:
|
||||
def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str:
|
||||
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
|
||||
if extra != "":
|
||||
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
|
||||
if extra:
|
||||
extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8")
|
||||
hashing_str = hashing_str + b"||" + extra_b
|
||||
return "c" + sha256_hash(hashing_str)
|
||||
|
||||
|
||||
|
|
@ -2815,10 +2815,10 @@ class PyCodeCache:
|
|||
return parse_stack_trace(entry)
|
||||
|
||||
|
||||
class TritonCodeCache:
|
||||
@classmethod
|
||||
def load(cls, kernel_name: str, source_code: str) -> ModuleType:
|
||||
return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
|
||||
def _load_triton_kernel_from_source(
|
||||
kernel_name: str, source_code: str
|
||||
) -> CachingAutotuner:
|
||||
return getattr(PyCodeCache.load(source_code), kernel_name)
|
||||
|
||||
|
||||
def _cuda_compiler() -> Optional[str]:
|
||||
|
|
@ -3222,30 +3222,12 @@ class CodeCacheFuture:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class TritonFuture(CodeCacheFuture):
|
||||
kernel: CachingAutotuner
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel: Any,
|
||||
future: Optional[Future[Any]],
|
||||
) -> None:
|
||||
self.kernel = kernel
|
||||
self.future = future
|
||||
|
||||
def result(self) -> Callable[..., Any]:
|
||||
if self.future is not None:
|
||||
# If the worker failed this will throw an exception.
|
||||
result = self.future.result()
|
||||
assert result is None
|
||||
self.future = None
|
||||
self.kernel.precompile()
|
||||
return self.kernel
|
||||
|
||||
|
||||
class LambdaFuture(CodeCacheFuture):
|
||||
def __init__(self, result_fn: Callable[..., Any]) -> None:
|
||||
def __init__(
|
||||
self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None
|
||||
) -> None:
|
||||
self.result_fn = result_fn
|
||||
self.future = future
|
||||
|
||||
def result(self) -> Callable[..., Any]: # type: ignore[override]
|
||||
return self.result_fn()
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from torch.utils._triton import has_triton_package
|
|||
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
|
||||
from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir, metrics
|
||||
from ..async_compile import AsyncCompile
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import (
|
||||
|
|
@ -110,6 +111,7 @@ log = logging.getLogger(__name__)
|
|||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
||||
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
|
||||
async_compile = AsyncCompile()
|
||||
|
||||
|
||||
class OpDtypeSupport:
|
||||
|
|
@ -3939,9 +3941,16 @@ class TritonScheduling(SIMDScheduling):
|
|||
src_code = src_code.replace("#pragma CMT", "#")
|
||||
|
||||
_basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py")
|
||||
|
||||
compile_wrapper = IndentedBuffer()
|
||||
|
||||
if async_compile.use_process_pool():
|
||||
# The process pool is warm, we can shell out to workers right away. This
|
||||
# allows us to save the result in async_compile.CompiledTritonKernels,
|
||||
# so that the second time we call async_compile.triton, we do no work.
|
||||
async_compile.triton(subs_name, src_code)
|
||||
|
||||
compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''")
|
||||
|
||||
compile_wrapper.splice(src_code, strip=True)
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import logging
|
|||
import os
|
||||
import os.path
|
||||
import re
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
|
|
@ -154,8 +154,47 @@ class AutotuneCache:
|
|||
if not remote_cache:
|
||||
return
|
||||
|
||||
# Save the args passed to create_cache
|
||||
# in case AutotuneCache needs to be pickled
|
||||
self.remote_cache_full_key = key
|
||||
self.is_fbcode = is_fbcode
|
||||
self.remote_cache = (remote_cache, cache_key)
|
||||
|
||||
# The AutotuneCache may be serialized/deserialized if we're using
|
||||
# AsyncCompile worker processes to run triton compilation.
|
||||
# This is because AutotuneCache instances are created on the worker
|
||||
# process, but we need to run AutotuneCache.save on the parent process
|
||||
# when actually doing autotuning.
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
# The remote cache handles themselves may not be serializable
|
||||
# So clear it and reconstruct it on setstate
|
||||
remote_cache = getattr(self, "remote_cache", None)
|
||||
return {
|
||||
**self.__dict__,
|
||||
# Save the cache_key portion
|
||||
"remote_cache": remote_cache and remote_cache[1],
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
# Reconstruct the remote cache on the parent class
|
||||
self.__dict__.update(state)
|
||||
if self.remote_cache is not None:
|
||||
assert isinstance(self.remote_cache, str)
|
||||
assert hasattr(self, "remote_cache_full_key")
|
||||
assert hasattr(self, "is_fbcode")
|
||||
cache_key = self.remote_cache
|
||||
remote_cache = create_cache(
|
||||
self.remote_cache_full_key,
|
||||
self.is_fbcode,
|
||||
"FbRemoteAutotuneCache",
|
||||
"RemoteAutotuneCache",
|
||||
)
|
||||
if remote_cache is not None:
|
||||
self.remote_cache = (remote_cache, cache_key)
|
||||
else:
|
||||
log.warning("Warning, failed to recreate remote cache after pickling")
|
||||
self.remote_cache = None
|
||||
|
||||
# Save the config in the caches
|
||||
def save(
|
||||
self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False
|
||||
|
|
|
|||
|
|
@ -68,7 +68,10 @@ def _set_triton_ptxas_path() -> None:
|
|||
|
||||
def _worker_compile_triton(
|
||||
load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str]
|
||||
) -> None:
|
||||
) -> CachingAutotuner:
|
||||
_set_triton_ptxas_path()
|
||||
os.environ.update(extra_env)
|
||||
load_kernel().precompile(warm_cache_only=True)
|
||||
kernel = load_kernel()
|
||||
kernel.precompile(warm_cache_only=True)
|
||||
kernel.prepare_for_pickle()
|
||||
return kernel
|
||||
|
|
|
|||
|
|
@ -256,18 +256,29 @@ class CachingAutotuner(KernelInterface):
|
|||
def precompile(
|
||||
self,
|
||||
warm_cache_only=False,
|
||||
reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None,
|
||||
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
|
||||
):
|
||||
if warm_cache_only:
|
||||
self._precompile_worker()
|
||||
return
|
||||
with self.lock:
|
||||
# Helper function for reloading a kernel generated in a worker
|
||||
# in the parent class. Normally we don't need to reload the kernel
|
||||
# in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock),
|
||||
# we need to actually run compilation on the parent process
|
||||
if reload_kernel is not None:
|
||||
self._reload_kernel = reload_kernel
|
||||
self._precompile_worker()
|
||||
self._make_launchers()
|
||||
self._dynamic_scale_rblock(reload_in_parent)
|
||||
self._dynamic_scale_rblock()
|
||||
|
||||
def _precompile_worker(self):
|
||||
if self.compile_results:
|
||||
for result in self.compile_results:
|
||||
TritonBundler.put(
|
||||
triton_hash_to_path_key(result.kernel.hash),
|
||||
self.triton_meta.get("device", 0),
|
||||
)
|
||||
return
|
||||
assert not self.launchers
|
||||
if not self.configs:
|
||||
|
|
@ -285,9 +296,7 @@ class CachingAutotuner(KernelInterface):
|
|||
self.compile_results = compile_results
|
||||
self.configs = None
|
||||
|
||||
def _dynamic_scale_rblock(
|
||||
self, reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None
|
||||
):
|
||||
def _dynamic_scale_rblock(self):
|
||||
# TODO(jansel): we should find a way to move this extra compile into the worker process
|
||||
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
|
||||
device_prop = self.device_props
|
||||
|
|
@ -392,8 +401,9 @@ class CachingAutotuner(KernelInterface):
|
|||
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
||||
containing the real fn yet.
|
||||
"""
|
||||
assert reload_in_parent
|
||||
self.fn = reload_in_parent().fn
|
||||
assert hasattr(self, "_reload_kernel")
|
||||
assert callable(self._reload_kernel)
|
||||
self.fn = self._reload_kernel().fn
|
||||
self.compile_results.append(self._precompile_config(new_config))
|
||||
|
||||
self._make_launchers()
|
||||
|
|
@ -415,6 +425,7 @@ class CachingAutotuner(KernelInterface):
|
|||
for result in self.compile_results:
|
||||
try:
|
||||
launchers.append(result.make_launcher())
|
||||
|
||||
except (OutOfResources, PTXASError) as e:
|
||||
exc = e
|
||||
if len(launchers) == 0:
|
||||
|
|
@ -519,7 +530,6 @@ class CachingAutotuner(KernelInterface):
|
|||
compile_meta,
|
||||
)
|
||||
raise
|
||||
|
||||
TritonBundler.put(
|
||||
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
|
||||
)
|
||||
|
|
@ -819,6 +829,17 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
config2launcher = {launcher.config: launcher}
|
||||
|
||||
# TODO: should we just load the kernels ahead of time if we know we're going to call this?
|
||||
if self.fn.fn is None:
|
||||
"""
|
||||
We are in the parent process, while this program was compiled in a worker
|
||||
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
||||
containing the real fn yet.
|
||||
"""
|
||||
assert hasattr(self, "_reload_kernel")
|
||||
assert callable(self._reload_kernel)
|
||||
self.fn = self._reload_kernel().fn
|
||||
|
||||
def benchmark_one_config(config):
|
||||
with self.lock:
|
||||
launcher = self._precompile_config(config).make_launcher()
|
||||
|
|
@ -957,12 +978,50 @@ class TritonCompileResult:
|
|||
self.compile_meta = compile_meta
|
||||
self.inductor_meta = inductor_meta
|
||||
|
||||
@staticmethod
|
||||
def _serialize_metadata(metadata):
|
||||
"""
|
||||
Triton uses a nested class called KernelMetadata to store metadata information.
|
||||
Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear
|
||||
in the toplevel namespace of the module. So these serialization/deser functions
|
||||
are used to convert the namedtuples to a dict and back.
|
||||
|
||||
As for packed_metadata, depending on the triton backend, KernelMetadata can be
|
||||
a namedtuple, or a regular tuple! So the serialization function branches on whether
|
||||
the metadata to be serialized is a namedtuple or regular, serializable one.
|
||||
"""
|
||||
|
||||
def is_namedtuple(obj) -> bool:
|
||||
return (
|
||||
isinstance(obj, tuple)
|
||||
and hasattr(obj, "_asdict")
|
||||
and hasattr(obj, "_fields")
|
||||
)
|
||||
|
||||
if is_namedtuple(metadata):
|
||||
return metadata._asdict()
|
||||
else:
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_metadata(metadata):
|
||||
if isinstance(metadata, dict):
|
||||
return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))(
|
||||
**metadata
|
||||
)
|
||||
else:
|
||||
return metadata
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
kernel = self.kernel
|
||||
# replace the fields that don't pickle nicely
|
||||
kernel_state = {
|
||||
**kernel.__dict__,
|
||||
"metadata": kernel.metadata._asdict(),
|
||||
# See doc about serializing metadata above
|
||||
"metadata": self._serialize_metadata(kernel.metadata),
|
||||
"packed_metadata": self._serialize_metadata(
|
||||
getattr(kernel, "packed_metadata", None)
|
||||
),
|
||||
"module": None, # regenerated by kernel._init_handles()
|
||||
"function": None, # regenerated by kernel._init_handles()
|
||||
"run": None, # regenerated by kernel._init_handles()
|
||||
|
|
@ -975,13 +1034,13 @@ class TritonCompileResult:
|
|||
# TODO(jansel): need to fixup src.fn which is now None
|
||||
kernel = CompiledKernel.__new__(CompiledKernel)
|
||||
metadata = state["kernel"]["metadata"]
|
||||
packed_metadata = state["kernel"]["packed_metadata"]
|
||||
kernel.__dict__.update(
|
||||
{
|
||||
**state["kernel"],
|
||||
# "src": src,
|
||||
"metadata": self._kernel_metadata_cls(tuple(metadata.keys()))(
|
||||
**metadata
|
||||
),
|
||||
"metadata": self._deserialize_metadata(metadata),
|
||||
"packed_metadata": self._deserialize_metadata(packed_metadata),
|
||||
}
|
||||
)
|
||||
self.__dict__.update(state)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ import sympy
|
|||
import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.codecache import PyCodeCache, TritonFuture
|
||||
from torch._inductor.codecache import LambdaFuture, PyCodeCache
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
|
@ -2734,7 +2734,7 @@ class Scheduler:
|
|||
|
||||
def compile_kernel(
|
||||
nodes: Sequence[BaseSchedulerNode],
|
||||
) -> tuple[Optional[TritonFuture], ModuleType]:
|
||||
) -> tuple[Optional[LambdaFuture], ModuleType]:
|
||||
src_code = self.generate_kernel_code_from_nodes(
|
||||
nodes, benchmark_kernel=True
|
||||
)
|
||||
|
|
@ -2743,7 +2743,7 @@ class Scheduler:
|
|||
fut = None
|
||||
else:
|
||||
fut = async_compile.triton(kernel_name="triton_", source_code=src_code)
|
||||
assert isinstance(fut, TritonFuture)
|
||||
assert isinstance(fut, LambdaFuture)
|
||||
|
||||
return (fut, mod)
|
||||
|
||||
|
|
@ -2772,7 +2772,7 @@ class Scheduler:
|
|||
)
|
||||
|
||||
# Start compiling choices in parallel
|
||||
future_choices: List[tuple[Any, Optional[TritonFuture], ModuleType]] = []
|
||||
future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
|
||||
triton_choices = 0
|
||||
for choice, unfused_time in sorted(
|
||||
choice_timings.items(), key=lambda x: x[1]
|
||||
|
|
|
|||
Loading…
Reference in a new issue