mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Update (base update)
[ghstack-poisoned]
This commit is contained in:
parent
7c8ec84dab
commit
99af07f2f3
9 changed files with 253 additions and 79 deletions
|
|
@ -1494,6 +1494,9 @@ class TestAutotuneCache(TestCase):
|
|||
@config.patch({"autotune_remote_cache": True})
|
||||
@config.patch({"bundled_autotune_remote_cache": False})
|
||||
@config.patch({"max_autotune": True})
|
||||
@config.patch(
|
||||
{"compile_threads": 1}
|
||||
) # Worker processes do not register PatchCaches() properly
|
||||
def test_autotune_cache(self):
|
||||
class Model(torch.nn.Module):
|
||||
def forward(self, x, y, a, b):
|
||||
|
|
@ -1531,6 +1534,7 @@ class TestAutotuneCache(TestCase):
|
|||
@config.patch({"autotune_local_cache": True})
|
||||
@config.patch({"autotune_remote_cache": False})
|
||||
@config.patch({"bundled_autotune_remote_cache": True})
|
||||
@config.patch({"compile_threads": 1})
|
||||
@config.patch({"max_autotune": True})
|
||||
def test_bundled_autotune_remote_cache(self):
|
||||
class Model(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -1031,6 +1031,9 @@ class TestMaxAutotuneRemoteCache(TestCase):
|
|||
PatchCaches.tearDown()
|
||||
|
||||
@parametrize("dynamic", (False, True))
|
||||
@config.patch(
|
||||
{"compile_threads": 1}
|
||||
) # Worker processes do not register PatchCaches() properly
|
||||
def test_max_autotune_remote_caching(self, dynamic: bool):
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
|
|||
|
|
@ -11,13 +11,15 @@ from concurrent.futures import Future, ProcessPoolExecutor, ThreadPoolExecutor
|
|||
from concurrent.futures.process import BrokenProcessPool
|
||||
from functools import partial
|
||||
from time import time
|
||||
from typing import Any, Callable, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, Dict, Optional, TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
from torch._dynamo.device_interface import get_registered_device_interfaces
|
||||
from torch._dynamo.utils import counters, dynamo_timed, set_feature_use
|
||||
from torch._inductor import config
|
||||
from torch._inductor.codecache import (
|
||||
_load_triton_kernel_from_source,
|
||||
code_hash,
|
||||
CodeCacheFuture,
|
||||
CppCodeCache,
|
||||
CppPythonBindingsCodeCache,
|
||||
|
|
@ -25,8 +27,7 @@ from torch._inductor.codecache import (
|
|||
HalideCodeCache,
|
||||
LambdaFuture,
|
||||
ROCmCodeCache,
|
||||
TritonCodeCache,
|
||||
TritonFuture,
|
||||
torch_key,
|
||||
)
|
||||
from torch._inductor.compile_worker.subproc_pool import AnyPool, SubprocPool
|
||||
from torch._inductor.compile_worker.watchdog import _async_compile_initializer
|
||||
|
|
@ -42,6 +43,7 @@ from torch.utils._triton import has_triton_package
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.runtime.hints import HalideMeta
|
||||
from torch._inductor.runtime.triton_heuristics import CachingAutotuner
|
||||
|
||||
# timing metrics for time spent in the compilation
|
||||
_cumulative_compile_time = 0.0
|
||||
|
|
@ -130,9 +132,49 @@ def get_compile_threads() -> int:
|
|||
|
||||
|
||||
@clear_on_fresh_inductor_cache
|
||||
@functools.lru_cache(None)
|
||||
def get_future_cache():
|
||||
return {}
|
||||
class CompiledTritonKernels:
|
||||
"""
|
||||
In memory cache for storing compiled triton kernels.
|
||||
|
||||
Each triton kernel is keyed by the hash of its source code. Each value stored
|
||||
in the cache is a return value of AsyncCompile.triton().
|
||||
|
||||
Currently, the cache stores Future objects, but it should be generalizable for any kernels.
|
||||
"""
|
||||
|
||||
_cache: Dict[str, LambdaFuture] = {}
|
||||
|
||||
@staticmethod
|
||||
def key(kernel_src: str):
|
||||
"""
|
||||
Generates a cache key given a triton kernel's full source code.
|
||||
This source includes the inductor meta, compilation metadata, the kernel itself, etc.
|
||||
`kernel_src` should be the exact string passed to async_compile.triton()'s first argument.
|
||||
"""
|
||||
# Hashes the kernel source with torch_key into a single hash key
|
||||
return code_hash(kernel_src, extra=torch_key())
|
||||
|
||||
@staticmethod
|
||||
def save(kernel_src: str, future: LambdaFuture):
|
||||
"""
|
||||
Saves a compiled triton kernel to the cache.
|
||||
TODO: We store a LambdaFuture as that's the callable returned by async_compile.triton,
|
||||
but the real type we want to return here is actually an abstract triton kernel.
|
||||
|
||||
TODO: Source code here is not just the kernel's source code, but also includes the inductor preamble, etc.
|
||||
so it could be less strict.
|
||||
"""
|
||||
key = CompiledTritonKernels.key(kernel_src)
|
||||
CompiledTritonKernels._cache[key] = future
|
||||
|
||||
@staticmethod
|
||||
def get(kernel_src: str, default: Any) -> LambdaFuture:
|
||||
key = CompiledTritonKernels.key(kernel_src)
|
||||
return CompiledTritonKernels._cache.get(key, default)
|
||||
|
||||
@staticmethod
|
||||
def cache_clear():
|
||||
CompiledTritonKernels._cache = {}
|
||||
|
||||
|
||||
class AsyncCompile:
|
||||
|
|
@ -208,51 +250,84 @@ class AsyncCompile:
|
|||
)
|
||||
|
||||
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
||||
"""
|
||||
Async_compile.triton is more complicated than the other backends because
|
||||
we're trying to optimize compile time as much as possible for this hot callsite.
|
||||
|
||||
First of all, the function is cached by CompiledTritonKernels; if there's a kernel
|
||||
already compiled, we grab it directly from the cache and return.
|
||||
|
||||
Otherwise, if we have multiple compile threads, we kick off triton compilations on each
|
||||
worker process by giving it a kernel and source code to compile. The worker initializes
|
||||
a CachingAutotuner, runs triton compilation, and pickles the kernel back to us.
|
||||
We use TritonCompileResult to represent the objects being pickled back to us by each
|
||||
worker.
|
||||
|
||||
Some maybe not obvious things that are pickled back to us:
|
||||
- Most of the time, we can avoid sending back CachingAutotuner.fn and other metadata
|
||||
and do not have to pay the cost of loading the triton kernel on the parent. But certain
|
||||
cases, like coordesc tuning and dynamic_scale_rblock, require us to reload the function
|
||||
in the parent lazily when we require it.
|
||||
- The AutotuneCache, if enabled, is constructed on each worker per triton config
|
||||
and pickled by to us via `CachingAutotuner.save_cache_hook`.
|
||||
"""
|
||||
if future := CompiledTritonKernels.get(source_code, None):
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
return future
|
||||
|
||||
counters["inductor"]["async_compile_cache_miss"] += 1
|
||||
|
||||
kernel_code_log.info("Triton Kernel:\n%s", source_code)
|
||||
_compile_start()
|
||||
_set_triton_ptxas_path()
|
||||
|
||||
if os.environ.get("TRITON_INTERPRET", "0") == "1":
|
||||
return getattr(
|
||||
torch._inductor.codecache.PyCodeCache.load(source_code), kernel_name
|
||||
)
|
||||
|
||||
kernel = TritonCodeCache.load(kernel_name, source_code)
|
||||
if self.use_process_pool():
|
||||
set_feature_use("parallel_compile_post_warmup", True)
|
||||
load_kernel = functools.partial(
|
||||
_load_triton_kernel_from_source, kernel_name, source_code
|
||||
)
|
||||
is_parallel = self.use_process_pool()
|
||||
set_feature_use("parallel_compile_post_warmup", is_parallel)
|
||||
if is_parallel:
|
||||
# We want to support changing these env vars after (and while) the
|
||||
# process pool is running, so pass them to the subprocess to reset.
|
||||
env_vars = ["TORCHINDUCTOR_CACHE_DIR", "TRITON_CACHE_DIR"]
|
||||
extra_env = {v: os.environ[v] for v in env_vars if v in os.environ}
|
||||
|
||||
future_cache = get_future_cache()
|
||||
|
||||
if future := future_cache.get(source_code, None):
|
||||
counters["inductor"]["async_compile_cache_hit"] += 1
|
||||
return future
|
||||
|
||||
counters["inductor"]["async_compile_cache_miss"] += 1
|
||||
future = TritonFuture(
|
||||
kernel,
|
||||
self.process_pool().submit(
|
||||
_worker_compile_triton,
|
||||
kernel._reload_in_subproc,
|
||||
extra_env,
|
||||
),
|
||||
task = self.process_pool().submit(
|
||||
_worker_compile_triton,
|
||||
load_kernel,
|
||||
extra_env,
|
||||
)
|
||||
future_cache[source_code] = future
|
||||
return future
|
||||
|
||||
def reload_kernel_in_parent():
|
||||
# Benchmark how often this happens
|
||||
with dynamo_timed("reload_kernel_in_parent"):
|
||||
return load_kernel()
|
||||
|
||||
def get_result() -> CachingAutotuner:
|
||||
kernel = task.result()
|
||||
kernel.precompile(
|
||||
warm_cache_only=False, reload_kernel=reload_kernel_in_parent
|
||||
)
|
||||
return kernel
|
||||
|
||||
future = LambdaFuture(get_result, future=task)
|
||||
CompiledTritonKernels.save(source_code, future)
|
||||
return future
|
||||
else:
|
||||
set_feature_use("parallel_compile_post_warmup", False)
|
||||
with dynamo_timed(
|
||||
"async_compile.precompile",
|
||||
log_pt2_compile_event=True,
|
||||
dynamo_compile_column_us="triton_compile_time_us",
|
||||
log_waitcounter=True,
|
||||
):
|
||||
kernel.precompile()
|
||||
return kernel
|
||||
_set_triton_ptxas_path()
|
||||
kernel = load_kernel()
|
||||
kernel.precompile(warm_cache_only=False)
|
||||
return kernel
|
||||
|
||||
def multi_kernel(self, *args, **kwargs) -> Any:
|
||||
from torch._inductor.codegen.multi_kernel import MultiKernelCall
|
||||
|
|
|
|||
|
|
@ -68,7 +68,6 @@ from torch._inductor.cpu_vec_isa import pick_vec_isa
|
|||
from torch._inductor.custom_graph_pass import CustomGraphPass, CustomGraphPassType
|
||||
from torch._inductor.freezing_utils import has_frozen_params, is_frozen_param
|
||||
from torch._inductor.runtime.compile_tasks import (
|
||||
_module_to_triton_kernel,
|
||||
_reload_python_module,
|
||||
_reload_python_module_in_subproc,
|
||||
)
|
||||
|
|
@ -358,10 +357,11 @@ def sha256_hash(data: bytes) -> str:
|
|||
return base64.b32encode(hashlib.sha256(data).digest())[:51].decode("utf-8").lower()
|
||||
|
||||
|
||||
def code_hash(code: Union[str, bytes], extra: str = "") -> str:
|
||||
def code_hash(code: Union[str, bytes], extra: Union[str, bytes] = "") -> str:
|
||||
hashing_str = code if isinstance(code, bytes) else code.encode("utf-8")
|
||||
if extra != "":
|
||||
hashing_str = hashing_str + b"||" + extra.encode("utf-8")
|
||||
if extra:
|
||||
extra_b = extra if isinstance(extra, bytes) else extra.encode("utf-8")
|
||||
hashing_str = hashing_str + b"||" + extra_b
|
||||
return "c" + sha256_hash(hashing_str)
|
||||
|
||||
|
||||
|
|
@ -2815,10 +2815,10 @@ class PyCodeCache:
|
|||
return parse_stack_trace(entry)
|
||||
|
||||
|
||||
class TritonCodeCache:
|
||||
@classmethod
|
||||
def load(cls, kernel_name: str, source_code: str) -> ModuleType:
|
||||
return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name)
|
||||
def _load_triton_kernel_from_source(
|
||||
kernel_name: str, source_code: str
|
||||
) -> CachingAutotuner:
|
||||
return getattr(PyCodeCache.load(source_code), kernel_name)
|
||||
|
||||
|
||||
def _cuda_compiler() -> Optional[str]:
|
||||
|
|
@ -3222,30 +3222,12 @@ class CodeCacheFuture:
|
|||
raise NotImplementedError
|
||||
|
||||
|
||||
class TritonFuture(CodeCacheFuture):
|
||||
kernel: CachingAutotuner
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
kernel: Any,
|
||||
future: Optional[Future[Any]],
|
||||
) -> None:
|
||||
self.kernel = kernel
|
||||
self.future = future
|
||||
|
||||
def result(self) -> Callable[..., Any]:
|
||||
if self.future is not None:
|
||||
# If the worker failed this will throw an exception.
|
||||
result = self.future.result()
|
||||
assert result is None
|
||||
self.future = None
|
||||
self.kernel.precompile()
|
||||
return self.kernel
|
||||
|
||||
|
||||
class LambdaFuture(CodeCacheFuture):
|
||||
def __init__(self, result_fn: Callable[..., Any]) -> None:
|
||||
def __init__(
|
||||
self, result_fn: Callable[..., Any], future: Optional[Future[Any]] = None
|
||||
) -> None:
|
||||
self.result_fn = result_fn
|
||||
self.future = future
|
||||
|
||||
def result(self) -> Callable[..., Any]: # type: ignore[override]
|
||||
return self.result_fn()
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ from torch.utils._triton import has_triton_package
|
|||
from ...utils._sympy.symbol import free_symbol_is_type, prefix_str, symbol_is_type, SymT
|
||||
from ...utils._sympy.value_ranges import ValueRanges
|
||||
from .. import config, ir, metrics
|
||||
from ..async_compile import AsyncCompile
|
||||
from ..codecache import code_hash, get_path, PyCodeCache
|
||||
from ..runtime.benchmarking import benchmarker
|
||||
from ..runtime.hints import (
|
||||
|
|
@ -110,6 +111,7 @@ log = logging.getLogger(__name__)
|
|||
perf_hint_log = torch._logging.getArtifactLogger(__name__, "perf_hints")
|
||||
schedule_log = torch._logging.getArtifactLogger(__name__, "schedule")
|
||||
fusion_log = torch._logging.getArtifactLogger(__name__, "fusion")
|
||||
async_compile = AsyncCompile()
|
||||
|
||||
|
||||
class OpDtypeSupport:
|
||||
|
|
@ -3939,9 +3941,16 @@ class TritonScheduling(SIMDScheduling):
|
|||
src_code = src_code.replace("#pragma CMT", "#")
|
||||
|
||||
_basename, _, kernel_path = get_path(code_hash(src_code.strip()), "py")
|
||||
|
||||
compile_wrapper = IndentedBuffer()
|
||||
|
||||
if async_compile.use_process_pool():
|
||||
# The process pool is warm, we can shell out to workers right away. This
|
||||
# allows us to save the result in async_compile.CompiledTritonKernels,
|
||||
# so that the second time we call async_compile.triton, we do no work.
|
||||
async_compile.triton(subs_name, src_code)
|
||||
|
||||
compile_wrapper.writeline(f"async_compile.triton({subs_name!r}, '''")
|
||||
|
||||
compile_wrapper.splice(src_code, strip=True)
|
||||
current_device = V.graph.get_current_device_or_throw()
|
||||
compile_wrapper.writeline(f"''', device_str='{current_device.type}')")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import logging
|
|||
import os
|
||||
import os.path
|
||||
import re
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from typing import Any, Optional, TYPE_CHECKING
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
|
|
@ -154,8 +154,47 @@ class AutotuneCache:
|
|||
if not remote_cache:
|
||||
return
|
||||
|
||||
# Save the args passed to create_cache
|
||||
# in case AutotuneCache needs to be pickled
|
||||
self.remote_cache_full_key = key
|
||||
self.is_fbcode = is_fbcode
|
||||
self.remote_cache = (remote_cache, cache_key)
|
||||
|
||||
# The AutotuneCache may be serialized/deserialized if we're using
|
||||
# AsyncCompile worker processes to run triton compilation.
|
||||
# This is because AutotuneCache instances are created on the worker
|
||||
# process, but we need to run AutotuneCache.save on the parent process
|
||||
# when actually doing autotuning.
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
# The remote cache handles themselves may not be serializable
|
||||
# So clear it and reconstruct it on setstate
|
||||
remote_cache = getattr(self, "remote_cache", None)
|
||||
return {
|
||||
**self.__dict__,
|
||||
# Save the cache_key portion
|
||||
"remote_cache": remote_cache and remote_cache[1],
|
||||
}
|
||||
|
||||
def __setstate__(self, state: dict[str, Any]) -> None:
|
||||
# Reconstruct the remote cache on the parent class
|
||||
self.__dict__.update(state)
|
||||
if self.remote_cache is not None:
|
||||
assert isinstance(self.remote_cache, str)
|
||||
assert hasattr(self, "remote_cache_full_key")
|
||||
assert hasattr(self, "is_fbcode")
|
||||
cache_key = self.remote_cache
|
||||
remote_cache = create_cache(
|
||||
self.remote_cache_full_key,
|
||||
self.is_fbcode,
|
||||
"FbRemoteAutotuneCache",
|
||||
"RemoteAutotuneCache",
|
||||
)
|
||||
if remote_cache is not None:
|
||||
self.remote_cache = (remote_cache, cache_key)
|
||||
else:
|
||||
log.warning("Warning, failed to recreate remote cache after pickling")
|
||||
self.remote_cache = None
|
||||
|
||||
# Save the config in the caches
|
||||
def save(
|
||||
self, config: Config, time_taken_ns: int, found_by_coordesc: bool = False
|
||||
|
|
|
|||
|
|
@ -68,7 +68,10 @@ def _set_triton_ptxas_path() -> None:
|
|||
|
||||
def _worker_compile_triton(
|
||||
load_kernel: Callable[[], CachingAutotuner], extra_env: dict[str, str]
|
||||
) -> None:
|
||||
) -> CachingAutotuner:
|
||||
_set_triton_ptxas_path()
|
||||
os.environ.update(extra_env)
|
||||
load_kernel().precompile(warm_cache_only=True)
|
||||
kernel = load_kernel()
|
||||
kernel.precompile(warm_cache_only=True)
|
||||
kernel.prepare_for_pickle()
|
||||
return kernel
|
||||
|
|
|
|||
|
|
@ -256,18 +256,29 @@ class CachingAutotuner(KernelInterface):
|
|||
def precompile(
|
||||
self,
|
||||
warm_cache_only=False,
|
||||
reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None,
|
||||
reload_kernel: Optional[Callable[[], CachingAutotuner]] = None,
|
||||
):
|
||||
if warm_cache_only:
|
||||
self._precompile_worker()
|
||||
return
|
||||
with self.lock:
|
||||
# Helper function for reloading a kernel generated in a worker
|
||||
# in the parent class. Normally we don't need to reload the kernel
|
||||
# in the parent process, but in certain cases (coordesc tuning, dynamic_scale_rblock),
|
||||
# we need to actually run compilation on the parent process
|
||||
if reload_kernel is not None:
|
||||
self._reload_kernel = reload_kernel
|
||||
self._precompile_worker()
|
||||
self._make_launchers()
|
||||
self._dynamic_scale_rblock(reload_in_parent)
|
||||
self._dynamic_scale_rblock()
|
||||
|
||||
def _precompile_worker(self):
|
||||
if self.compile_results:
|
||||
for result in self.compile_results:
|
||||
TritonBundler.put(
|
||||
triton_hash_to_path_key(result.kernel.hash),
|
||||
self.triton_meta.get("device", 0),
|
||||
)
|
||||
return
|
||||
assert not self.launchers
|
||||
if not self.configs:
|
||||
|
|
@ -285,9 +296,7 @@ class CachingAutotuner(KernelInterface):
|
|||
self.compile_results = compile_results
|
||||
self.configs = None
|
||||
|
||||
def _dynamic_scale_rblock(
|
||||
self, reload_in_parent: Optional[Callable[[], CachingAutotuner]] = None
|
||||
):
|
||||
def _dynamic_scale_rblock(self):
|
||||
# TODO(jansel): we should find a way to move this extra compile into the worker process
|
||||
# Currently it relies on _make_launchers(), which requires a cuda context, to populate nreg.
|
||||
device_prop = self.device_props
|
||||
|
|
@ -392,8 +401,9 @@ class CachingAutotuner(KernelInterface):
|
|||
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
||||
containing the real fn yet.
|
||||
"""
|
||||
assert reload_in_parent
|
||||
self.fn = reload_in_parent().fn
|
||||
assert hasattr(self, "_reload_kernel")
|
||||
assert callable(self._reload_kernel)
|
||||
self.fn = self._reload_kernel().fn
|
||||
self.compile_results.append(self._precompile_config(new_config))
|
||||
|
||||
self._make_launchers()
|
||||
|
|
@ -415,6 +425,7 @@ class CachingAutotuner(KernelInterface):
|
|||
for result in self.compile_results:
|
||||
try:
|
||||
launchers.append(result.make_launcher())
|
||||
|
||||
except (OutOfResources, PTXASError) as e:
|
||||
exc = e
|
||||
if len(launchers) == 0:
|
||||
|
|
@ -519,7 +530,6 @@ class CachingAutotuner(KernelInterface):
|
|||
compile_meta,
|
||||
)
|
||||
raise
|
||||
|
||||
TritonBundler.put(
|
||||
triton_hash_to_path_key(binary.hash), self.triton_meta.get("device", 0)
|
||||
)
|
||||
|
|
@ -819,6 +829,17 @@ class CachingAutotuner(KernelInterface):
|
|||
|
||||
config2launcher = {launcher.config: launcher}
|
||||
|
||||
# TODO: should we just load the kernels ahead of time if we know we're going to call this?
|
||||
if self.fn.fn is None:
|
||||
"""
|
||||
We are in the parent process, while this program was compiled in a worker
|
||||
and the fn was dropped in prepare_for_pickle(). We haven't loaded the module
|
||||
containing the real fn yet.
|
||||
"""
|
||||
assert hasattr(self, "_reload_kernel")
|
||||
assert callable(self._reload_kernel)
|
||||
self.fn = self._reload_kernel().fn
|
||||
|
||||
def benchmark_one_config(config):
|
||||
with self.lock:
|
||||
launcher = self._precompile_config(config).make_launcher()
|
||||
|
|
@ -957,12 +978,50 @@ class TritonCompileResult:
|
|||
self.compile_meta = compile_meta
|
||||
self.inductor_meta = inductor_meta
|
||||
|
||||
@staticmethod
|
||||
def _serialize_metadata(metadata):
|
||||
"""
|
||||
Triton uses a nested class called KernelMetadata to store metadata information.
|
||||
Pickle does not work well with nested namedtuples, as the namedtuple doesn't appear
|
||||
in the toplevel namespace of the module. So these serialization/deser functions
|
||||
are used to convert the namedtuples to a dict and back.
|
||||
|
||||
As for packed_metadata, depending on the triton backend, KernelMetadata can be
|
||||
a namedtuple, or a regular tuple! So the serialization function branches on whether
|
||||
the metadata to be serialized is a namedtuple or regular, serializable one.
|
||||
"""
|
||||
|
||||
def is_namedtuple(obj) -> bool:
|
||||
return (
|
||||
isinstance(obj, tuple)
|
||||
and hasattr(obj, "_asdict")
|
||||
and hasattr(obj, "_fields")
|
||||
)
|
||||
|
||||
if is_namedtuple(metadata):
|
||||
return metadata._asdict()
|
||||
else:
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _deserialize_metadata(metadata):
|
||||
if isinstance(metadata, dict):
|
||||
return TritonCompileResult._kernel_metadata_cls(tuple(metadata.keys()))(
|
||||
**metadata
|
||||
)
|
||||
else:
|
||||
return metadata
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
kernel = self.kernel
|
||||
# replace the fields that don't pickle nicely
|
||||
kernel_state = {
|
||||
**kernel.__dict__,
|
||||
"metadata": kernel.metadata._asdict(),
|
||||
# See doc about serializing metadata above
|
||||
"metadata": self._serialize_metadata(kernel.metadata),
|
||||
"packed_metadata": self._serialize_metadata(
|
||||
getattr(kernel, "packed_metadata", None)
|
||||
),
|
||||
"module": None, # regenerated by kernel._init_handles()
|
||||
"function": None, # regenerated by kernel._init_handles()
|
||||
"run": None, # regenerated by kernel._init_handles()
|
||||
|
|
@ -975,13 +1034,13 @@ class TritonCompileResult:
|
|||
# TODO(jansel): need to fixup src.fn which is now None
|
||||
kernel = CompiledKernel.__new__(CompiledKernel)
|
||||
metadata = state["kernel"]["metadata"]
|
||||
packed_metadata = state["kernel"]["packed_metadata"]
|
||||
kernel.__dict__.update(
|
||||
{
|
||||
**state["kernel"],
|
||||
# "src": src,
|
||||
"metadata": self._kernel_metadata_cls(tuple(metadata.keys()))(
|
||||
**metadata
|
||||
),
|
||||
"metadata": self._deserialize_metadata(metadata),
|
||||
"packed_metadata": self._deserialize_metadata(packed_metadata),
|
||||
}
|
||||
)
|
||||
self.__dict__.update(state)
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ import sympy
|
|||
import torch
|
||||
import torch._inductor.async_compile # noqa: F401 required to warm up AsyncCompile pools
|
||||
from torch._dynamo.utils import counters, dynamo_timed
|
||||
from torch._inductor.codecache import PyCodeCache, TritonFuture
|
||||
from torch._inductor.codecache import LambdaFuture, PyCodeCache
|
||||
from torch._inductor.metrics import get_metric_table, is_metric_table_enabled
|
||||
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
|
@ -2734,7 +2734,7 @@ class Scheduler:
|
|||
|
||||
def compile_kernel(
|
||||
nodes: Sequence[BaseSchedulerNode],
|
||||
) -> tuple[Optional[TritonFuture], ModuleType]:
|
||||
) -> tuple[Optional[LambdaFuture], ModuleType]:
|
||||
src_code = self.generate_kernel_code_from_nodes(
|
||||
nodes, benchmark_kernel=True
|
||||
)
|
||||
|
|
@ -2743,7 +2743,7 @@ class Scheduler:
|
|||
fut = None
|
||||
else:
|
||||
fut = async_compile.triton(kernel_name="triton_", source_code=src_code)
|
||||
assert isinstance(fut, TritonFuture)
|
||||
assert isinstance(fut, LambdaFuture)
|
||||
|
||||
return (fut, mod)
|
||||
|
||||
|
|
@ -2772,7 +2772,7 @@ class Scheduler:
|
|||
)
|
||||
|
||||
# Start compiling choices in parallel
|
||||
future_choices: List[tuple[Any, Optional[TritonFuture], ModuleType]] = []
|
||||
future_choices: List[tuple[Any, Optional[LambdaFuture], ModuleType]] = []
|
||||
triton_choices = 0
|
||||
for choice, unfused_time in sorted(
|
||||
choice_timings.items(), key=lambda x: x[1]
|
||||
|
|
|
|||
Loading…
Reference in a new issue