mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Refactor op handlers part 1 (#146235)
This enforces the invariant that every backend implements the same set of ops and removes a layer of indirection for BasicMathOps.
Interestingly this is a small compile time win:
```
...
WIN: benchmark ('add_loop_inductor', 'compile_time_instruction_count') failed, actual result 30151159301 is -6.13% lower than expected 32120000000 ±1.50% please update the expected results.
please update all results that changed significantly, and not only the failed ones
PASS: benchmark ('add_loop_inductor_dynamic_gpu', 'compile_time_instruction_count') pass, actual result 44447549162 -1.69% is within expected 45210000000 ±2.50%
WIN: benchmark ('add_loop_inductor_gpu', 'compile_time_instruction_count') failed, actual result 26743557195 is -2.25% lower than expected 27360000000 ±1.50% please update the expected results.
please update all results that changed significantly, and not only the failed ones
PASS: benchmark ('basic_modules_ListOfLinears_eager', 'compile_time_instruction_count') pass, actual result 945129734 +0.93% is within expected 936400000 ±1.50%
WIN: benchmark ('basic_modules_ListOfLinears_inductor', 'compile_time_instruction_count') failed, actual result 18984384503 is -3.19% lower than expected 19610000000 ±1.50% please update the expected results.
please update all results that changed significantly, and not only the failed ones
WIN: benchmark ('basic_modules_ListOfLinears_inductor_gpu_force_shape_pad', 'compile_time_instruction_count') failed, actual result 17258025389 is -1.94% lower than expected 17600000000 ±1.50% please update the expected results.
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146235
Approved by: https://github.com/shunting314
ghstack dependencies: #146225, #146226
This commit is contained in:
parent
18380ab877
commit
204be4e0a2
9 changed files with 397 additions and 159 deletions
|
|
@ -1,4 +1,4 @@
|
|||
add_loop_eager,compile_time_instruction_count,3066000000,0.015
|
||||
add_loop_eager,compile_time_instruction_count,3096000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -6,27 +6,27 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025
|
|||
|
||||
|
||||
|
||||
add_loop_inductor,compile_time_instruction_count,32120000000,0.015
|
||||
add_loop_inductor,compile_time_instruction_count,30150000000,0.015
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,45210000000,0.025
|
||||
add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025
|
||||
|
||||
|
||||
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,27360000000,0.015
|
||||
add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,936400000,0.015
|
||||
basic_modules_ListOfLinears_eager,compile_time_instruction_count,945100000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19610000000,0.015
|
||||
basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,0.015
|
||||
|
||||
|
||||
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17600000000,0.015
|
||||
basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -46,7 +46,7 @@ symint_sum,compile_time_instruction_count,3324000000,0.015
|
|||
|
||||
|
||||
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2018000000,0.015
|
||||
aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2028000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5836000000,0
|
|||
|
||||
|
||||
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9103000000,0.015
|
||||
aotdispatcher_partitioner_cpu,compile_time_instruction_count,9167000000,0.015
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
42
test/inductor/test_op_completeness.py
Normal file
42
test/inductor/test_op_completeness.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
import unittest
|
||||
|
||||
from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides
|
||||
from torch._inductor.codegen.halide import HalideOverrides
|
||||
from torch._inductor.codegen.mps import MetalOverrides
|
||||
from torch._inductor.codegen.triton import TritonKernelOverrides
|
||||
from torch._inductor.ops_handler import list_ops, OP_NAMES
|
||||
from torch._inductor.test_case import TestCase
|
||||
|
||||
|
||||
class TestOpCompleteness(TestCase):
|
||||
def verify_ops_handler_completeness(self, handler):
|
||||
op_names = list_ops(handler)
|
||||
if OP_NAMES == op_names:
|
||||
return
|
||||
print(f"Missing ops: {OP_NAMES - op_names}")
|
||||
print(f"Extra ops: {op_names - OP_NAMES}")
|
||||
self.assertEqual(", ".join(OP_NAMES - op_names), "")
|
||||
self.assertEqual(", ".join(op_names - OP_NAMES), "")
|
||||
|
||||
def test_triton_overrides(self):
|
||||
self.verify_ops_handler_completeness(TritonKernelOverrides)
|
||||
|
||||
def test_cpp_overrides(self):
|
||||
self.verify_ops_handler_completeness(CppOverrides)
|
||||
|
||||
def test_cpp_vec_overrides(self):
|
||||
self.verify_ops_handler_completeness(CppVecOverrides)
|
||||
|
||||
def test_halide_overrides(self):
|
||||
self.verify_ops_handler_completeness(HalideOverrides)
|
||||
|
||||
@unittest.skip("MPS backend not yet finished")
|
||||
def test_metal_overrides(self):
|
||||
self.verify_ops_handler_completeness(MetalOverrides)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
||||
run_tests()
|
||||
|
|
@ -41,6 +41,7 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val
|
|||
|
||||
from .. import config, metrics
|
||||
from ..dtype_propagation import DtypePropagationOpsHandler
|
||||
from ..ops_handler import BasicMathOps
|
||||
from ..utils import (
|
||||
boolean_ops,
|
||||
DeferredLineBase,
|
||||
|
|
@ -49,6 +50,7 @@ from ..utils import (
|
|||
ir_dataclass,
|
||||
ScopedDict,
|
||||
sympy_dot,
|
||||
sympy_index_symbol,
|
||||
sympy_subs,
|
||||
unique,
|
||||
)
|
||||
|
|
@ -761,11 +763,7 @@ def _all_in_parens(string: str) -> bool:
|
|||
return True
|
||||
|
||||
|
||||
class OpOverrides(OpDecompositions):
|
||||
def __init__(self, parent: OpsHandler[OpVarT]) -> None:
|
||||
super().__init__()
|
||||
self._parent = parent
|
||||
|
||||
class OpOverrides(BasicMathOps, OpDecompositions):
|
||||
@staticmethod
|
||||
def paren(string: OpVarT) -> OpVarT:
|
||||
if (
|
||||
|
|
@ -777,9 +775,6 @@ class OpOverrides(OpDecompositions):
|
|||
return string
|
||||
return f"({string})"
|
||||
|
||||
def __getattr__(self, item: str) -> Callable[..., Any]:
|
||||
return getattr(self._parent, item)
|
||||
|
||||
@staticmethod
|
||||
def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT:
|
||||
return repr(value)
|
||||
|
|
@ -852,15 +847,138 @@ class OpOverrides(OpDecompositions):
|
|||
def load_seed(name: str, offset: OpVarT) -> OpVarT:
|
||||
return ops.load(name, sympy.Integer(offset))
|
||||
|
||||
def indirect_indexing(
|
||||
self,
|
||||
var: OpVarT,
|
||||
size: Union[sympy.Expr, int],
|
||||
check: bool = True,
|
||||
wrap_neg: bool = True,
|
||||
) -> sympy.Symbol:
|
||||
return sympy_index_symbol(str(var))
|
||||
|
||||
def check_bounds(
|
||||
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
||||
) -> None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: check_bounds should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def load(self, name: str, index: sympy.Expr) -> OpVarT:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: load should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def store(
|
||||
self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None
|
||||
) -> None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: store should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: store_reduction should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def reduction(
|
||||
self,
|
||||
dtype: torch.dtype,
|
||||
src_dtype: torch.dtype,
|
||||
reduction_type: ReductionType,
|
||||
value: Union[OpVarT, tuple[OpVarT, ...]],
|
||||
) -> Union[OpVarT, tuple[OpVarT, ...]]:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: reduction should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def scan(
|
||||
self,
|
||||
dtypes: tuple[torch.dtype, ...],
|
||||
combine_fn: Callable[
|
||||
[tuple[OpVarT, ...], tuple[OpVarT, ...]],
|
||||
tuple[OpVarT, ...],
|
||||
],
|
||||
values: tuple[OpVarT, ...],
|
||||
) -> tuple[OpVarT, ...]:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: scan should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def sort(
|
||||
self,
|
||||
dtypes: tuple[torch.dtype, ...],
|
||||
values: tuple[OpVarT, ...],
|
||||
stable: bool,
|
||||
descending: bool,
|
||||
) -> tuple[OpVarT, ...]:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: sort should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def bucketize(
|
||||
self,
|
||||
values: OpVarT,
|
||||
boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr],
|
||||
boundary_indices: OpVarT,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[OpVarT] = None,
|
||||
) -> OpVarT:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: bucketize should be handled by CSEProxy"
|
||||
)
|
||||
|
||||
def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: halide_clamp only implemented for Halide backend"
|
||||
)
|
||||
|
||||
def inline_asm_elementwise(
|
||||
self,
|
||||
*inputs: OpVarT,
|
||||
asm: str,
|
||||
constraints: Optional[str] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
is_pure: bool = True,
|
||||
pack: int = 1,
|
||||
) -> OpVarT:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _unimplemented(name: str) -> Callable[..., OpVarT]:
|
||||
def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not implement ops.{name}"
|
||||
)
|
||||
|
||||
unimplemented.__name__ = name
|
||||
unimplemented.is_unimplemented = True # type: ignore[attr-defined]
|
||||
return unimplemented
|
||||
|
||||
@classmethod
|
||||
def _is_unimplemented(cls, name: str) -> bool:
|
||||
fn = getattr(cls, name, None)
|
||||
default_fn = getattr(OpsHandler, name, None)
|
||||
return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False)
|
||||
|
||||
@classmethod
|
||||
def _initialize_pointwise_overrides(cls, target: str) -> None:
|
||||
assert target in ("triton", "cpp", "cppvec"), target
|
||||
assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target
|
||||
|
||||
for funcname, data in pointwise_overrides_data.items():
|
||||
impl = getattr(data, target)
|
||||
if impl is None:
|
||||
continue
|
||||
setattr(cls, funcname, staticmethod(impl))
|
||||
if cls._is_unimplemented(funcname):
|
||||
setattr(cls, funcname, cls._unimplemented(funcname))
|
||||
else:
|
||||
assert (
|
||||
funcname not in cls.__dict__
|
||||
), f"multiple definitions of {funcname} on {cls.__name__}"
|
||||
impl.__name__ = funcname
|
||||
setattr(cls, funcname, staticmethod(impl))
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -874,6 +992,8 @@ class OverridesData:
|
|||
type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = (
|
||||
ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
halide: Optional[Callable[..., str]] = None
|
||||
mps: Optional[Callable[..., str]] = None
|
||||
|
||||
|
||||
# NB: if you add a new special function, don't forget to update
|
||||
|
|
@ -1104,9 +1224,10 @@ pointwise_overrides_data: dict[str, OverridesData] = dict(
|
|||
)
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[OpVarT]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class DeferredLine(DeferredLineBase):
|
||||
|
|
@ -1704,7 +1825,7 @@ class CodeGen:
|
|||
class Kernel(CodeGen, Generic[CSEVariableType]):
|
||||
newvar_prefix: str = ""
|
||||
suffix: str = ""
|
||||
overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None
|
||||
overrides: Optional[Callable[[], OpsHandler[Any]]] = None
|
||||
|
||||
def __init__(
|
||||
self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True
|
||||
|
|
@ -1891,8 +2012,9 @@ class Kernel(CodeGen, Generic[CSEVariableType]):
|
|||
def __enter__(self) -> typing.Self:
|
||||
super().__enter__()
|
||||
assert self.overrides
|
||||
parent_handler = self.overrides(V.get_ops_handler())
|
||||
self.exit_stack.enter_context(V.set_ops_handler(CSEProxy(self, parent_handler)))
|
||||
self.exit_stack.enter_context(
|
||||
V.set_ops_handler(CSEProxy(self, self.overrides()))
|
||||
)
|
||||
self.exit_stack.enter_context(V.set_kernel_handler(self))
|
||||
return self
|
||||
|
||||
|
|
|
|||
|
|
@ -1140,9 +1140,7 @@ class CppVecOverrides(CppOverrides):
|
|||
else:
|
||||
# fallback to scalar ops
|
||||
scalar_ops = super(CppVecOverrides, self)
|
||||
scalar_func = getattr(
|
||||
scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined]
|
||||
)
|
||||
scalar_func = getattr(scalar_ops, func.__name__)
|
||||
assert scalar_func is not None
|
||||
return scalar_func(*args, **kwargs)
|
||||
|
||||
|
|
@ -1646,8 +1644,11 @@ class CppVecOverrides(CppOverrides):
|
|||
assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
|
||||
body_vec_var.dtype = dtype
|
||||
other_vec_var.dtype = dtype
|
||||
overrides: type[
|
||||
Union[CppOverrides, CppVecOverrides]
|
||||
] = V.kernel.overrides # type: ignore[has-type]
|
||||
code.writeline(
|
||||
f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};"
|
||||
f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
|
||||
)
|
||||
code.writeline("()")
|
||||
csevar = V.kernel.cse.generate(
|
||||
|
|
@ -1748,7 +1749,7 @@ class CppVecOverrides(CppOverrides):
|
|||
return mantissa, exponent
|
||||
|
||||
@classmethod
|
||||
def scalarize(cls, scalar_func):
|
||||
def _scalarize(cls, scalar_func):
|
||||
def inner(*args, **kwargs):
|
||||
assert not kwargs
|
||||
kernel = V.kernel
|
||||
|
|
@ -1810,11 +1811,10 @@ class CppVecOverrides(CppOverrides):
|
|||
|
||||
@classmethod
|
||||
def _initialize_scalarize(cls):
|
||||
vec_vars = vars(CppVecOverrides)
|
||||
for name, method in vars(CppOverrides).items():
|
||||
if getattr(method, "__class__", None) == staticmethod and name not in vars(
|
||||
CppVecOverrides
|
||||
):
|
||||
func = cls.scalarize(method.__func__)
|
||||
if isinstance(method, staticmethod) and name not in vec_vars:
|
||||
func = cls._scalarize(method.__func__)
|
||||
func.__name__ = name
|
||||
setattr(cls, name, staticmethod(func))
|
||||
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ from .. import config, ir
|
|||
from ..codecache import HalideCodeCache
|
||||
from ..ir import get_reduction_combine_fn
|
||||
from ..metrics import is_metric_table_enabled, log_kernel_metadata
|
||||
from ..ops_handler import AddParenHandler, MockHandler
|
||||
from ..ops_handler import AddParenHandler
|
||||
from ..runtime.hints import HalideInputSpec, HalideMeta
|
||||
from ..utils import (
|
||||
get_bounds_index_expr,
|
||||
|
|
@ -555,10 +555,18 @@ class HalideOverrides(OpOverrides):
|
|||
# TODO(jansel): look into removing the where in the same places triton does
|
||||
return ops.where(new_mask, result, other)
|
||||
|
||||
@staticmethod
|
||||
def frexp(x):
|
||||
raise NotImplementedError("frexp")
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_HalideOverrides(h: HalideOverrides) -> OpsHandler[str]:
|
||||
return h
|
||||
|
||||
HalideOverrides._initialize_pointwise_overrides("halide")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class HalideCSEVariable(CSEVariable):
|
||||
|
|
@ -1217,7 +1225,7 @@ class HalideKernel(SIMDKernel):
|
|||
result_var = self.welford_reduce_fallback(dtype, value)
|
||||
else:
|
||||
combine_fn = get_reduction_combine_fn(reduction_type, acc_type)
|
||||
with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))):
|
||||
with V.set_ops_handler(AddParenHandler(HalideOverrides())):
|
||||
combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type]
|
||||
default_str = f"hl.cast({acc_type}, {halide_constant(default)})"
|
||||
self.body.writeline(f"{result_var} = {default_str}")
|
||||
|
|
@ -1333,7 +1341,7 @@ class HalideKernel(SIMDKernel):
|
|||
self.body.writeline(f"{result_var} = {maybe_tuple(initial)}")
|
||||
|
||||
# Disable CSE for update fn
|
||||
with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))):
|
||||
with V.set_ops_handler(AddParenHandler(HalideOverrides())):
|
||||
combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type]
|
||||
self.body.writeline(
|
||||
f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}"
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||
|
||||
import sympy
|
||||
|
||||
from ..ops_handler import StoreMode
|
||||
from ..ops_handler import OpsHandler, StoreMode
|
||||
from ..scheduler import Scheduler, SchedulerNode
|
||||
from .common import OpVarT
|
||||
|
||||
|
|
@ -357,6 +357,15 @@ class MetalOverrides(OpOverrides):
|
|||
return f"metal::pow({cast_a}, {cast_b})"
|
||||
|
||||
|
||||
MetalOverrides._initialize_pointwise_overrides("mps")
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class MetalKernel(SIMDKernel):
|
||||
overrides = MetalOverrides # type: ignore[assignment]
|
||||
suffix = ";"
|
||||
|
|
|
|||
|
|
@ -1283,11 +1283,6 @@ class TritonOverrides(OpOverrides):
|
|||
TritonOverrides._initialize_pointwise_overrides("triton")
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]:
|
||||
return h
|
||||
|
||||
|
||||
class TritonKernelOverrides(TritonOverrides):
|
||||
"""Map element-wise ops to Triton within a TritonKernel
|
||||
|
||||
|
|
@ -1412,9 +1407,10 @@ class TritonKernelOverrides(TritonOverrides):
|
|||
return (mantissa, exponent)
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class HelperFunctions:
|
||||
|
|
@ -2812,7 +2808,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args))
|
||||
helper.writeline(f"def {{name}}({signature}):")
|
||||
|
||||
overrides = TritonOverrides(V.MockHandler())
|
||||
overrides = TritonOverrides()
|
||||
|
||||
# Build a name that changes depending on fn to workaround a triton bug
|
||||
# where the combine_fn to reduce and scan is not hashed, and so different
|
||||
|
|
|
|||
|
|
@ -1,22 +1,17 @@
|
|||
# mypy: allow-untyped-defs
|
||||
import functools
|
||||
from collections.abc import Sequence
|
||||
from typing import Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
import sympy
|
||||
|
||||
import torch
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch._inductor.loop_body import LoopBodyBlock
|
||||
|
||||
import torch
|
||||
from torch._inductor.virtualized import V
|
||||
from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype
|
||||
|
||||
from .ops_handler import OP_NAMES, OpsHandler
|
||||
from .utils import upcast_compute_type
|
||||
from .virtualized import OpsValue
|
||||
from .virtualized import OpsValue, V
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
|
@ -128,10 +123,7 @@ class DtypePropagationOpsHandler:
|
|||
self, op, functools.partial(self.return_dtype, dtype=torch.bool)
|
||||
)
|
||||
|
||||
from torch._inductor.ops_handler import OpsHandler
|
||||
|
||||
ops_set = OrderedSet(s for s in dir(OpsHandler) if s[0] != "_")
|
||||
unimplemented_ops = ops_set - OrderedSet(dir(self))
|
||||
unimplemented_ops = OP_NAMES - OrderedSet(dir(self))
|
||||
torch._check(
|
||||
len(unimplemented_ops) == 0,
|
||||
lambda: f"Unimplemented dtype rule for ops: {unimplemented_ops}",
|
||||
|
|
@ -164,7 +156,12 @@ class DtypePropagationOpsHandler:
|
|||
return torch.int64
|
||||
|
||||
@staticmethod
|
||||
def masked(mask: DTypeArg, body: "LoopBodyBlock", other: DTypeArg) -> torch.dtype:
|
||||
def masked(
|
||||
mask: DTypeArg, body: Callable[[], DTypeArg], other: DTypeArg
|
||||
) -> torch.dtype:
|
||||
from .loop_body import LoopBodyBlock
|
||||
|
||||
assert isinstance(body, LoopBodyBlock), "body must be a LoopBodyBlock"
|
||||
# TODO - we avoid calling this in codegen, needs work for non codegen use cases
|
||||
loads = body.graph.find_nodes(op="call_method", target="load")
|
||||
if len(loads) <= 1:
|
||||
|
|
@ -210,10 +207,6 @@ class DtypePropagationOpsHandler:
|
|||
def mul(a: DTypeArg, b: DTypeArg) -> torch.dtype:
|
||||
return promote_types([a, b])
|
||||
|
||||
@staticmethod
|
||||
def div(a: DTypeArg, b: DTypeArg) -> torch.dtype:
|
||||
return promote_types([a, b])
|
||||
|
||||
@staticmethod
|
||||
def truediv(a: DTypeArg, b: DTypeArg) -> torch.dtype:
|
||||
return promote_types([a, b])
|
||||
|
|
@ -319,6 +312,8 @@ class DtypePropagationOpsHandler:
|
|||
boundary_indices: DTypeArg,
|
||||
indexing_dtype: torch.dtype,
|
||||
right: bool,
|
||||
sorter: Optional[tuple[str, sympy.Expr]] = None,
|
||||
sorter_indices: Optional[T] = None,
|
||||
) -> torch.dtype:
|
||||
return indexing_dtype
|
||||
|
||||
|
|
@ -332,10 +327,6 @@ class DtypePropagationOpsHandler:
|
|||
[x], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def getitem(x: DTypeArg, y: DTypeArg) -> torch.dtype:
|
||||
raise RuntimeError("Unexpected op: getitem")
|
||||
|
||||
@staticmethod
|
||||
def trunc_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype:
|
||||
return dtype
|
||||
|
|
@ -352,11 +343,6 @@ class DtypePropagationOpsHandler:
|
|||
def floordiv(x: DTypeArg, y: DTypeArg) -> torch.dtype:
|
||||
return promote_types([x, y])
|
||||
|
||||
@staticmethod
|
||||
def round_decimal(x: DTypeArg, y: DTypeArg) -> torch.dtype:
|
||||
# TODO - dont see it anywhere..
|
||||
return promote_types([x])
|
||||
|
||||
@staticmethod
|
||||
def halide_clamp(value, size, check):
|
||||
# TODO - way of registering dtype for op in backend
|
||||
|
|
@ -377,9 +363,13 @@ class DtypePropagationOpsHandler:
|
|||
return promote_types([x])
|
||||
|
||||
@staticmethod
|
||||
def invert(x: DTypeArg) -> torch.dtype:
|
||||
raise RuntimeError("Unexpected op: invert")
|
||||
def check_bounds(
|
||||
expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
||||
) -> None:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def matmul(x: DTypeArg, y: DTypeArg) -> torch.dtype:
|
||||
raise RuntimeError("Unexpected op: matmul")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_DtypePropagation(DtypePropagationOpsHandler, OpsHandler[Any]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
import re
|
||||
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
||||
from typing_extensions import Protocol
|
||||
from unittest.mock import patch
|
||||
|
|
@ -65,16 +66,29 @@ class OpsHandler(Protocol[T]):
|
|||
Note that this often describes a class of static methods, for stateless
|
||||
ops handlers.
|
||||
|
||||
Handlers are often defined using ``__getattr__`` metaprogramming, which means
|
||||
that you cannot declare that a type implements a protocol by inheriting from
|
||||
it (as the type stubs count as attribute declarations and impede the getattr
|
||||
magic method from being called). Instead, define a function that casts an
|
||||
argument of your type to the protocol, which is sufficient to induce mypy to
|
||||
test that the protocol is implemented correctly. Search for ``_typecheck_``
|
||||
in this file to see some examples. If you see an obscure error where a
|
||||
class doesn't implement a Protocol, but mypy doesn't say why, check to see
|
||||
that ``__getattr__`` is typed correctly (typically, it is not possible to
|
||||
type ``__getattr__`` without typing it as ``Callable[..., Any]``)
|
||||
Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides),
|
||||
which means you will get type errors if you subclass OpsHandler since mypy doesn't know
|
||||
about the methods added via metaprogramming and thinks the class is still abstract.
|
||||
Instead, you should add a block like:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
Which will check the signatures of non-meta-programmed methods and gives decent error messages.
|
||||
|
||||
Some older parts of the code use a pattern like:
|
||||
|
||||
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
|
||||
return h
|
||||
|
||||
This pattern only works if the class defines a __getattr__ method, which we are moving away from.
|
||||
Additionally, this pattern generates horrible error messages if the signatures are wrong.
|
||||
It gives zero information about what the problem is, which makes the pattern harmful.
|
||||
|
||||
Instead of that, we have tests in test/inductor/test_op_completeness.py which check that all
|
||||
operators are implemented after all the metaprogramming has run.
|
||||
"""
|
||||
|
||||
def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T:
|
||||
|
|
@ -517,17 +531,6 @@ class OpsHandler(Protocol[T]):
|
|||
def rshift(self, x0: T, x1: T) -> T:
|
||||
...
|
||||
|
||||
def getitem(self, x0: T, x1: T) -> T:
|
||||
# TODO: this is probably just illegal lol
|
||||
...
|
||||
|
||||
def matmul(self, x0: T, x1: T) -> T:
|
||||
# TODO: this is probably just illegal lol
|
||||
...
|
||||
|
||||
def invert(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
# These are "special" operators. These only exist if the target
|
||||
# language actually supports the operator. Keep this in sync with
|
||||
|
|
@ -683,11 +686,6 @@ class OpsHandler(Protocol[T]):
|
|||
"""
|
||||
...
|
||||
|
||||
def div(self, x0: T, x1: T) -> T:
|
||||
"""TODO: to be removed. This renders as / no matter what the backend is
|
||||
which is incoherent."""
|
||||
...
|
||||
|
||||
def mod(self, x0: T, x1: T) -> T:
|
||||
"""C-style modulus, take sign from LHS (x0)."""
|
||||
...
|
||||
|
|
@ -696,8 +694,12 @@ class OpsHandler(Protocol[T]):
|
|||
"""Python-style modulus, take sign from RHS (x1)."""
|
||||
...
|
||||
|
||||
def round_decimal(self, x0: T, x1: T) -> T:
|
||||
"""Python-style round with decimal argument"""
|
||||
def square(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
def check_bounds(
|
||||
self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool
|
||||
) -> None:
|
||||
...
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
||||
|
|
@ -734,16 +736,42 @@ class OpsHandler(Protocol[T]):
|
|||
def libdevice_log(self, x0: T) -> T:
|
||||
...
|
||||
|
||||
# halide-only
|
||||
def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T:
|
||||
raise NotImplementedError
|
||||
|
||||
# triton-only
|
||||
def inline_asm_elementwise(
|
||||
self,
|
||||
*inputs: T,
|
||||
asm: str,
|
||||
constraints: Optional[str] = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
is_pure: bool = True,
|
||||
pack: int = 1,
|
||||
) -> T:
|
||||
...
|
||||
|
||||
|
||||
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
|
||||
|
||||
|
||||
def list_ops(cls: type[Any]):
|
||||
return OrderedSet([x for x in dir(cls) if not _ignore_op_re(x)])
|
||||
|
||||
|
||||
OP_NAMES = list_ops(OpsHandler)
|
||||
|
||||
|
||||
def _return_none(*args, **kwargs):
|
||||
return None
|
||||
|
||||
|
||||
class NoopHandler:
|
||||
def __getattr__(self, name):
|
||||
if name == "name":
|
||||
return "NoopHandler"
|
||||
name = "NoopHandler"
|
||||
|
||||
def inner(*args, **kwargs):
|
||||
return None
|
||||
|
||||
return inner
|
||||
def __getattr__(self, name: str) -> Callable[..., None]:
|
||||
return _return_none
|
||||
|
||||
@staticmethod
|
||||
def masked(mask, body, other) -> None:
|
||||
|
|
@ -771,11 +799,89 @@ def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]:
|
|||
return h
|
||||
|
||||
|
||||
class MockHandler:
|
||||
def __getattr__(self, name):
|
||||
if name == "name":
|
||||
return "MockHandler"
|
||||
class BasicMathOps:
|
||||
@staticmethod
|
||||
def add(a, b):
|
||||
return f"{a} + {b}"
|
||||
|
||||
@staticmethod
|
||||
def sub(a, b):
|
||||
return f"{a} - {b}"
|
||||
|
||||
@staticmethod
|
||||
def mul(a, b):
|
||||
return f"{a} * {b}"
|
||||
|
||||
@staticmethod
|
||||
def floordiv(a, b):
|
||||
return f"{a} // {b}"
|
||||
|
||||
@staticmethod
|
||||
def truediv(a, b):
|
||||
return f"{a} / {b}"
|
||||
|
||||
@staticmethod
|
||||
def mod(a, b):
|
||||
# careful, depending on target semantics varies
|
||||
return f"{a} % {b}"
|
||||
|
||||
@staticmethod
|
||||
def pow(a, b):
|
||||
return f"{a} ** {b}"
|
||||
|
||||
@staticmethod
|
||||
def lshift(a, b):
|
||||
return f"{a} << {b}"
|
||||
|
||||
@staticmethod
|
||||
def rshift(a, b):
|
||||
return f"{a} >> {b}"
|
||||
|
||||
@staticmethod
|
||||
def and_(a, b):
|
||||
return f"{a} & {b}"
|
||||
|
||||
@staticmethod
|
||||
def or_(a, b):
|
||||
return f"{a} | {b}"
|
||||
|
||||
@staticmethod
|
||||
def xor(a, b):
|
||||
return f"{a} ^ {b}"
|
||||
|
||||
@staticmethod
|
||||
def eq(a, b):
|
||||
return f"{a} == {b}"
|
||||
|
||||
@staticmethod
|
||||
def ne(a, b):
|
||||
return f"{a} != {b}"
|
||||
|
||||
@staticmethod
|
||||
def lt(a, b):
|
||||
return f"{a} < {b}"
|
||||
|
||||
@staticmethod
|
||||
def gt(a, b):
|
||||
return f"{a} > {b}"
|
||||
|
||||
@staticmethod
|
||||
def le(a, b):
|
||||
return f"{a} <= {b}"
|
||||
|
||||
@staticmethod
|
||||
def ge(a, b):
|
||||
return f"{a} >= {b}"
|
||||
|
||||
@staticmethod
|
||||
def neg(a):
|
||||
return f"-{a}"
|
||||
|
||||
|
||||
class MockHandler(BasicMathOps):
|
||||
name = "MockHandler"
|
||||
|
||||
def __getattr__(self, name):
|
||||
def inner(*args, **kwargs):
|
||||
fargs = [_arg_str(a) for a in args]
|
||||
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
|
||||
|
|
@ -809,41 +915,6 @@ class MockHandler:
|
|||
def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol:
|
||||
return sympy_index_symbol(str(index_var))
|
||||
|
||||
@classmethod
|
||||
def _init_cls(cls):
|
||||
def make_handler(format_string):
|
||||
@staticmethod # type: ignore[misc]
|
||||
def inner(*args):
|
||||
return format_string.format(*args)
|
||||
|
||||
return inner
|
||||
|
||||
for name, format_string in {
|
||||
"add": "{} + {}",
|
||||
"sub": "{} - {}",
|
||||
"mul": "{} * {}",
|
||||
"floordiv": "{} // {}",
|
||||
"truediv": "{} / {}",
|
||||
"mod": "{} % {}", # careful, depending on target semantics varies
|
||||
"pow": "{} ** {}",
|
||||
"lshift": "{} << {}",
|
||||
"rshift": "{} >> {}",
|
||||
"and_": "{} & {}",
|
||||
"or_": "{} | {}",
|
||||
"xor": "{} ^ {}",
|
||||
"eq": "{} == {}",
|
||||
"ne": "{} != {}",
|
||||
"lt": "{} < {}",
|
||||
"gt": "{} > {}",
|
||||
"le": "{} <= {}",
|
||||
"ge": "{} >= {}",
|
||||
"neg": "-{}",
|
||||
}.items():
|
||||
setattr(cls, name, make_handler(format_string))
|
||||
|
||||
|
||||
MockHandler._init_cls()
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
|
||||
|
|
@ -923,7 +994,7 @@ def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[s
|
|||
|
||||
|
||||
class WrapperHandler(Generic[T]):
|
||||
def __init__(self, inner: OpsHandler[T]):
|
||||
def __init__(self, inner: Any):
|
||||
self._inner = inner
|
||||
|
||||
def __getattr__(self, item):
|
||||
|
|
|
|||
Loading…
Reference in a new issue