diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index c254aacdb17..6e45c02b68e 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -948,6 +948,16 @@ class OpOverrides(BasicMathOps, OpDecompositions): f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend" ) + def output(self, x0: OpVarT) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear at codegen time" + ) + + def placeholder(self, index: int) -> OpVarT: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear at codegen time" + ) + @staticmethod def _unimplemented(name: str) -> Callable[..., OpVarT]: def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT: @@ -2570,6 +2580,7 @@ class CSEProxy: ) -# Use mypy to check protocol implemented correctly -def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]: - return h +if TYPE_CHECKING: + + class _typecheck_CSEProxy(CSEProxy, OpsHandler[CSEVariable]): + pass diff --git a/torch/_inductor/dtype_propagation.py b/torch/_inductor/dtype_propagation.py index efe0ebe2caf..5b45943b940 100644 --- a/torch/_inductor/dtype_propagation.py +++ b/torch/_inductor/dtype_propagation.py @@ -368,6 +368,16 @@ class DtypePropagationOpsHandler: ) -> None: return None + def output(self, x: DTypeArg) -> None: + raise AssertionError( + f"{type(self).__name__}: ops.output should not appear here" + ) + + def placeholder(self, index: int) -> torch.dtype: + raise AssertionError( + f"{type(self).__name__}: ops.placeholder should not appear here" + ) + if TYPE_CHECKING: diff --git a/torch/_inductor/loop_body.py b/torch/_inductor/loop_body.py index 21f63a11b67..aab9d318762 100644 --- a/torch/_inductor/loop_body.py +++ b/torch/_inductor/loop_body.py @@ -449,7 +449,7 @@ class LoopBodyBlock: ) class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined] - self.name = "CaptureIndexing" + name = "CaptureIndexing" def load(self, name: str, index: sympy.Expr): index = add_index(index, MemoryUsageType.LOAD, buffer_name=name) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 935c5f6fc36..22ce7154c6b 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -3,7 +3,17 @@ from __future__ import annotations import itertools import re -from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union +import warnings +from typing import ( + Any, + Callable, + Literal, + NamedTuple, + Optional, + TYPE_CHECKING, + TypeVar, + Union, +) from typing_extensions import Protocol from unittest.mock import patch @@ -752,6 +762,14 @@ class OpsHandler(Protocol[T]): ) -> T: ... + def output(self, x0: T) -> None: + """This is a fake op used in analysis but not codegen""" + ... + + def placeholder(self, index: int) -> T: + """This is a fake op used in analysis but not codegen""" + ... + _ignore_op_re = re.compile(r"_.*|paren").fullmatch @@ -763,15 +781,53 @@ def list_ops(cls: type[Any]): OP_NAMES = list_ops(OpsHandler) -def _return_none(*args, **kwargs): - return None +class DefaultHandler: + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + """ + Default implementation for all ops. Override in a subclass to + provide generic op behavior. + + Args: + target: name of the op, see OpHandler.target + args: positional args passed to the op + kwargs: keyword args passed to the op + + Returns: + return value of the op + + """ + raise NotImplementedError + + def __getattr__(self, name: str) -> Any: + def fallback(*args: Any, **kwargs: Any) -> Any: + return self._default(name, args, kwargs) + + # would like to remove this function entirely, but it's used in MTIA backend + warnings.warn(f"undefined OpHandler.{name}, please add missing op schema") + return fallback + + @staticmethod + def _call_default(target: str): + def call_default(self, *args, **kwargs): + return self._default(target, args, kwargs) + + call_default.__name__ = target + return call_default + + @classmethod + def _init_cls(cls): + for target in OP_NAMES: + setattr(cls, target, cls._call_default(target)) -class NoopHandler: +DefaultHandler._init_cls() + + +class NoopHandler(DefaultHandler): name = "NoopHandler" - def __getattr__(self, name: str) -> Callable[..., None]: - return _return_none + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return None @staticmethod def masked(mask, body, other) -> None: @@ -794,9 +850,10 @@ class NoopHandler: return sympy.S.Zero -# Use mypy to check protocol implemented correctly -def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]: - return h +if TYPE_CHECKING: + + class _typecheck_NoopHandler(NoopHandler, OpsHandler[None]): + pass # mypy will error if we got any of the signatures wrong class BasicMathOps: @@ -878,16 +935,14 @@ class BasicMathOps: return f"-{a}" -class MockHandler(BasicMathOps): +class MockHandler(BasicMathOps, DefaultHandler): name = "MockHandler" - def __getattr__(self, name): - def inner(*args, **kwargs): - fargs = [_arg_str(a) for a in args] - fargs.extend(f"{k}={v}" for k, v in kwargs.items()) - return f"ops.{name}({', '.join(fargs)})" - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + fargs = [*map(_arg_str, args)] + for k, v in kwargs.items(): + fargs.append(f"{k}={_arg_str(v)}") + return f"ops.{name}({', '.join(fargs)})" @staticmethod def masked(mask, body, other) -> str: @@ -916,15 +971,16 @@ class MockHandler(BasicMathOps): return sympy_index_symbol(str(index_var)) -# Use mypy to check protocol implemented correctly -def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]: - return h +if TYPE_CHECKING: + + class _typecheck_MockHandler(MockHandler, OpsHandler[str]): + pass # mypy will error if we got any of the signatures wrong -class KernelFormatterHandler: +class KernelFormatterHandler(DefaultHandler): def __init__(self, parent_handler): self.parent_handler = parent_handler - self.output = IndentedBuffer(1) + self._output = IndentedBuffer(1) self.var_counter = itertools.count() @staticmethod @@ -936,8 +992,8 @@ class KernelFormatterHandler: names = ["index", "rindex"] if rindex is not None else ["index"] formatter = KernelFormatterHandler(MockHandler()) - with formatter.output.indent(-1): - formatter.output.writeline(f"def inner_fn({', '.join(names)}):") + with formatter._output.indent(-1): + formatter._output.writeline(f"def inner_fn({', '.join(names)}):") for name, arg in zip(names, args): if arg: lhs = ", ".join( @@ -946,7 +1002,7 @@ class KernelFormatterHandler: for v in arg ] ) - formatter.output.writeline(f"{lhs} = {name}") + formatter._output.writeline(f"{lhs} = {name}") with V.set_ops_handler(formatter), patch.object( FlexibleLayout, "allow_indexing", True @@ -954,21 +1010,19 @@ class KernelFormatterHandler: result = ir_fn(*args) return formatter.getvalue(result) - def __getattr__(self, name) -> Callable[..., Any]: - def inner(*args, **kwargs): - line = getattr(self.parent_handler, name)(*args, **kwargs) - if name == "indirect_indexing": - return line + def indirect_indexing(self, *args, **kwargs) -> sympy.Symbol: + return self.parent_handler.indirect_indexing(*args, **kwargs) - def write(line): - # replace line with a new variable name - varname = f"tmp{next(self.var_counter)}" - self.output.writeline(f"{varname} = {line}") - return varname + def _write(self, line): + # replace line with a new variable name + varname = f"tmp{next(self.var_counter)}" + self._output.writeline(f"{varname} = {line}") + return varname - return pytree.tree_map(write, line) - - return inner + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return pytree.tree_map( + self._write, getattr(self.parent_handler, name)(*args, **kwargs) + ) def reduction( self, @@ -980,44 +1034,34 @@ class KernelFormatterHandler: line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value) num_values = reduction_num_outputs(reduction_type) varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)] - self.output.writeline(f"{','.join(varnames)} = {line}") + self._output.writeline(f"{','.join(varnames)} = {line}") return tuple(varnames) if num_values > 1 else varnames[0] def getvalue(self, result): - self.output.writeline(f"return {result}") - return self.output.getvalue() + self._output.writeline(f"return {result}") + return self._output.getvalue() -# Use mypy to check protocol implemented correctly -def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]: - return h +if TYPE_CHECKING: + + class _typecheck_KernelFormatterHandler(KernelFormatterHandler, OpsHandler[str]): + pass # mypy will error if we got any of the signatures wrong -class WrapperHandler(Generic[T]): +class WrapperHandler(DefaultHandler): def __init__(self, inner: Any): self._inner = inner - def __getattr__(self, item): - return getattr(self._inner, item) + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + return getattr(self._inner, name)(*args, **kwargs) -# Use mypy to check protocol implemented correctly -def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]: - return h - - -class AddParenHandler(WrapperHandler[T]): - def __getattr__(self, name): - def inner(*args, **kwargs): - val = getattr(self._inner, name)(*args, **kwargs) - return f"({val})" - - return inner - - -# Use mypy to check protocol implemented correctly -def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]: - return h +class AddParenHandler(WrapperHandler): + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + val = getattr(self._inner, name)(*args, **kwargs) + if not val or isinstance(val, (sympy.Expr, tuple, list)): + return val + return f"({val})" class OpCountResult(NamedTuple): @@ -1027,7 +1071,7 @@ class OpCountResult(NamedTuple): nontrivial_read_count: int -class OpCounterCSE: +class OpCounterCSE(DefaultHandler): """Shim to count how many ops are used""" def __init__(self, inner): @@ -1035,18 +1079,15 @@ class OpCounterCSE: self.parent_handler = inner self.op_count = 0 self.var_names = {} - self._used_ops = OrderedSet[str]() + self._used_ops: OrderedSet[str] = OrderedSet() self._read_names: list[str] = [] self._nontrivial_read_count = 0 - def __getattr__(self, name): - def inner(*args, **kwargs): - return pytree.tree_map( - self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) - ) - + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: self._used_ops.add(name) - return inner + return pytree.tree_map( + self._update_count, getattr(self.parent_handler, name)(*args, **kwargs) + ) def _update_count(self, val): varname = self.var_names.get(val) @@ -1111,8 +1152,10 @@ class OpCounterCSE: ) -def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]: - return h +if TYPE_CHECKING: + + class _typecheck_OpCounterCSE(OpCounterCSE, OpsHandler[str]): + pass # mypy will error if we got any of the signatures wrong class ExtractConstantsHandler(NoopHandler): @@ -1125,44 +1168,45 @@ class ExtractConstantsHandler(NoopHandler): return ir.Constant(value=value, dtype=dtype, device=self.device) -def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]: - return h +if TYPE_CHECKING: + + class _typecheck_ExtractConstantsHandler(ExtractConstantsHandler, OpsHandler[Any]): + pass # mypy will error if we got any of the signatures wrong -class SimpleCSEHandler(WrapperHandler[T]): +class SimpleCSEHandler(WrapperHandler): """Wraps the underlying handler with a CSE pass NOTE: Compared to codegen level CSE this is simplified as it doesn't support stores which require load cache invalidation. """ - def __init__(self, inner: OpsHandler[T]): + def __init__(self, inner: Any): super().__init__(inner) - self.cse_cache: dict[str, Union[T, tuple[T, ...]]] = {} + self.cse_cache: dict[str, Union[Any, tuple[Any, ...]]] = {} self.mock = MockHandler() def indirect_indexing(self, *args, **kwargs) -> sympy.Expr: return super().indirect_indexing(*args, **kwargs) # type: ignore[misc] - def store(self, *args, **kwargs) -> T: + def store(self, *args, **kwargs) -> None: raise NotImplementedError("store not implemented") - def store_reduction(self, *args, **kwargs) -> T: + def store_reduction(self, *args, **kwargs) -> None: raise NotImplementedError("store not implemented") - def __getattr__(self, name) -> Callable[..., Any]: - def inner(*args, **kwargs): - key = getattr(self.mock, name)(*args, **kwargs) - val = self.cse_cache.get(key) - if val is not None: - return val - - val = getattr(self._inner, name)(*args, **kwargs) - self.cse_cache[key] = val + def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any: + key = getattr(self.mock, name)(*args, **kwargs) + val = self.cse_cache.get(key) + if val is not None: return val - return inner + val = getattr(self._inner, name)(*args, **kwargs) + self.cse_cache[key] = val + return val -def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]: - return h +if TYPE_CHECKING: + + class _typecheck_SimpleCSEHandler(SimpleCSEHandler, OpsHandler[Any]): + pass # mypy will error if we got any of the signatures wrong diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index bd96e830bcc..d016af99954 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -742,7 +742,7 @@ class TritonTemplateKernel(TritonKernel): template_mask = self.template_mask class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined] - self.name = name + name = "StoreOutputSubstitution" def store( self, diff --git a/torch/_inductor/subgraph_lowering.py b/torch/_inductor/subgraph_lowering.py index ce35959c532..0166534e5fb 100644 --- a/torch/_inductor/subgraph_lowering.py +++ b/torch/_inductor/subgraph_lowering.py @@ -135,7 +135,7 @@ class InputDescriptor: device: torch.device -class TracingOpsHandler(WrapperHandler[T]): +class TracingOpsHandler(WrapperHandler): def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None: parent = tracer.create_proxy("placeholder", "ops", (), {}) super().__init__(parent) diff --git a/torch/_inductor/virtualized.py b/torch/_inductor/virtualized.py index 4de41846166..55ff6ce32b3 100644 --- a/torch/_inductor/virtualized.py +++ b/torch/_inductor/virtualized.py @@ -59,7 +59,7 @@ from __future__ import annotations from contextlib import AbstractContextManager, contextmanager from threading import local -from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union +from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union from torch.utils._ordered_set import OrderedSet @@ -154,7 +154,9 @@ class NullKernelHandler(NullHandler): self.index_dtype = "tl.int64" -_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler) +_ops: Virtualized[OpsHandler[Any]] = Virtualized( + "ops", cast(type[OpsHandler[Any]], MockHandler) +) _graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler) _real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler) _fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)