From b8c5c4443c1cdd6abe88c5372ed17385bf0b64ea Mon Sep 17 00:00:00 2001 From: James Wu Date: Sat, 8 Feb 2025 20:15:28 -0800 Subject: [PATCH] Only call triton in worker process, ahead of time compile Pull Request resolved: https://github.com/pytorch/pytorch/pull/146417 ### Big idea This PR extends https://github.com/pytorch/pytorch/pull/144288 by combining calling triton in worker processes with the future cache: we kick off triton compilation in the worker processes earlier, during inductor codegen. Basically instead of calling async_compile.triton for the first time only after the entire code has been generated, we start compiling as soon as we know we'll need to compile the kernel. Then, when loading the generated inductor code, we can simply read from our in memory future cache, considerably increasing the parallelism. ### Implementation Overview In total, the diff does the following: - Converts TritonFuture to LambdaFuture, only calling triton.compile on worker processes - Now that triton.compile() isn't called on the main process, we call TritonBundler on all compiled kernels when we get them back from workers - Extend @eellison's future cache to a class, mostly as a refactor - Finally, call async_compile.triton ahead of time in Scheduler.codegen if workers are warmed up. This causes the subsequent async_compile.triton call that occurs after codegen to cache hit on cold start. In the diffs after this, I will add more to CompiledTritonKernels so that TritonBundler, on a warm start, automatically populates the in memory cache on warm start with the existing triton kernels, avoiding calling triton altogether on warm starts. Because LambdaFutures are much faster to kick off than TritonFutures, due to not needing to load from TritonCodeCache at all, the time spent kicking off these worker jobs is pretty minimal for inductor codegen. Differential Revision: [D69123174](https://our.internmc.facebook.com/intern/diff/D69123174/) ghstack-source-id: 22f0e5a3bdc9ef6ef461a1f4afc0b2cb0afd0a8e --- test/inductor/test_codecache.py | 4 + test/inductor/test_max_autotune.py | 3 + torch/_inductor/async_compile.py | 133 +++++++++++++++---- torch/_inductor/codecache.py | 42 ++---- torch/_inductor/codegen/triton.py | 11 +- torch/_inductor/runtime/autotune_cache.py | 41 +++++- torch/_inductor/runtime/compile_tasks.py | 7 +- torch/_inductor/runtime/triton_heuristics.py | 83 ++++++++++-- torch/_inductor/scheduler.py | 8 +- 9 files changed, 253 insertions(+), 79 deletions(-) diff --git a/test/inductor/test_codecache.py b/test/inductor/test_codecache.py index adb8319ba6d..4309ad9cfa9 100644 --- a/test/inductor/test_codecache.py +++ b/test/inductor/test_codecache.py @@ -1494,6 +1494,9 @@ class TestAutotuneCache(TestCase): @config.patch({"autotune_remote_cache": True}) @config.patch({"bundled_autotune_remote_cache": False}) @config.patch({"max_autotune": True}) + @config.patch( + {"compile_threads": 1} + ) # Worker processes do not register PatchCaches() properly def test_autotune_cache(self): class Model(torch.nn.Module): def forward(self, x, y, a, b): @@ -1531,6 +1534,7 @@ class TestAutotuneCache(TestCase): @config.patch({"autotune_local_cache": True}) @config.patch({"autotune_remote_cache": False}) @config.patch({"bundled_autotune_remote_cache": True}) + @config.patch({"compile_threads": 1}) @config.patch({"max_autotune": True}) def test_bundled_autotune_remote_cache(self): class Model(torch.nn.Module): diff --git a/test/inductor/test_max_autotune.py b/test/inductor/test_max_autotune.py index 29a8d4d3004..399e6b47819 100644 --- a/test/inductor/test_max_autotune.py +++ b/test/inductor/test_max_autotune.py @@ -1031,6 +1031,9 @@ class TestMaxAutotuneRemoteCache(TestCase): PatchCaches.tearDown() @parametrize("dynamic", (False, True)) + @config.patch( + {"compile_threads": 1} + ) # Worker processes do not register PatchCaches() properly def test_max_autotune_remote_caching(self, dynamic: bool): from unittest.mock import patch diff --git a/torch/_inductor/async_compile.py b/torch/_inductor/async_compile.py index 4815f6f568a..d44bcd9452a 100644 --- a/torch/_inductor/async_compile.py +++ b/torch/_inductor/async_compile.py @@ -11,13 +11,15 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor from concurrent.futures.process import BrokenProcessPool from functools import partial from time import time -from typing import Any, Callable, Optional, TYPE_CHECKING +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING import torch from torch._dynamo.device_interface import get_registered_device_interfaces from torch._dynamo.utils import counters, dynamo_timed, set_feature_use from torch._inductor import config from torch._inductor.codecache import ( + _load_triton_kernel_from_source, + code_hash, CodeCacheFuture, CppCodeCache, CppPythonBindingsCodeCache, @@ -25,8 +27,7 @@ from torch._inductor.codecache import ( HalideCodeCache, LambdaFuture, ROCmCodeCache, - TritonCodeCache, - TritonFuture, + torch_key, ) from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool from torch._inductor.compile_worker.watchdog import _async_compile_initializer @@ -42,6 +43,7 @@ from torch.utils._triton import has_triton_package if TYPE_CHECKING: from torch._inductor.runtime.hints import HalideMeta + from torch._inductor.runtime.triton_heuristics import CachingAutotuner # timing metrics for time spent in the compilation _cumulative_compile_time = 0.0 @@ -130,9 +132,49 @@ def get_compile_threads() -> int: @clear_on_fresh_inductor_cache -@functools.lru_cache(None) -def get_future_cache(): - return {} +class CompiledTritonKernels: + """ + In memory cache for storing compiled triton kernels. + + Each triton kernel is keyed by the hash of its source code. Each value stored + in the cache is a return value of AsyncCompile.triton(). + + Currently, the cache stores Future objects, but it should be generalizable for any kernels. + """ + + _cache: Dict[str, LambdaFuture] = {} + + @staticmethod + def key(kernel_src: str): + """ + Generates a cache key given a triton kernel's full source code. + This source includes the inductor meta, compilation metadata, the kernel itself, etc. + `kernel_src` should be the exact string passed to async_compile.triton()'s first argument. + """ + # Hashes the kernel source with torch_key into a single hash key + return code_hash(kernel_src, extra=torch_key()) + + @staticmethod + def save(kernel_src: str, future: LambdaFuture): + """ + Saves a compiled triton kernel to the cache. + TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton, + but the real type we want to return here is actually an abstract triton kernel. + + TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc. + so it could be less strict. + """ + key = CompiledTritonKernels.key(kernel_src) + CompiledTritonKernels._cache[key] = future + + @staticmethod + def get(kernel_src: str, default: Any) -> LambdaFuture: + key = CompiledTritonKernels.key(kernel_src) + return CompiledTritonKernels._cache.get(key, default) + + @staticmethod + def cache_clear(): + CompiledTritonKernels._cache = {} class AsyncCompile: @@ -208,51 +250,84 @@ class AsyncCompile: ) def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): + """ + Async_compile.triton is more complicated than the other backends because + we're trying to optimize compile time as much as possible for this hot callsite. + + First of all, the function is cached by CompiledTritonKernels; if there's a kernel + already compiled, we grab it directly from the cache and return. + + Otherwise, if we have multiple compile threads, we kick off triton compilations on each + worker process by giving it a kernel and source code to compile. The worker initializes + a CachingAutotuner, runs triton compilation, and pickles the kernel back to us. + We use TritonCompileResult to represent the objects being pickled back to us by each + worker. + + Some maybe not obvious things that are pickled back to us: + - Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata + and do not have to pay the cost of loading the triton kernel on the parent. But certain + cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function + in the parent lazily when we require it. + - The AutotuneCache, if enabled, is constructed on each worker per triton config + and pickled by to us via `CachingAutotuner.save_cache_hook`. + """ + if future := CompiledTritonKernels.get(source_code, None): + counters["inductor"]["async_compile_cache_hit"] += 1 + return future + + counters["inductor"]["async_compile_cache_miss"] += 1 + kernel_code_log.info("Triton Kernel:\n%s", source_code) _compile_start() - _set_triton_ptxas_path() if os.environ.get("TRITON_INTERPRET", "0") == "1": return getattr( torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name ) - kernel = TritonCodeCache.load(kernel_name, source_code) - if self.use_process_pool(): - set_feature_use("parallel_compile_post_warmup", True) + load_kernel = functools.partial( + _load_triton_kernel_from_source, kernel_name, source_code + ) + is_parallel = self.use_process_pool() + set_feature_use("parallel_compile_post_warmup", is_parallel) + if is_parallel: # We want to support changing these env vars after (and while) the # process pool is running, so pass them to the subprocess to reset. env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"] extra_env = {v: os.environ[v] for v in env_vars if v in os.environ} - future_cache = get_future_cache() - - if future := future_cache.get(source_code, None): - counters["inductor"]["async_compile_cache_hit"] += 1 - return future - - counters["inductor"]["async_compile_cache_miss"] += 1 - future = TritonFuture( - kernel, - self.process_pool().submit( - _worker_compile_triton, - kernel._reload_in_subproc, - extra_env, - ), + task = self.process_pool().submit( + _worker_compile_triton, + load_kernel, + extra_env, ) - future_cache[source_code] = future - return future + def reload_kernel_in_parent(): + # Benchmark how often this happens + with dynamo_timed("reload_kernel_in_parent"): + return load_kernel() + + def get_result() -> CachingAutotuner: + kernel = task.result() + kernel.precompile( + warm_cache_only=False, reload_kernel=reload_kernel_in_parent + ) + return kernel + + future = LambdaFuture(get_result, future=task) + CompiledTritonKernels.save(source_code, future) + return future else: - set_feature_use("parallel_compile_post_warmup", False) with dynamo_timed( "async_compile.precompile", log_pt2_compile_event=True, dynamo_compile_column_us="triton_compile_time_us", log_waitcounter=True, ): - kernel.precompile() - return kernel + _set_triton_ptxas_path() + kernel = load_kernel() + kernel.precompile(warm_cache_only=False) + return kernel def multi_kernel(self, *args, **kwargs) -> Any: from torch._inductor.codegen.multi_kernel import MultiKernelCall diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index 6a3474987d3..c9eb117f856 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -68,7 +68,6 @@ from torch._inductor.cpu_vec_isa import pick_vec_isa from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param from torch._inductor.runtime.compile_tasks import ( - _module_to_triton_kernel, _reload_python_module, _reload_python_module_in_subproc, ) @@ -358,10 +357,11 @@ def sha256_hash(data: bytes) -> str: return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower() -def code_hash(code: Union[str, bytes], extra: str = "") -> str: +def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str: hashing_str = code if isinstance(code, bytes) else code.encode("utf-8") - if extra != "": - hashing_str = hashing_str + b"||" + extra.encode("utf-8") + if extra: + extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8") + hashing_str = hashing_str + b"||" + extra_b return "c" + sha256_hash(hashing_str) @@ -2815,10 +2815,10 @@ class PyCodeCache: return parse_stack_trace(entry) -class TritonCodeCache: - @classmethod - def load(cls, kernel_name: str, source_code: str) -> ModuleType: - return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name) +def _load_triton_kernel_from_source( + kernel_name: str, source_code: str +) -> CachingAutotuner: + return getattr(PyCodeCache.load(source_code), kernel_name) def _cuda_compiler() -> Optional[str]: @@ -3222,30 +3222,12 @@ class CodeCacheFuture: raise NotImplementedError -class TritonFuture(CodeCacheFuture): - kernel: CachingAutotuner - - def __init__( - self, - kernel: Any, - future: Optional[Future[Any]], - ) -> None: - self.kernel = kernel - self.future = future - - def result(self) -> Callable[..., Any]: - if self.future is not None: - # If the worker failed this will throw an exception. - result = self.future.result() - assert result is None - self.future = None - self.kernel.precompile() - return self.kernel - - class LambdaFuture(CodeCacheFuture): - def __init__(self, result_fn: Callable[..., Any]) -> None: + def __init__( + self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None + ) -> None: self.result_fn = result_fn + self.future = future def result(self) -> Callable[..., Any]: # type: ignore[override] return self.result_fn() diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 3b30cd40a09..64244a32df1 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -29,6 +29,7 @@ from torch.utils._triton import has_triton_package from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT from ...utils._sympy.value_ranges import ValueRanges from .. import config, ir, metrics +from ..async_compile import AsyncCompile from ..codecache import code_hash, get_path, PyCodeCache from ..runtime.benchmarking import benchmarker from ..runtime.hints import ( @@ -110,6 +111,7 @@ log = logging.getLogger(__name__) perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints") schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") fusion_log = torch._logging.getArtifactLogger(__name__, "fusion") +async_compile = AsyncCompile() class OpDtypeSupport: @@ -3939,9 +3941,16 @@ class TritonScheduling(SIMDScheduling): src_code = src_code.replace("#pragma CMT", "#") _basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py") - compile_wrapper = IndentedBuffer() + + if async_compile.use_process_pool(): + # The process pool is warm, we can shell out to workers right away. This + # allows us to save the result in async_compile.CompiledTritonKernels, + # so that the second time we call async_compile.triton, we do no work. + async_compile.triton(subs_name, src_code) + compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''") + compile_wrapper.splice(src_code, strip=True) current_device = V.graph.get_current_device_or_throw() compile_wrapper.writeline(f"''', device_str='{current_device.type}')") diff --git a/torch/_inductor/runtime/autotune_cache.py b/torch/_inductor/runtime/autotune_cache.py index dabe299fa2c..0c098f6afa4 100644 --- a/torch/_inductor/runtime/autotune_cache.py +++ b/torch/_inductor/runtime/autotune_cache.py @@ -6,7 +6,7 @@ import logging import os import os.path import re -from typing import Optional, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from typing_extensions import override import torch @@ -154,8 +154,47 @@ class AutotuneCache: if not remote_cache: return + # Save the args passed to create_cache + # in case AutotuneCache needs to be pickled + self.remote_cache_full_key = key + self.is_fbcode = is_fbcode self.remote_cache = (remote_cache, cache_key) + # The AutotuneCache may be serialized/deserialized if we're using + # AsyncCompile worker processes to run triton compilation. + # This is because AutotuneCache instances are created on the worker + # process, but we need to run AutotuneCache.save on the parent process + # when actually doing autotuning. + def __getstate__(self) -> dict[str, Any]: + # The remote cache handles themselves may not be serializable + # So clear it and reconstruct it on setstate + remote_cache = getattr(self, "remote_cache", None) + return { + **self.__dict__, + # Save the cache_key portion + "remote_cache": remote_cache and remote_cache[1], + } + + def __setstate__(self, state: dict[str, Any]) -> None: + # Reconstruct the remote cache on the parent class + self.__dict__.update(state) + if self.remote_cache is not None: + assert isinstance(self.remote_cache, str) + assert hasattr(self, "remote_cache_full_key") + assert hasattr(self, "is_fbcode") + cache_key = self.remote_cache + remote_cache = create_cache( + self.remote_cache_full_key, + self.is_fbcode, + "FbRemoteAutotuneCache", + "RemoteAutotuneCache", + ) + if remote_cache is not None: + self.remote_cache = (remote_cache, cache_key) + else: + log.warning("Warning, failed to recreate remote cache after pickling") + self.remote_cache = None + # Save the config in the caches def save( self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False diff --git a/torch/_inductor/runtime/compile_tasks.py b/torch/_inductor/runtime/compile_tasks.py index 8e1f8659adf..b15fc568e0b 100644 --- a/torch/_inductor/runtime/compile_tasks.py +++ b/torch/_inductor/runtime/compile_tasks.py @@ -68,7 +68,10 @@ def _set_triton_ptxas_path() -> None: def _worker_compile_triton( load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str] -) -> None: +) -> CachingAutotuner: _set_triton_ptxas_path() os.environ.update(extra_env) - load_kernel().precompile(warm_cache_only=True) + kernel = load_kernel() + kernel.precompile(warm_cache_only=True) + kernel.prepare_for_pickle() + return kernel diff --git a/torch/_inductor/runtime/triton_heuristics.py b/torch/_inductor/runtime/triton_heuristics.py index c920be4813b..487610b81f5 100644 --- a/torch/_inductor/runtime/triton_heuristics.py +++ b/torch/_inductor/runtime/triton_heuristics.py @@ -256,18 +256,29 @@ class CachingAutotuner(KernelInterface): def precompile( self, warm_cache_only=False, - reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None, + reload_kernel: Optional[Callable[[], CachingAutotuner]] = None, ): if warm_cache_only: self._precompile_worker() return with self.lock: + # Helper function for reloading a kernel generated in a worker + # in the parent class. Normally we don't need to reload the kernel + # in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock), + # we need to actually run compilation on the parent process + if reload_kernel is not None: + self._reload_kernel = reload_kernel self._precompile_worker() self._make_launchers() - self._dynamic_scale_rblock(reload_in_parent) + self._dynamic_scale_rblock() def _precompile_worker(self): if self.compile_results: + for result in self.compile_results: + TritonBundler.put( + triton_hash_to_path_key(result.kernel.hash), + self.triton_meta.get("device", 0), + ) return assert not self.launchers if not self.configs: @@ -285,9 +296,7 @@ class CachingAutotuner(KernelInterface): self.compile_results = compile_results self.configs = None - def _dynamic_scale_rblock( - self, reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None - ): + def _dynamic_scale_rblock(self): # TODO(jansel): we should find a way to move this extra compile into the worker process # Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg. device_prop = self.device_props @@ -392,8 +401,9 @@ class CachingAutotuner(KernelInterface): and the fn was dropped in prepare_for_pickle(). We haven't loaded the module containing the real fn yet. """ - assert reload_in_parent - self.fn = reload_in_parent().fn + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn self.compile_results.append(self._precompile_config(new_config)) self._make_launchers() @@ -415,6 +425,7 @@ class CachingAutotuner(KernelInterface): for result in self.compile_results: try: launchers.append(result.make_launcher()) + except (OutOfResources, PTXASError) as e: exc = e if len(launchers) == 0: @@ -519,7 +530,6 @@ class CachingAutotuner(KernelInterface): compile_meta, ) raise - TritonBundler.put( triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0) ) @@ -819,6 +829,17 @@ class CachingAutotuner(KernelInterface): config2launcher = {launcher.config: launcher} + # TODO: should we just load the kernels ahead of time if we know we're going to call this? + if self.fn.fn is None: + """ + We are in the parent process, while this program was compiled in a worker + and the fn was dropped in prepare_for_pickle(). We haven't loaded the module + containing the real fn yet. + """ + assert hasattr(self, "_reload_kernel") + assert callable(self._reload_kernel) + self.fn = self._reload_kernel().fn + def benchmark_one_config(config): with self.lock: launcher = self._precompile_config(config).make_launcher() @@ -957,12 +978,50 @@ class TritonCompileResult: self.compile_meta = compile_meta self.inductor_meta = inductor_meta + @staticmethod + def _serialize_metadata(metadata): + """ + Triton uses a nested class called KernelMetadata to store metadata information. + Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear + in the toplevel namespace of the module. So these serialization/deser functions + are used to convert the namedtuples to a dict and back. + + As for packed_metadata, depending on the triton backend, KernelMetadata can be + a namedtuple, or a regular tuple! So the serialization function branches on whether + the metadata to be serialized is a namedtuple or regular, serializable one. + """ + + def is_namedtuple(obj) -> bool: + return ( + isinstance(obj, tuple) + and hasattr(obj, "_asdict") + and hasattr(obj, "_fields") + ) + + if is_namedtuple(metadata): + return metadata._asdict() + else: + return metadata + + @staticmethod + def _deserialize_metadata(metadata): + if isinstance(metadata, dict): + return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))( + **metadata + ) + else: + return metadata + def __getstate__(self) -> dict[str, Any]: kernel = self.kernel # replace the fields that don't pickle nicely kernel_state = { **kernel.__dict__, - "metadata": kernel.metadata._asdict(), + # See doc about serializing metadata above + "metadata": self._serialize_metadata(kernel.metadata), + "packed_metadata": self._serialize_metadata( + getattr(kernel, "packed_metadata", None) + ), "module": None, # regenerated by kernel._init_handles() "function": None, # regenerated by kernel._init_handles() "run": None, # regenerated by kernel._init_handles() @@ -975,13 +1034,13 @@ class TritonCompileResult: # TODO(jansel): need to fixup src.fn which is now None kernel = CompiledKernel.__new__(CompiledKernel) metadata = state["kernel"]["metadata"] + packed_metadata = state["kernel"]["packed_metadata"] kernel.__dict__.update( { **state["kernel"], # "src": src, - "metadata": self._kernel_metadata_cls(tuple(metadata.keys()))( - **metadata - ), + "metadata": self._deserialize_metadata(metadata), + "packed_metadata": self._deserialize_metadata(packed_metadata), } ) self.__dict__.update(state) diff --git a/torch/_inductor/scheduler.py b/torch/_inductor/scheduler.py index ee272adf557..fdbfd58fc5a 100644 --- a/torch/_inductor/scheduler.py +++ b/torch/_inductor/scheduler.py @@ -36,7 +36,7 @@ import sympy import torch import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools from torch._dynamo.utils import counters, dynamo_timed -from torch._inductor.codecache import PyCodeCache, TritonFuture +from torch._inductor.codecache import LambdaFuture, PyCodeCache from torch._inductor.metrics import get_metric_table, is_metric_table_enabled from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols from torch.utils._ordered_set import OrderedSet @@ -2734,7 +2734,7 @@ class Scheduler: def compile_kernel( nodes: Sequence[BaseSchedulerNode], - ) -> tuple[Optional[TritonFuture], ModuleType]: + ) -> tuple[Optional[LambdaFuture], ModuleType]: src_code = self.generate_kernel_code_from_nodes( nodes, benchmark_kernel=True ) @@ -2743,7 +2743,7 @@ class Scheduler: fut = None else: fut = async_compile.triton(kernel_name="triton_", source_code=src_code) - assert isinstance(fut, TritonFuture) + assert isinstance(fut, LambdaFuture) return (fut, mod) @@ -2772,7 +2772,7 @@ class Scheduler: ) # Start compiling choices in parallel - future_choices: List[tuple[Any, Optional[TritonFuture], ModuleType]] = [] + future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = [] triton_choices = 0 for choice, unfused_time in sorted( choice_timings.items(), key=lambda x: x[1]