Revert "[inductor] Refactor op handlers part 1 (#146235)"

This reverts commit 204be4e0a2.

Reverted https://github.com/pytorch/pytorch/pull/146235 on behalf of https://github.com/atalman due to Breaks lint, sorry: Definition of polygamma in base class MetalOverrides is incompatible with definition in base class OpsHandler. Please rebase fix lint and reland ([comment](https://github.com/pytorch/pytorch/pull/146235#issuecomment-2632444514))
This commit is contained in:
PyTorch MergeBot 2025-02-04 00:00:08 +00:00
parent 3aeccf2a28
commit 2f40f789da
9 changed files with 159 additions and 397 deletions

View file

@ -1,4 +1,4 @@
add_loop_eager,compile_time_instruction_count,3096000000,0.015
add_loop_eager,compile_time_instruction_count,3066000000,0.015
@ -6,27 +6,27 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
add_loop_inductor,compile_time_instruction_count,30150000000,0.015
add_loop_inductor,compile_time_instruction_count,32120000000,0.015
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,45210000000,0.025
add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015
add_loop_inductor_gpu,compile_time_instruction_count,27360000000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,945100000,0.015
basic_modules_ListOfLinears_eager,compile_time_instruction_count,936400000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,0.015
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19610000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17600000000,0.015
@ -46,7 +46,7 @@ symint_sum,compile_time_instruction_count,3324000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2028000000,0.015
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2018000000,0.015
@ -54,7 +54,7 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5836000000,0
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9167000000,0.015
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9103000000,0.015

1 add_loop_eager compile_time_instruction_count 3096000000 3066000000 0.015
2 add_loop_eager_dynamic compile_time_instruction_count 5703000000 5703000000 0.025
3 add_loop_inductor compile_time_instruction_count 30150000000 32120000000 0.015
4 add_loop_inductor_dynamic_gpu compile_time_instruction_count 44440000000 45210000000 0.025
6 basic_modules_ListOfLinears_eager compile_time_instruction_count 945100000 936400000 0.015
7 basic_modules_ListOfLinears_inductor compile_time_instruction_count 18980000000 19610000000 0.015
8 basic_modules_ListOfLinears_inductor_gpu_force_shape_pad compile_time_instruction_count 17250000000 17600000000 0.015
9 basic_modules_ListOfLinears_inductor_gpu compile_time_instruction_count 10885050825 10885050825 0.2
10 update_hint_regression compile_time_instruction_count 1686000000 1686000000 0.02
11 sum_floordiv_regression compile_time_instruction_count 1041000000 1041000000 0.015
12 symint_sum compile_time_instruction_count 3324000000 3324000000 0.015
13 aotdispatcher_inference_nosubclass_cpu compile_time_instruction_count 2028000000 2018000000 0.015
14 aotdispatcher_inference_subclass_cpu compile_time_instruction_count 5836000000 5836000000 0.015
15 aotdispatcher_partitioner_cpu compile_time_instruction_count 9167000000 9103000000 0.015
16 aotdispatcher_training_nosubclass_cpu compile_time_instruction_count 3863000000 3863000000 0.015
17 aotdispatcher_training_subclass_cpu compile_time_instruction_count 10340000000 10340000000 0.015
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
46
47
48
49
50
51
52
54
55
56
57
58
59
60

View file

@ -1,42 +0,0 @@
# Owner(s): ["module: inductor"]
import unittest
from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides
from torch._inductor.codegen.halide import HalideOverrides
from torch._inductor.codegen.mps import MetalOverrides
from torch._inductor.codegen.triton import TritonKernelOverrides
from torch._inductor.ops_handler import list_ops, OP_NAMES
from torch._inductor.test_case import TestCase
class TestOpCompleteness(TestCase):
def verify_ops_handler_completeness(self, handler):
op_names = list_ops(handler)
if OP_NAMES == op_names:
return
print(f"Missing ops: {OP_NAMES - op_names}")
print(f"Extra ops: {op_names - OP_NAMES}")
self.assertEqual(", ".join(OP_NAMES - op_names), "")
self.assertEqual(", ".join(op_names - OP_NAMES), "")
def test_triton_overrides(self):
self.verify_ops_handler_completeness(TritonKernelOverrides)
def test_cpp_overrides(self):
self.verify_ops_handler_completeness(CppOverrides)
def test_cpp_vec_overrides(self):
self.verify_ops_handler_completeness(CppVecOverrides)
def test_halide_overrides(self):
self.verify_ops_handler_completeness(HalideOverrides)
@unittest.skip("MPS backend not yet finished")
def test_metal_overrides(self):
self.verify_ops_handler_completeness(MetalOverrides)
if __name__ == "__main__":
from torch._inductor.test_case import run_tests
run_tests()

View file

@ -41,7 +41,6 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val
from .. import config, metrics
from ..dtype_propagation import DtypePropagationOpsHandler
from ..ops_handler import BasicMathOps
from ..utils import (
boolean_ops,
DeferredLineBase,
@ -50,7 +49,6 @@ from ..utils import (
ir_dataclass,
ScopedDict,
sympy_dot,
sympy_index_symbol,
sympy_subs,
unique,
)
@ -763,7 +761,11 @@ def _all_in_parens(string: str) -> bool:
return True
class OpOverrides(BasicMathOps, OpDecompositions):
class OpOverrides(OpDecompositions):
def __init__(self, parent: OpsHandler[OpVarT]) -> None:
super().__init__()
self._parent = parent
@staticmethod
def paren(string: OpVarT) -> OpVarT:
if (
@ -775,6 +777,9 @@ class OpOverrides(BasicMathOps, OpDecompositions):
return string
return f"({string})"
def __getattr__(self, item: str) -> Callable[..., Any]:
return getattr(self._parent, item)
@staticmethod
def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT:
return repr(value)
@ -847,138 +852,15 @@ class OpOverrides(BasicMathOps, OpDecompositions):
def load_seed(name: str, offset: OpVarT) -> OpVarT:
return ops.load(name, sympy.Integer(offset))
def indirect_indexing(
self,
var: OpVarT,
size: Union[sympy.Expr, int],
check: bool = True,
wrap_neg: bool = True,
) -> sympy.Symbol:
return sympy_index_symbol(str(var))
def check_bounds(
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
) -> None:
raise NotImplementedError(
f"{type(self).__name__}: check_bounds should be handled by CSEProxy"
)
def load(self, name: str, index: sympy.Expr) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: load should be handled by CSEProxy"
)
def store(
self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None
) -> None:
raise NotImplementedError(
f"{type(self).__name__}: store should be handled by CSEProxy"
)
def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None:
raise NotImplementedError(
f"{type(self).__name__}: store_reduction should be handled by CSEProxy"
)
def reduction(
self,
dtype: torch.dtype,
src_dtype: torch.dtype,
reduction_type: ReductionType,
value: Union[OpVarT, tuple[OpVarT, ...]],
) -> Union[OpVarT, tuple[OpVarT, ...]]:
raise NotImplementedError(
f"{type(self).__name__}: reduction should be handled by CSEProxy"
)
def scan(
self,
dtypes: tuple[torch.dtype, ...],
combine_fn: Callable[
[tuple[OpVarT, ...], tuple[OpVarT, ...]],
tuple[OpVarT, ...],
],
values: tuple[OpVarT, ...],
) -> tuple[OpVarT, ...]:
raise NotImplementedError(
f"{type(self).__name__}: scan should be handled by CSEProxy"
)
def sort(
self,
dtypes: tuple[torch.dtype, ...],
values: tuple[OpVarT, ...],
stable: bool,
descending: bool,
) -> tuple[OpVarT, ...]:
raise NotImplementedError(
f"{type(self).__name__}: sort should be handled by CSEProxy"
)
def bucketize(
self,
values: OpVarT,
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
boundary_indices: OpVarT,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[OpVarT] = None,
) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: bucketize should be handled by CSEProxy"
)
def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: halide_clamp only implemented for Halide backend"
)
def inline_asm_elementwise(
self,
*inputs: OpVarT,
asm: str,
constraints: Optional[str] = None,
dtype: torch.dtype = torch.float32,
is_pure: bool = True,
pack: int = 1,
) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
)
@staticmethod
def _unimplemented(name: str) -> Callable[..., OpVarT]:
def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
raise NotImplementedError(
f"{type(self).__name__} does not implement ops.{name}"
)
unimplemented.__name__ = name
unimplemented.is_unimplemented = True # type: ignore[attr-defined]
return unimplemented
@classmethod
def _is_unimplemented(cls, name: str) -> bool:
fn = getattr(cls, name, None)
default_fn = getattr(OpsHandler, name, None)
return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False)
@classmethod
def _initialize_pointwise_overrides(cls, target: str) -> None:
assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target
assert target in ("triton", "cpp", "cppvec"), target
for funcname, data in pointwise_overrides_data.items():
impl = getattr(data, target)
if impl is None:
if cls._is_unimplemented(funcname):
setattr(cls, funcname, cls._unimplemented(funcname))
else:
assert (
funcname not in cls.__dict__
), f"multiple definitions of {funcname} on {cls.__name__}"
impl.__name__ = funcname
setattr(cls, funcname, staticmethod(impl))
continue
setattr(cls, funcname, staticmethod(impl))
@dataclasses.dataclass
@ -992,8 +874,6 @@ class OverridesData:
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
)
halide: Optional[Callable[..., str]] = None
mps: Optional[Callable[..., str]] = None
# NB: if you add a new special function, don't forget to update
@ -1224,10 +1104,9 @@ pointwise_overrides_data: dict[str, OverridesData] = dict(
)
if TYPE_CHECKING:
class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
# Use mypy to check protocol implemented correctly
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[OpVarT]:
return h
class DeferredLine(DeferredLineBase):
@ -1825,7 +1704,7 @@ class CodeGen:
class Kernel(CodeGen, Generic[CSEVariableType]):
newvar_prefix: str = ""
suffix: str = ""
overrides: Optional[Callable[[], OpsHandler[Any]]] = None
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
def __init__(
self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True
@ -2012,9 +1891,8 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
def __enter__(self) -> typing.Self:
super().__enter__()
assert self.overrides
self.exit_stack.enter_context(
V.set_ops_handler(CSEProxy(self, self.overrides()))
)
parent_handler = self.overrides(V.get_ops_handler())
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy(self, parent_handler)))
self.exit_stack.enter_context(V.set_kernel_handler(self))
return self

View file

@ -1140,7 +1140,9 @@ class CppVecOverrides(CppOverrides):
else:
# fallback to scalar ops
scalar_ops = super(CppVecOverrides, self)
scalar_func = getattr(scalar_ops, func.__name__)
scalar_func = getattr(
scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined]
)
assert scalar_func is not None
return scalar_func(*args, **kwargs)
@ -1644,11 +1646,8 @@ class CppVecOverrides(CppOverrides):
assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
body_vec_var.dtype = dtype
other_vec_var.dtype = dtype
overrides: type[
Union[CppOverrides, CppVecOverrides]
] = V.kernel.overrides # type: ignore[has-type]
code.writeline(
f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};"
)
code.writeline("()")
csevar = V.kernel.cse.generate(
@ -1749,7 +1748,7 @@ class CppVecOverrides(CppOverrides):
return mantissa, exponent
@classmethod
def _scalarize(cls, scalar_func):
def scalarize(cls, scalar_func):
def inner(*args, **kwargs):
assert not kwargs
kernel = V.kernel
@ -1811,10 +1810,11 @@ class CppVecOverrides(CppOverrides):
@classmethod
def _initialize_scalarize(cls):
vec_vars = vars(CppVecOverrides)
for name, method in vars(CppOverrides).items():
if isinstance(method, staticmethod) and name not in vec_vars:
func = cls._scalarize(method.__func__)
if getattr(method, "__class__", None) == staticmethod and name not in vars(
CppVecOverrides
):
func = cls.scalarize(method.__func__)
func.__name__ = name
setattr(cls, name, staticmethod(func))

View file

@ -24,7 +24,7 @@ from .. import config, ir
from ..codecache import HalideCodeCache
from ..ir import get_reduction_combine_fn
from ..metrics import is_metric_table_enabled, log_kernel_metadata
from ..ops_handler import AddParenHandler
from ..ops_handler import AddParenHandler, MockHandler
from ..runtime.hints import HalideInputSpec, HalideMeta
from ..utils import (
get_bounds_index_expr,
@ -555,18 +555,10 @@ class HalideOverrides(OpOverrides):
# TODO(jansel): look into removing the where in the same places triton does
return ops.where(new_mask, result, other)
@staticmethod
def frexp(x):
raise NotImplementedError("frexp")
HalideOverrides._initialize_pointwise_overrides("halide")
if TYPE_CHECKING:
class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
# Use mypy to check protocol implemented correctly
def _typecheck_HalideOverrides(h: HalideOverrides) -> OpsHandler[str]:
return h
class HalideCSEVariable(CSEVariable):
@ -1225,7 +1217,7 @@ class HalideKernel(SIMDKernel):
result_var = self.welford_reduce_fallback(dtype, value)
else:
combine_fn = get_reduction_combine_fn(reduction_type, acc_type)
with V.set_ops_handler(AddParenHandler(HalideOverrides())):
with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))):
combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type]
default_str = f"hl.cast({acc_type}, {halide_constant(default)})"
self.body.writeline(f"{result_var} = {default_str}")
@ -1341,7 +1333,7 @@ class HalideKernel(SIMDKernel):
self.body.writeline(f"{result_var} = {maybe_tuple(initial)}")
# Disable CSE for update fn
with V.set_ops_handler(AddParenHandler(HalideOverrides())):
with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))):
combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type]
self.body.writeline(
f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}"

View file

@ -26,7 +26,7 @@ if TYPE_CHECKING:
import sympy
from ..ops_handler import OpsHandler, StoreMode
from ..ops_handler import StoreMode
from ..scheduler import Scheduler, SchedulerNode
from .common import OpVarT
@ -357,15 +357,6 @@ class MetalOverrides(OpOverrides):
return f"metal::pow({cast_a}, {cast_b})"
MetalOverrides._initialize_pointwise_overrides("mps")
if TYPE_CHECKING:
class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]):
pass # mypy will error if we got any of the signatures wrong
class MetalKernel(SIMDKernel):
overrides = MetalOverrides # type: ignore[assignment]
suffix = ";"

View file

@ -1283,6 +1283,11 @@ class TritonOverrides(OpOverrides):
TritonOverrides._initialize_pointwise_overrides("triton")
# Use mypy to check protocol implemented correctly
def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]:
return h
class TritonKernelOverrides(TritonOverrides):
"""Map element-wise ops to Triton within a TritonKernel
@ -1407,10 +1412,9 @@ class TritonKernelOverrides(TritonOverrides):
return (mantissa, exponent)
if TYPE_CHECKING:
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
# Use mypy to check protocol implemented correctly
def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]:
return h
class HelperFunctions:
@ -2808,7 +2812,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args))
helper.writeline(f"def {{name}}({signature}):")
overrides = TritonOverrides()
overrides = TritonOverrides(V.MockHandler())
# Build a name that changes depending on fn to workaround a triton bug
# where the combine_fn to reduce and scan is not hashed, and so different

View file

@ -1,17 +1,22 @@
# mypy: allow-untyped-defs
import functools
from collections.abc import Sequence
from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
from typing import Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
import sympy
import torch
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype
from torch.utils._ordered_set import OrderedSet
from .ops_handler import OP_NAMES, OpsHandler
if TYPE_CHECKING:
from torch._inductor.loop_body import LoopBodyBlock
import torch
from torch._inductor.virtualized import V
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype
from .utils import upcast_compute_type
from .virtualized import OpsValue, V
from .virtualized import OpsValue
T = TypeVar("T")
@ -123,7 +128,10 @@ class DtypePropagationOpsHandler:
self, op, functools.partial(self.return_dtype, dtype=torch.bool)
)
unimplemented_ops = OP_NAMES - OrderedSet(dir(self))
from torch._inductor.ops_handler import OpsHandler
ops_set = OrderedSet(s for s in dir(OpsHandler) if s[0] != "_")
unimplemented_ops = ops_set - OrderedSet(dir(self))
torch._check(
len(unimplemented_ops) == 0,
lambda: f"Unimplemented dtype rule for ops: {unimplemented_ops}",
@ -156,12 +164,7 @@ class DtypePropagationOpsHandler:
return torch.int64
@staticmethod
def masked(
mask: DTypeArg, body: Callable[[], DTypeArg], other: DTypeArg
) -> torch.dtype:
from .loop_body import LoopBodyBlock
assert isinstance(body, LoopBodyBlock), "body must be a LoopBodyBlock"
def masked(mask: DTypeArg, body: "LoopBodyBlock", other: DTypeArg) -> torch.dtype:
# TODO - we avoid calling this in codegen, needs work for non codegen use cases
loads = body.graph.find_nodes(op="call_method", target="load")
if len(loads) <= 1:
@ -207,6 +210,10 @@ class DtypePropagationOpsHandler:
def mul(a: DTypeArg, b: DTypeArg) -> torch.dtype:
return promote_types([a, b])
@staticmethod
def div(a: DTypeArg, b: DTypeArg) -> torch.dtype:
return promote_types([a, b])
@staticmethod
def truediv(a: DTypeArg, b: DTypeArg) -> torch.dtype:
return promote_types([a, b])
@ -312,8 +319,6 @@ class DtypePropagationOpsHandler:
boundary_indices: DTypeArg,
indexing_dtype: torch.dtype,
right: bool,
sorter: Optional[tuple[str, sympy.Expr]] = None,
sorter_indices: Optional[T] = None,
) -> torch.dtype:
return indexing_dtype
@ -327,6 +332,10 @@ class DtypePropagationOpsHandler:
[x], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
)
@staticmethod
def getitem(x: DTypeArg, y: DTypeArg) -> torch.dtype:
raise RuntimeError("Unexpected op: getitem")
@staticmethod
def trunc_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype:
return dtype
@ -343,6 +352,11 @@ class DtypePropagationOpsHandler:
def floordiv(x: DTypeArg, y: DTypeArg) -> torch.dtype:
return promote_types([x, y])
@staticmethod
def round_decimal(x: DTypeArg, y: DTypeArg) -> torch.dtype:
# TODO - dont see it anywhere..
return promote_types([x])
@staticmethod
def halide_clamp(value, size, check):
# TODO - way of registering dtype for op in backend
@ -363,13 +377,9 @@ class DtypePropagationOpsHandler:
return promote_types([x])
@staticmethod
def check_bounds(
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
) -> None:
return None
def invert(x: DTypeArg) -> torch.dtype:
raise RuntimeError("Unexpected op: invert")
if TYPE_CHECKING:
class _typecheck_DtypePropagation(DtypePropagationOpsHandler, OpsHandler[Any]):
pass # mypy will error if we got any of the signatures wrong
@staticmethod
def matmul(x: DTypeArg, y: DTypeArg) -> torch.dtype:
raise RuntimeError("Unexpected op: matmul")

View file

@ -2,7 +2,6 @@
from __future__ import annotations
import itertools
import re
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
from typing_extensions import Protocol
from unittest.mock import patch
@ -66,29 +65,16 @@ class OpsHandler(Protocol[T]):
Note that this often describes a class of static methods, for stateless
ops handlers.
Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides),
which means you will get type errors if you subclass OpsHandler since mypy doesn't know
about the methods added via metaprogramming and thinks the class is still abstract.
Instead, you should add a block like:
if TYPE_CHECKING:
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
Which will check the signatures of non-meta-programmed methods and gives decent error messages.
Some older parts of the code use a pattern like:
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
return h
This pattern only works if the class defines a __getattr__ method, which we are moving away from.
Additionally, this pattern generates horrible error messages if the signatures are wrong.
It gives zero information about what the problem is, which makes the pattern harmful.
Instead of that, we have tests in test/inductor/test_op_completeness.py which check that all
operators are implemented after all the metaprogramming has run.
Handlers are often defined using ``__getattr__`` metaprogramming, which means
that you cannot declare that a type implements a protocol by inheriting from
it (as the type stubs count as attribute declarations and impede the getattr
magic method from being called). Instead, define a function that casts an
argument of your type to the protocol, which is sufficient to induce mypy to
test that the protocol is implemented correctly. Search for ``_typecheck_``
in this file to see some examples. If you see an obscure error where a
class doesn't implement a Protocol, but mypy doesn't say why, check to see
that ``__getattr__`` is typed correctly (typically, it is not possible to
type ``__getattr__`` without typing it as ``Callable[..., Any]``)
"""
def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
@ -531,6 +517,17 @@ class OpsHandler(Protocol[T]):
def rshift(self, x0: T, x1: T) -> T:
...
def getitem(self, x0: T, x1: T) -> T:
# TODO: this is probably just illegal lol
...
def matmul(self, x0: T, x1: T) -> T:
# TODO: this is probably just illegal lol
...
def invert(self, x0: T) -> T:
...
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# These are "special" operators. These only exist if the target
# language actually supports the operator. Keep this in sync with
@ -686,6 +683,11 @@ class OpsHandler(Protocol[T]):
"""
...
def div(self, x0: T, x1: T) -> T:
"""TODO: to be removed. This renders as / no matter what the backend is
which is incoherent."""
...
def mod(self, x0: T, x1: T) -> T:
"""C-style modulus, take sign from LHS (x0)."""
...
@ -694,12 +696,8 @@ class OpsHandler(Protocol[T]):
"""Python-style modulus, take sign from RHS (x1)."""
...
def square(self, x0: T) -> T:
...
def check_bounds(
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
) -> None:
def round_decimal(self, x0: T, x1: T) -> T:
"""Python-style round with decimal argument"""
...
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@ -736,42 +734,16 @@ class OpsHandler(Protocol[T]):
def libdevice_log(self, x0: T) -> T:
...
# halide-only
def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T:
raise NotImplementedError
# triton-only
def inline_asm_elementwise(
self,
*inputs: T,
asm: str,
constraints: Optional[str] = None,
dtype: torch.dtype = torch.float32,
is_pure: bool = True,
pack: int = 1,
) -> T:
...
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
def list_ops(cls: type[Any]):
return OrderedSet([x for x in dir(cls) if not _ignore_op_re(x)])
OP_NAMES = list_ops(OpsHandler)
def _return_none(*args, **kwargs):
return None
class NoopHandler:
name = "NoopHandler"
def __getattr__(self, name):
if name == "name":
return "NoopHandler"
def __getattr__(self, name: str) -> Callable[..., None]:
return _return_none
def inner(*args, **kwargs):
return None
return inner
@staticmethod
def masked(mask, body, other) -> None:
@ -799,89 +771,11 @@ def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]:
return h
class BasicMathOps:
@staticmethod
def add(a, b):
return f"{a} + {b}"
@staticmethod
def sub(a, b):
return f"{a} - {b}"
@staticmethod
def mul(a, b):
return f"{a} * {b}"
@staticmethod
def floordiv(a, b):
return f"{a} // {b}"
@staticmethod
def truediv(a, b):
return f"{a} / {b}"
@staticmethod
def mod(a, b):
# careful, depending on target semantics varies
return f"{a} % {b}"
@staticmethod
def pow(a, b):
return f"{a} ** {b}"
@staticmethod
def lshift(a, b):
return f"{a} << {b}"
@staticmethod
def rshift(a, b):
return f"{a} >> {b}"
@staticmethod
def and_(a, b):
return f"{a} & {b}"
@staticmethod
def or_(a, b):
return f"{a} | {b}"
@staticmethod
def xor(a, b):
return f"{a} ^ {b}"
@staticmethod
def eq(a, b):
return f"{a} == {b}"
@staticmethod
def ne(a, b):
return f"{a} != {b}"
@staticmethod
def lt(a, b):
return f"{a} < {b}"
@staticmethod
def gt(a, b):
return f"{a} > {b}"
@staticmethod
def le(a, b):
return f"{a} <= {b}"
@staticmethod
def ge(a, b):
return f"{a} >= {b}"
@staticmethod
def neg(a):
return f"-{a}"
class MockHandler(BasicMathOps):
name = "MockHandler"
class MockHandler:
def __getattr__(self, name):
if name == "name":
return "MockHandler"
def inner(*args, **kwargs):
fargs = [_arg_str(a) for a in args]
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
@ -915,6 +809,41 @@ class MockHandler(BasicMathOps):
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
return sympy_index_symbol(str(index_var))
@classmethod
def _init_cls(cls):
def make_handler(format_string):
@staticmethod # type: ignore[misc]
def inner(*args):
return format_string.format(*args)
return inner
for name, format_string in {
"add": "{} + {}",
"sub": "{} - {}",
"mul": "{} * {}",
"floordiv": "{} // {}",
"truediv": "{} / {}",
"mod": "{} % {}", # careful, depending on target semantics varies
"pow": "{} ** {}",
"lshift": "{} << {}",
"rshift": "{} >> {}",
"and_": "{} & {}",
"or_": "{} | {}",
"xor": "{} ^ {}",
"eq": "{} == {}",
"ne": "{} != {}",
"lt": "{} < {}",
"gt": "{} > {}",
"le": "{} <= {}",
"ge": "{} >= {}",
"neg": "-{}",
}.items():
setattr(cls, name, make_handler(format_string))
MockHandler._init_cls()
# Use mypy to check protocol implemented correctly
def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
@ -994,7 +923,7 @@ def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[s
class WrapperHandler(Generic[T]):
def __init__(self, inner: Any):
def __init__(self, inner: OpsHandler[T]):
self._inner = inner
def __getattr__(self, item):