diff --git a/torch/_inductor/choices.py b/torch/_inductor/choices.py index eec0f1c591d..05ad3beb9aa 100644 --- a/torch/_inductor/choices.py +++ b/torch/_inductor/choices.py @@ -1,7 +1,8 @@ from __future__ import annotations import typing -from typing import Any, TYPE_CHECKING +from functools import partial +from typing import Any, Generator, Optional, TYPE_CHECKING import sympy @@ -24,6 +25,9 @@ from .virtualized import V if TYPE_CHECKING: from functools import partial + + from triton import Config as TritonConfig + from torch.utils._ordered_set import OrderedSet from .codegen.simd_kernel_features import SIMDKernelFeatures @@ -50,7 +54,9 @@ class InductorChoices: torch._inductor.virtualized.V.set_choices_handler(MyHeuristics()) """ - def get_config_heuristics(self, device_type="cuda") -> BaseConfigHeuristic: + def get_config_heuristics( + self, device_type: Optional[str] = "cuda" + ) -> BaseConfigHeuristic: if device_type == "cuda": if torch.version.hip is None: return CUDAConfigHeuristic() @@ -64,43 +70,61 @@ class InductorChoices: return BaseConfigHeuristic() # GEMM configs - def get_base_mm_configs(self, device_type="cuda") -> partial: + def get_base_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: mm_heuristics = self.get_config_heuristics(device_type) if config.max_autotune_gemm_search_space != "EXHAUSTIVE": return mm_heuristics.get_mm_configs() else: return mm_heuristics.get_exhaustive_mm_configs() - def get_extra_mm_configs(self, device_type="cuda") -> partial: + def get_extra_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: mm_heuristics = self.get_config_heuristics(device_type) return mm_heuristics.get_extra_mm_configs() - def get_int8_mm_configs(self, device_type="cuda") -> partial: + def get_int8_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: mm_heuristics = self.get_config_heuristics(device_type) return mm_heuristics.get_int8_mm_configs() - def get_mixed_mm_configs(self, device_type="cuda") -> partial: + def get_mixed_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: mm_heuristics = self.get_config_heuristics(device_type) return mm_heuristics.get_mixed_mm_configs() - def get_persistent_mm_configs(self, device_type="cuda") -> partial: + def get_persistent_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: mm_heuristics = self.get_config_heuristics(device_type) return mm_heuristics.get_persistent_mm_configs() - def get_scaled_mm_configs(self, device_type="cuda") -> partial: + def get_scaled_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: mm_heuristics = self.get_config_heuristics(device_type) return mm_heuristics.get_scaled_mm_configs() - def get_scaled_persistent_mm_configs(self, device_type="cuda") -> partial: + def get_scaled_persistent_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: mm_heuristics = self.get_config_heuristics(device_type) return mm_heuristics.get_scaled_persistent_mm_configs() - def get_mm_plus_mm_configs(self, device_type="cuda") -> partial: + def get_mm_plus_mm_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: mm_heuristics = self.get_config_heuristics(device_type) return mm_heuristics.get_mm_plus_mm_configs() # Conv configs - def get_conv_configs(self, device_type="cuda") -> partial: + def get_conv_configs( + self, device_type: Optional[str] = "cuda" + ) -> partial[Generator[TritonConfig, None, None]]: conv_heuristics = self.get_config_heuristics(device_type) return conv_heuristics.get_conv_configs() diff --git a/torch/_inductor/kernel/mm_common.py b/torch/_inductor/kernel/mm_common.py index 05ae5543961..0bee451f148 100644 --- a/torch/_inductor/kernel/mm_common.py +++ b/torch/_inductor/kernel/mm_common.py @@ -1,7 +1,6 @@ # mypy: allow-untyped-defs import logging -from collections.abc import Sequence -from typing import Any, cast, Dict, Tuple +from typing import Any import sympy diff --git a/torch/_inductor/template_heuristics.py b/torch/_inductor/template_heuristics.py index c4980314036..5718e74f89e 100644 --- a/torch/_inductor/template_heuristics.py +++ b/torch/_inductor/template_heuristics.py @@ -4,19 +4,7 @@ import itertools from collections import namedtuple from functools import partial from threading import Lock -from typing import ( - Any, - Callable, - cast, - Dict, - Generator, - List, - Sequence, - Tuple, - Type, - TYPE_CHECKING, - TypeVar, -) +from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING from torch.utils._ordered_set import OrderedSet @@ -43,7 +31,7 @@ class BaseConfigSingleton(type): ) -> BaseConfigHeuristic: with cls._lock: if cls not in cls._instances: - instance = super().__call__(cls, *args, **kwargs) + instance = super().__call__() cls._instances[cls] = instance return cls._instances[cls]