mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fixes
This commit is contained in:
parent
93834a5581
commit
92d0a0d9d3
3 changed files with 38 additions and 27 deletions
|
|
@ -1,7 +1,8 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
from typing import Any, TYPE_CHECKING
|
||||
from functools import partial
|
||||
from typing import Any, Generator, Optional, TYPE_CHECKING
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
@ -24,6 +25,9 @@ from .virtualized import V
|
|||
|
||||
if TYPE_CHECKING:
|
||||
from functools import partial
|
||||
|
||||
from triton import Config as TritonConfig
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
from .codegen.simd_kernel_features import SIMDKernelFeatures
|
||||
|
|
@ -50,7 +54,9 @@ class InductorChoices:
|
|||
torch._inductor.virtualized.V.set_choices_handler(MyHeuristics())
|
||||
"""
|
||||
|
||||
def get_config_heuristics(self, device_type="cuda") -> BaseConfigHeuristic:
|
||||
def get_config_heuristics(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> BaseConfigHeuristic:
|
||||
if device_type == "cuda":
|
||||
if torch.version.hip is None:
|
||||
return CUDAConfigHeuristic()
|
||||
|
|
@ -64,43 +70,61 @@ class InductorChoices:
|
|||
return BaseConfigHeuristic()
|
||||
|
||||
# GEMM configs
|
||||
def get_base_mm_configs(self, device_type="cuda") -> partial:
|
||||
def get_base_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
if config.max_autotune_gemm_search_space != "EXHAUSTIVE":
|
||||
return mm_heuristics.get_mm_configs()
|
||||
else:
|
||||
return mm_heuristics.get_exhaustive_mm_configs()
|
||||
|
||||
def get_extra_mm_configs(self, device_type="cuda") -> partial:
|
||||
def get_extra_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_extra_mm_configs()
|
||||
|
||||
def get_int8_mm_configs(self, device_type="cuda") -> partial:
|
||||
def get_int8_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_int8_mm_configs()
|
||||
|
||||
def get_mixed_mm_configs(self, device_type="cuda") -> partial:
|
||||
def get_mixed_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_mixed_mm_configs()
|
||||
|
||||
def get_persistent_mm_configs(self, device_type="cuda") -> partial:
|
||||
def get_persistent_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_persistent_mm_configs()
|
||||
|
||||
def get_scaled_mm_configs(self, device_type="cuda") -> partial:
|
||||
def get_scaled_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_scaled_mm_configs()
|
||||
|
||||
def get_scaled_persistent_mm_configs(self, device_type="cuda") -> partial:
|
||||
def get_scaled_persistent_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_scaled_persistent_mm_configs()
|
||||
|
||||
def get_mm_plus_mm_configs(self, device_type="cuda") -> partial:
|
||||
def get_mm_plus_mm_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
mm_heuristics = self.get_config_heuristics(device_type)
|
||||
return mm_heuristics.get_mm_plus_mm_configs()
|
||||
|
||||
# Conv configs
|
||||
def get_conv_configs(self, device_type="cuda") -> partial:
|
||||
def get_conv_configs(
|
||||
self, device_type: Optional[str] = "cuda"
|
||||
) -> partial[Generator[TritonConfig, None, None]]:
|
||||
conv_heuristics = self.get_config_heuristics(device_type)
|
||||
return conv_heuristics.get_conv_configs()
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast, Dict, Tuple
|
||||
from typing import Any
|
||||
|
||||
import sympy
|
||||
|
||||
|
|
|
|||
|
|
@ -4,19 +4,7 @@ import itertools
|
|||
from collections import namedtuple
|
||||
from functools import partial
|
||||
from threading import Lock
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
)
|
||||
from typing import Any, Callable, Generator, List, Sequence, Tuple, Type, TYPE_CHECKING
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
|
@ -43,7 +31,7 @@ class BaseConfigSingleton(type):
|
|||
) -> BaseConfigHeuristic:
|
||||
with cls._lock:
|
||||
if cls not in cls._instances:
|
||||
instance = super().__call__(cls, *args, **kwargs)
|
||||
instance = super().__call__()
|
||||
cls._instances[cls] = instance
|
||||
return cls._instances[cls]
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue