This commit is contained in:
Jack Taylor 2025-02-06 13:11:03 +00:00 committed by PyTorch MergeBot
parent 93834a5581
commit 92d0a0d9d3
3 changed files with 38 additions and 27 deletions

View file

@ -1,7 +1,8 @@
from __future__ import annotations
import typing
from typing import Any, TYPE_CHECKING
from functools import partial
from typing import Any, Generator, Optional, TYPE_CHECKING
import sympy
@ -24,6 +25,9 @@ from .virtualized import V
if TYPE_CHECKING:
from functools import partial
from triton import Config as TritonConfig
from torch.utils._ordered_set import OrderedSet
from .codegen.simd_kernel_features import SIMDKernelFeatures
@ -50,7 +54,9 @@ class InductorChoices:
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
"""
def get_config_heuristics(self, device_type="cuda") -> BaseConfigHeuristic:
def get_config_heuristics(
self, device_type: Optional[str] = "cuda"
) -> BaseConfigHeuristic:
if device_type == "cuda":
if torch.version.hip is None:
return CUDAConfigHeuristic()
@ -64,43 +70,61 @@ class InductorChoices:
return BaseConfigHeuristic()
# GEMM configs
def get_base_mm_configs(self, device_type="cuda") -> partial:
def get_base_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
return mm_heuristics.get_mm_configs()
else:
return mm_heuristics.get_exhaustive_mm_configs()
def get_extra_mm_configs(self, device_type="cuda") -> partial:
def get_extra_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_extra_mm_configs()
def get_int8_mm_configs(self, device_type="cuda") -> partial:
def get_int8_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_int8_mm_configs()
def get_mixed_mm_configs(self, device_type="cuda") -> partial:
def get_mixed_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mixed_mm_configs()
def get_persistent_mm_configs(self, device_type="cuda") -> partial:
def get_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_persistent_mm_configs()
def get_scaled_mm_configs(self, device_type="cuda") -> partial:
def get_scaled_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_mm_configs()
def get_scaled_persistent_mm_configs(self, device_type="cuda") -> partial:
def get_scaled_persistent_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_scaled_persistent_mm_configs()
def get_mm_plus_mm_configs(self, device_type="cuda") -> partial:
def get_mm_plus_mm_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
mm_heuristics = self.get_config_heuristics(device_type)
return mm_heuristics.get_mm_plus_mm_configs()
# Conv configs
def get_conv_configs(self, device_type="cuda") -> partial:
def get_conv_configs(
self, device_type: Optional[str] = "cuda"
) -> partial[Generator[TritonConfig, None, None]]:
conv_heuristics = self.get_config_heuristics(device_type)
return conv_heuristics.get_conv_configs()

View file

@ -1,7 +1,6 @@
# mypy: allow-untyped-defs
import logging
from collections.abc import Sequence
from typing import Any, cast, Dict, Tuple
from typing import Any
import sympy

View file

@ -4,19 +4,7 @@ import itertools
from collections import namedtuple
from functools import partial
from threading import Lock
from typing import (
Any,
Callable,
cast,
Dict,
Generator,
List,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
TypeVar,
)
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING
from torch.utils._ordered_set import OrderedSet
@ -43,7 +31,7 @@ class BaseConfigSingleton(type):
) -> BaseConfigHeuristic:
with cls._lock:
if cls not in cls._instances:
instance = super().__call__(cls, *args, **kwargs)
instance = super().__call__()
cls._instances[cls] = instance
return cls._instances[cls]