From 2f40f789dafeaa62c4e4b90dbf4a900ff6da2ca4 Mon Sep 17 00:00:00 2001 From: PyTorch MergeBot Date: Tue, 4 Feb 2025 00:00:08 +0000 Subject: [PATCH] Revert "[inductor] Refactor op handlers part 1 (#146235)" This reverts commit 204be4e0a2e4509bd2457bfb295c429dd92c241f. Reverted https://github.com/pytorch/pytorch/pull/146235 on behalf of https://github.com/atalman due to Breaks lint, sorry: Definition of polygamma in base class MetalOverrides is incompatible with definition in base class OpsHandler. Please rebase fix lint and reland ([comment](https://github.com/pytorch/pytorch/pull/146235#issuecomment-2632444514)) --- .../pr_time_benchmarks/expected_results.csv | 18 +- test/inductor/test_op_completeness.py | 42 ---- torch/_inductor/codegen/common.py | 156 ++----------- torch/_inductor/codegen/cpp.py | 18 +- torch/_inductor/codegen/halide.py | 20 +- torch/_inductor/codegen/mps.py | 11 +- torch/_inductor/codegen/triton.py | 14 +- torch/_inductor/dtype_propagation.py | 56 +++-- torch/_inductor/ops_handler.py | 221 ++++++------------ 9 files changed, 159 insertions(+), 397 deletions(-) delete mode 100644 test/inductor/test_op_completeness.py diff --git a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv index 388b8d1a5f6..11a2117a644 100644 --- a/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv +++ b/benchmarks/dynamo/pr_time_benchmarks/expected_results.csv @@ -1,4 +1,4 @@ -add_loop_eager,compile_time_instruction_count,3096000000,0.015 +add_loop_eager,compile_time_instruction_count,3066000000,0.015 @@ -6,27 +6,27 @@ add_loop_eager_dynamic,compile_time_instruction_count,5703000000,0.025 -add_loop_inductor,compile_time_instruction_count,30150000000,0.015 +add_loop_inductor,compile_time_instruction_count,32120000000,0.015 -add_loop_inductor_dynamic_gpu,compile_time_instruction_count,44440000000,0.025 +add_loop_inductor_dynamic_gpu,compile_time_instruction_count,45210000000,0.025 -add_loop_inductor_gpu,compile_time_instruction_count,26740000000,0.015 +add_loop_inductor_gpu,compile_time_instruction_count,27360000000,0.015 -basic_modules_ListOfLinears_eager,compile_time_instruction_count,945100000,0.015 +basic_modules_ListOfLinears_eager,compile_time_instruction_count,936400000,0.015 -basic_modules_ListOfLinears_inductor,compile_time_instruction_count,18980000000,0.015 +basic_modules_ListOfLinears_inductor,compile_time_instruction_count,19610000000,0.015 -basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17250000000,0.015 +basic_modules_ListOfLinears_inductor_gpu_force_shape_pad,compile_time_instruction_count,17600000000,0.015 @@ -46,7 +46,7 @@ symint_sum,compile_time_instruction_count,3324000000,0.015 -aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2028000000,0.015 +aotdispatcher_inference_nosubclass_cpu,compile_time_instruction_count,2018000000,0.015 @@ -54,7 +54,7 @@ aotdispatcher_inference_subclass_cpu,compile_time_instruction_count,5836000000,0 -aotdispatcher_partitioner_cpu,compile_time_instruction_count,9167000000,0.015 +aotdispatcher_partitioner_cpu,compile_time_instruction_count,9103000000,0.015 diff --git a/test/inductor/test_op_completeness.py b/test/inductor/test_op_completeness.py deleted file mode 100644 index 04fac4870fd..00000000000 --- a/test/inductor/test_op_completeness.py +++ /dev/null @@ -1,42 +0,0 @@ -# Owner(s): ["module: inductor"] -import unittest - -from torch._inductor.codegen.cpp import CppOverrides, CppVecOverrides -from torch._inductor.codegen.halide import HalideOverrides -from torch._inductor.codegen.mps import MetalOverrides -from torch._inductor.codegen.triton import TritonKernelOverrides -from torch._inductor.ops_handler import list_ops, OP_NAMES -from torch._inductor.test_case import TestCase - - -class TestOpCompleteness(TestCase): - def verify_ops_handler_completeness(self, handler): - op_names = list_ops(handler) - if OP_NAMES == op_names: - return - print(f"Missing ops: {OP_NAMES - op_names}") - print(f"Extra ops: {op_names - OP_NAMES}") - self.assertEqual(", ".join(OP_NAMES - op_names), "") - self.assertEqual(", ".join(op_names - OP_NAMES), "") - - def test_triton_overrides(self): - self.verify_ops_handler_completeness(TritonKernelOverrides) - - def test_cpp_overrides(self): - self.verify_ops_handler_completeness(CppOverrides) - - def test_cpp_vec_overrides(self): - self.verify_ops_handler_completeness(CppVecOverrides) - - def test_halide_overrides(self): - self.verify_ops_handler_completeness(HalideOverrides) - - @unittest.skip("MPS backend not yet finished") - def test_metal_overrides(self): - self.verify_ops_handler_completeness(MetalOverrides) - - -if __name__ == "__main__": - from torch._inductor.test_case import run_tests - - run_tests() diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 85ed6958b99..a883ebc4bc6 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -41,7 +41,6 @@ from torch.utils._sympy.value_ranges import bound_sympy, ValueRangeAnalysis, Val from .. import config, metrics from ..dtype_propagation import DtypePropagationOpsHandler -from ..ops_handler import BasicMathOps from ..utils import ( boolean_ops, DeferredLineBase, @@ -50,7 +49,6 @@ from ..utils import ( ir_dataclass, ScopedDict, sympy_dot, - sympy_index_symbol, sympy_subs, unique, ) @@ -763,7 +761,11 @@ def _all_in_parens(string: str) -> bool: return True -class OpOverrides(BasicMathOps, OpDecompositions): +class OpOverrides(OpDecompositions): + def __init__(self, parent: OpsHandler[OpVarT]) -> None: + super().__init__() + self._parent = parent + @staticmethod def paren(string: OpVarT) -> OpVarT: if ( @@ -775,6 +777,9 @@ class OpOverrides(BasicMathOps, OpDecompositions): return string return f"({string})" + def __getattr__(self, item: str) -> Callable[..., Any]: + return getattr(self._parent, item) + @staticmethod def constant(value: Union[bool, float, int], dtype: torch.dtype) -> OpVarT: return repr(value) @@ -847,138 +852,15 @@ class OpOverrides(BasicMathOps, OpDecompositions): def load_seed(name: str, offset: OpVarT) -> OpVarT: return ops.load(name, sympy.Integer(offset)) - def indirect_indexing( - self, - var: OpVarT, - size: Union[sympy.Expr, int], - check: bool = True, - wrap_neg: bool = True, - ) -> sympy.Symbol: - return sympy_index_symbol(str(var)) - - def check_bounds( - self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool - ) -> None: - raise NotImplementedError( - f"{type(self).__name__}: check_bounds should be handled by CSEProxy" - ) - - def load(self, name: str, index: sympy.Expr) -> OpVarT: - raise NotImplementedError( - f"{type(self).__name__}: load should be handled by CSEProxy" - ) - - def store( - self, name: str, index: sympy.Expr, value: OpVarT, mode: StoreMode = None - ) -> None: - raise NotImplementedError( - f"{type(self).__name__}: store should be handled by CSEProxy" - ) - - def store_reduction(self, name: str, index: sympy.Expr, value: OpVarT) -> None: - raise NotImplementedError( - f"{type(self).__name__}: store_reduction should be handled by CSEProxy" - ) - - def reduction( - self, - dtype: torch.dtype, - src_dtype: torch.dtype, - reduction_type: ReductionType, - value: Union[OpVarT, tuple[OpVarT, ...]], - ) -> Union[OpVarT, tuple[OpVarT, ...]]: - raise NotImplementedError( - f"{type(self).__name__}: reduction should be handled by CSEProxy" - ) - - def scan( - self, - dtypes: tuple[torch.dtype, ...], - combine_fn: Callable[ - [tuple[OpVarT, ...], tuple[OpVarT, ...]], - tuple[OpVarT, ...], - ], - values: tuple[OpVarT, ...], - ) -> tuple[OpVarT, ...]: - raise NotImplementedError( - f"{type(self).__name__}: scan should be handled by CSEProxy" - ) - - def sort( - self, - dtypes: tuple[torch.dtype, ...], - values: tuple[OpVarT, ...], - stable: bool, - descending: bool, - ) -> tuple[OpVarT, ...]: - raise NotImplementedError( - f"{type(self).__name__}: sort should be handled by CSEProxy" - ) - - def bucketize( - self, - values: OpVarT, - boundaries: tuple[str, sympy.Expr, sympy.Expr, sympy.Expr], - boundary_indices: OpVarT, - indexing_dtype: torch.dtype, - right: bool, - sorter: Optional[tuple[str, sympy.Expr]] = None, - sorter_indices: Optional[OpVarT] = None, - ) -> OpVarT: - raise NotImplementedError( - f"{type(self).__name__}: bucketize should be handled by CSEProxy" - ) - - def halide_clamp(self, value: OpVarT, size: sympy.Expr, check: bool) -> OpVarT: - raise NotImplementedError( - f"{type(self).__name__}: halide_clamp only implemented for Halide backend" - ) - - def inline_asm_elementwise( - self, - *inputs: OpVarT, - asm: str, - constraints: Optional[str] = None, - dtype: torch.dtype = torch.float32, - is_pure: bool = True, - pack: int = 1, - ) -> OpVarT: - raise NotImplementedError( - f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" - ) - - @staticmethod - def _unimplemented(name: str) -> Callable[..., OpVarT]: - def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT: - raise NotImplementedError( - f"{type(self).__name__} does not implement ops.{name}" - ) - - unimplemented.__name__ = name - unimplemented.is_unimplemented = True # type: ignore[attr-defined] - return unimplemented - - @classmethod - def _is_unimplemented(cls, name: str) -> bool: - fn = getattr(cls, name, None) - default_fn = getattr(OpsHandler, name, None) - return not fn or fn == default_fn or getattr(fn, "is_unimplemented", False) - @classmethod def _initialize_pointwise_overrides(cls, target: str) -> None: - assert target in ("triton", "cpp", "cppvec", "halide", "mps"), target + assert target in ("triton", "cpp", "cppvec"), target for funcname, data in pointwise_overrides_data.items(): impl = getattr(data, target) if impl is None: - if cls._is_unimplemented(funcname): - setattr(cls, funcname, cls._unimplemented(funcname)) - else: - assert ( - funcname not in cls.__dict__ - ), f"multiple definitions of {funcname} on {cls.__name__}" - impl.__name__ = funcname - setattr(cls, funcname, staticmethod(impl)) + continue + setattr(cls, funcname, staticmethod(impl)) @dataclasses.dataclass @@ -992,8 +874,6 @@ class OverridesData: type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND = ( ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT ) - halide: Optional[Callable[..., str]] = None - mps: Optional[Callable[..., str]] = None # NB: if you add a new special function, don't forget to update @@ -1224,10 +1104,9 @@ pointwise_overrides_data: dict[str, OverridesData] = dict( ) -if TYPE_CHECKING: - - class _typecheck_OpOverrides(OpOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong +# Use mypy to check protocol implemented correctly +def _typecheck_OpOverrides(h: OpOverrides) -> OpsHandler[OpVarT]: + return h class DeferredLine(DeferredLineBase): @@ -1825,7 +1704,7 @@ class CodeGen: class Kernel(CodeGen, Generic[CSEVariableType]): newvar_prefix: str = "" suffix: str = "" - overrides: Optional[Callable[[], OpsHandler[Any]]] = None + overrides: Optional[Callable[[OpsHandler[Any]], OpsHandler[Any]]] = None def __init__( self, args: Optional[KernelArgs] = None, increase_kernel_count: bool = True @@ -2012,9 +1891,8 @@ class Kernel(CodeGen, Generic[CSEVariableType]): def __enter__(self) -> typing.Self: super().__enter__() assert self.overrides - self.exit_stack.enter_context( - V.set_ops_handler(CSEProxy(self, self.overrides())) - ) + parent_handler = self.overrides(V.get_ops_handler()) + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy(self, parent_handler))) self.exit_stack.enter_context(V.set_kernel_handler(self)) return self diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index b7c06fac902..d21414d7e24 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -1140,7 +1140,9 @@ class CppVecOverrides(CppOverrides): else: # fallback to scalar ops scalar_ops = super(CppVecOverrides, self) - scalar_func = getattr(scalar_ops, func.__name__) + scalar_func = getattr( + scalar_ops, func.__name__, scalar_ops.__getattr__(func.__name__) # type: ignore[attr-defined] + ) assert scalar_func is not None return scalar_func(*args, **kwargs) @@ -1644,11 +1646,8 @@ class CppVecOverrides(CppOverrides): assert isinstance(other_vec_var, CppCSEVariable), other_vec_var body_vec_var.dtype = dtype other_vec_var.dtype = dtype - overrides: type[ - Union[CppOverrides, CppVecOverrides] - ] = V.kernel.overrides # type: ignore[has-type] code.writeline( - f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};" + f"return {V.kernel.overrides.where(new_mask, body_vec_var, other_vec_var)};" ) code.writeline("()") csevar = V.kernel.cse.generate( @@ -1749,7 +1748,7 @@ class CppVecOverrides(CppOverrides): return mantissa, exponent @classmethod - def _scalarize(cls, scalar_func): + def scalarize(cls, scalar_func): def inner(*args, **kwargs): assert not kwargs kernel = V.kernel @@ -1811,10 +1810,11 @@ class CppVecOverrides(CppOverrides): @classmethod def _initialize_scalarize(cls): - vec_vars = vars(CppVecOverrides) for name, method in vars(CppOverrides).items(): - if isinstance(method, staticmethod) and name not in vec_vars: - func = cls._scalarize(method.__func__) + if getattr(method, "__class__", None) == staticmethod and name not in vars( + CppVecOverrides + ): + func = cls.scalarize(method.__func__) func.__name__ = name setattr(cls, name, staticmethod(func)) diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 560e75c648f..f09bb94f2a0 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -24,7 +24,7 @@ from .. import config, ir from ..codecache import HalideCodeCache from ..ir import get_reduction_combine_fn from ..metrics import is_metric_table_enabled, log_kernel_metadata -from ..ops_handler import AddParenHandler +from ..ops_handler import AddParenHandler, MockHandler from ..runtime.hints import HalideInputSpec, HalideMeta from ..utils import ( get_bounds_index_expr, @@ -555,18 +555,10 @@ class HalideOverrides(OpOverrides): # TODO(jansel): look into removing the where in the same places triton does return ops.where(new_mask, result, other) - @staticmethod - def frexp(x): - raise NotImplementedError("frexp") - -HalideOverrides._initialize_pointwise_overrides("halide") - - -if TYPE_CHECKING: - - class _typecheck_HalideOverrides(HalideOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong +# Use mypy to check protocol implemented correctly +def _typecheck_HalideOverrides(h: HalideOverrides) -> OpsHandler[str]: + return h class HalideCSEVariable(CSEVariable): @@ -1225,7 +1217,7 @@ class HalideKernel(SIMDKernel): result_var = self.welford_reduce_fallback(dtype, value) else: combine_fn = get_reduction_combine_fn(reduction_type, acc_type) - with V.set_ops_handler(AddParenHandler(HalideOverrides())): + with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))): combine_str = combine_fn(result_var, value_str) # type: ignore[arg-type] default_str = f"hl.cast({acc_type}, {halide_constant(default)})" self.body.writeline(f"{result_var} = {default_str}") @@ -1341,7 +1333,7 @@ class HalideKernel(SIMDKernel): self.body.writeline(f"{result_var} = {maybe_tuple(initial)}") # Disable CSE for update fn - with V.set_ops_handler(AddParenHandler(HalideOverrides())): + with V.set_ops_handler(AddParenHandler(HalideOverrides(MockHandler()))): combine_str = combine_fn(read_left, read_right) # type: ignore[arg-type] self.body.writeline( f"{result_var.subs_str(scan_renames_cur)} = {maybe_tuple(combine_str)}" diff --git a/torch/_inductor/codegen/mps.py b/torch/_inductor/codegen/mps.py index a93caacd09a..63b4d2c17b3 100644 --- a/torch/_inductor/codegen/mps.py +++ b/torch/_inductor/codegen/mps.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: import sympy - from ..ops_handler import OpsHandler, StoreMode + from ..ops_handler import StoreMode from ..scheduler import Scheduler, SchedulerNode from .common import OpVarT @@ -357,15 +357,6 @@ class MetalOverrides(OpOverrides): return f"metal::pow({cast_a}, {cast_b})" -MetalOverrides._initialize_pointwise_overrides("mps") - - -if TYPE_CHECKING: - - class _typecheck_MetalOverrides(MetalOverrides, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong - - class MetalKernel(SIMDKernel): overrides = MetalOverrides # type: ignore[assignment] suffix = ";" diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 70a89e2e355..3b30cd40a09 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -1283,6 +1283,11 @@ class TritonOverrides(OpOverrides): TritonOverrides._initialize_pointwise_overrides("triton") +# Use mypy to check protocol implemented correctly +def _typecheck_TritonOverrides(h: TritonOverrides) -> OpsHandler[str]: + return h + + class TritonKernelOverrides(TritonOverrides): """Map element-wise ops to Triton within a TritonKernel @@ -1407,10 +1412,9 @@ class TritonKernelOverrides(TritonOverrides): return (mantissa, exponent) -if TYPE_CHECKING: - - class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong +# Use mypy to check protocol implemented correctly +def _typecheck_TritonKernelOverrides(h: TritonKernelOverrides) -> OpsHandler[str]: + return h class HelperFunctions: @@ -2808,7 +2812,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]): signature = ", ".join(str(x) for x in itertools.chain.from_iterable(args)) helper.writeline(f"def {{name}}({signature}):") - overrides = TritonOverrides() + overrides = TritonOverrides(V.MockHandler()) # Build a name that changes depending on fn to workaround a triton bug # where the combine_fn to reduce and scan is not hashed, and so different diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index efe0ebe2caf..45948ce6919 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -1,17 +1,22 @@ # mypy: allow-untyped-defs import functools from collections.abc import Sequence -from typing import Any, Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union +from typing import Callable, Optional, Protocol, TYPE_CHECKING, TypeVar, Union import sympy -import torch -from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype from torch.utils._ordered_set import OrderedSet -from .ops_handler import OP_NAMES, OpsHandler + +if TYPE_CHECKING: + from torch._inductor.loop_body import LoopBodyBlock + +import torch +from torch._inductor.virtualized import V +from torch._prims_common import ELEMENTWISE_TYPE_PROMOTION_KIND, type_to_dtype + from .utils import upcast_compute_type -from .virtualized import OpsValue, V +from .virtualized import OpsValue T = TypeVar("T") @@ -123,7 +128,10 @@ class DtypePropagationOpsHandler: self, op, functools.partial(self.return_dtype, dtype=torch.bool) ) - unimplemented_ops = OP_NAMES - OrderedSet(dir(self)) + from torch._inductor.ops_handler import OpsHandler + + ops_set = OrderedSet(s for s in dir(OpsHandler) if s[0] != "_") + unimplemented_ops = ops_set - OrderedSet(dir(self)) torch._check( len(unimplemented_ops) == 0, lambda: f"Unimplemented dtype rule for ops: {unimplemented_ops}", @@ -156,12 +164,7 @@ class DtypePropagationOpsHandler: return torch.int64 @staticmethod - def masked( - mask: DTypeArg, body: Callable[[], DTypeArg], other: DTypeArg - ) -> torch.dtype: - from .loop_body import LoopBodyBlock - - assert isinstance(body, LoopBodyBlock), "body must be a LoopBodyBlock" + def masked(mask: DTypeArg, body: "LoopBodyBlock", other: DTypeArg) -> torch.dtype: # TODO - we avoid calling this in codegen, needs work for non codegen use cases loads = body.graph.find_nodes(op="call_method", target="load") if len(loads) <= 1: @@ -207,6 +210,10 @@ class DtypePropagationOpsHandler: def mul(a: DTypeArg, b: DTypeArg) -> torch.dtype: return promote_types([a, b]) + @staticmethod + def div(a: DTypeArg, b: DTypeArg) -> torch.dtype: + return promote_types([a, b]) + @staticmethod def truediv(a: DTypeArg, b: DTypeArg) -> torch.dtype: return promote_types([a, b]) @@ -312,8 +319,6 @@ class DtypePropagationOpsHandler: boundary_indices: DTypeArg, indexing_dtype: torch.dtype, right: bool, - sorter: Optional[tuple[str, sympy.Expr]] = None, - sorter_indices: Optional[T] = None, ) -> torch.dtype: return indexing_dtype @@ -327,6 +332,10 @@ class DtypePropagationOpsHandler: [x], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT ) + @staticmethod + def getitem(x: DTypeArg, y: DTypeArg) -> torch.dtype: + raise RuntimeError("Unexpected op: getitem") + @staticmethod def trunc_to_int(x: DTypeArg, dtype: torch.dtype) -> torch.dtype: return dtype @@ -343,6 +352,11 @@ class DtypePropagationOpsHandler: def floordiv(x: DTypeArg, y: DTypeArg) -> torch.dtype: return promote_types([x, y]) + @staticmethod + def round_decimal(x: DTypeArg, y: DTypeArg) -> torch.dtype: + # TODO - dont see it anywhere.. + return promote_types([x]) + @staticmethod def halide_clamp(value, size, check): # TODO - way of registering dtype for op in backend @@ -363,13 +377,9 @@ class DtypePropagationOpsHandler: return promote_types([x]) @staticmethod - def check_bounds( - expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool - ) -> None: - return None + def invert(x: DTypeArg) -> torch.dtype: + raise RuntimeError("Unexpected op: invert") - -if TYPE_CHECKING: - - class _typecheck_DtypePropagation(DtypePropagationOpsHandler, OpsHandler[Any]): - pass # mypy will error if we got any of the signatures wrong + @staticmethod + def matmul(x: DTypeArg, y: DTypeArg) -> torch.dtype: + raise RuntimeError("Unexpected op: matmul") diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 935c5f6fc36..95c5fa87730 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -2,7 +2,6 @@ from __future__ import annotations import itertools -import re from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union from typing_extensions import Protocol from unittest.mock import patch @@ -66,29 +65,16 @@ class OpsHandler(Protocol[T]): Note that this often describes a class of static methods, for stateless ops handlers. - Handlers are often defined using metaprogramming (e.g. _initialize_pointwise_overrides), - which means you will get type errors if you subclass OpsHandler since mypy doesn't know - about the methods added via metaprogramming and thinks the class is still abstract. - Instead, you should add a block like: - - if TYPE_CHECKING: - - class _typecheck_TritonKernelOverrides(TritonKernelOverrides, OpsHandler[str]): - pass # mypy will error if we got any of the signatures wrong - - Which will check the signatures of non-meta-programmed methods and gives decent error messages. - - Some older parts of the code use a pattern like: - - def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: - return h - - This pattern only works if the class defines a __getattr__ method, which we are moving away from. - Additionally, this pattern generates horrible error messages if the signatures are wrong. - It gives zero information about what the problem is, which makes the pattern harmful. - - Instead of that, we have tests in test/inductor/test_op_completeness.py which check that all - operators are implemented after all the metaprogramming has run. + Handlers are often defined using ``__getattr__`` metaprogramming, which means + that you cannot declare that a type implements a protocol by inheriting from + it (as the type stubs count as attribute declarations and impede the getattr + magic method from being called). Instead, define a function that casts an + argument of your type to the protocol, which is sufficient to induce mypy to + test that the protocol is implemented correctly. Search for ``_typecheck_`` + in this file to see some examples. If you see an obscure error where a + class doesn't implement a Protocol, but mypy doesn't say why, check to see + that ``__getattr__`` is typed correctly (typically, it is not possible to + type ``__getattr__`` without typing it as ``Callable[..., Any]``) """ def constant(self, value: Union[bool, float, int], dtype: torch.dtype) -> T: @@ -531,6 +517,17 @@ class OpsHandler(Protocol[T]): def rshift(self, x0: T, x1: T) -> T: ... + def getitem(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol + ... + + def matmul(self, x0: T, x1: T) -> T: + # TODO: this is probably just illegal lol + ... + + def invert(self, x0: T) -> T: + ... + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # These are "special" operators. These only exist if the target # language actually supports the operator. Keep this in sync with @@ -686,6 +683,11 @@ class OpsHandler(Protocol[T]): """ ... + def div(self, x0: T, x1: T) -> T: + """TODO: to be removed. This renders as / no matter what the backend is + which is incoherent.""" + ... + def mod(self, x0: T, x1: T) -> T: """C-style modulus, take sign from LHS (x0).""" ... @@ -694,12 +696,8 @@ class OpsHandler(Protocol[T]): """Python-style modulus, take sign from RHS (x1).""" ... - def square(self, x0: T) -> T: - ... - - def check_bounds( - self, expr: sympy.Expr, size: sympy.Expr, lower: bool, upper: bool - ) -> None: + def round_decimal(self, x0: T, x1: T) -> T: + """Python-style round with decimal argument""" ... # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ @@ -736,42 +734,16 @@ class OpsHandler(Protocol[T]): def libdevice_log(self, x0: T) -> T: ... - # halide-only - def halide_clamp(self, value: T, size: sympy.Expr, check: bool) -> T: - raise NotImplementedError - - # triton-only - def inline_asm_elementwise( - self, - *inputs: T, - asm: str, - constraints: Optional[str] = None, - dtype: torch.dtype = torch.float32, - is_pure: bool = True, - pack: int = 1, - ) -> T: - ... - - -_ignore_op_re = re.compile(r"_.*|paren").fullmatch - - -def list_ops(cls: type[Any]): - return OrderedSet([x for x in dir(cls) if not _ignore_op_re(x)]) - - -OP_NAMES = list_ops(OpsHandler) - - -def _return_none(*args, **kwargs): - return None - class NoopHandler: - name = "NoopHandler" + def __getattr__(self, name): + if name == "name": + return "NoopHandler" - def __getattr__(self, name: str) -> Callable[..., None]: - return _return_none + def inner(*args, **kwargs): + return None + + return inner @staticmethod def masked(mask, body, other) -> None: @@ -799,89 +771,11 @@ def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]: return h -class BasicMathOps: - @staticmethod - def add(a, b): - return f"{a} + {b}" - - @staticmethod - def sub(a, b): - return f"{a} - {b}" - - @staticmethod - def mul(a, b): - return f"{a} * {b}" - - @staticmethod - def floordiv(a, b): - return f"{a} // {b}" - - @staticmethod - def truediv(a, b): - return f"{a} / {b}" - - @staticmethod - def mod(a, b): - # careful, depending on target semantics varies - return f"{a} % {b}" - - @staticmethod - def pow(a, b): - return f"{a} ** {b}" - - @staticmethod - def lshift(a, b): - return f"{a} << {b}" - - @staticmethod - def rshift(a, b): - return f"{a} >> {b}" - - @staticmethod - def and_(a, b): - return f"{a} & {b}" - - @staticmethod - def or_(a, b): - return f"{a} | {b}" - - @staticmethod - def xor(a, b): - return f"{a} ^ {b}" - - @staticmethod - def eq(a, b): - return f"{a} == {b}" - - @staticmethod - def ne(a, b): - return f"{a} != {b}" - - @staticmethod - def lt(a, b): - return f"{a} < {b}" - - @staticmethod - def gt(a, b): - return f"{a} > {b}" - - @staticmethod - def le(a, b): - return f"{a} <= {b}" - - @staticmethod - def ge(a, b): - return f"{a} >= {b}" - - @staticmethod - def neg(a): - return f"-{a}" - - -class MockHandler(BasicMathOps): - name = "MockHandler" - +class MockHandler: def __getattr__(self, name): + if name == "name": + return "MockHandler" + def inner(*args, **kwargs): fargs = [_arg_str(a) for a in args] fargs.extend(f"{k}={v}" for k, v in kwargs.items()) @@ -915,6 +809,41 @@ class MockHandler(BasicMathOps): def indirect_indexing(index_var, size, check=True, wrap_neg=True) -> sympy.Symbol: return sympy_index_symbol(str(index_var)) + @classmethod + def _init_cls(cls): + def make_handler(format_string): + @staticmethod # type: ignore[misc] + def inner(*args): + return format_string.format(*args) + + return inner + + for name, format_string in { + "add": "{} + {}", + "sub": "{} - {}", + "mul": "{} * {}", + "floordiv": "{} // {}", + "truediv": "{} / {}", + "mod": "{} % {}", # careful, depending on target semantics varies + "pow": "{} ** {}", + "lshift": "{} << {}", + "rshift": "{} >> {}", + "and_": "{} & {}", + "or_": "{} | {}", + "xor": "{} ^ {}", + "eq": "{} == {}", + "ne": "{} != {}", + "lt": "{} < {}", + "gt": "{} > {}", + "le": "{} <= {}", + "ge": "{} >= {}", + "neg": "-{}", + }.items(): + setattr(cls, name, make_handler(format_string)) + + +MockHandler._init_cls() + # Use mypy to check protocol implemented correctly def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: @@ -994,7 +923,7 @@ def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[s class WrapperHandler(Generic[T]): - def __init__(self, inner: Any): + def __init__(self, inner: OpsHandler[T]): self._inner = inner def __getattr__(self, item):