[inductor] Refactor op handlers part 2

ghstack-source-id: b0e2f58719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146252
This commit is contained in:
Jason Ansel 2025-02-07 13:32:53 -08:00
parent 579b9f2ed9
commit 4e40642e01
7 changed files with 169 additions and 102 deletions

View file

@ -948,6 +948,16 @@ class OpOverrides(BasicMathOps, OpDecompositions):
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
)
def output(self, x0: OpVarT) -> None:
raise AssertionError(
f"{type(self).__name__}: ops.output should not appear at codegen time"
)
def placeholder(self, index: int) -> OpVarT:
raise AssertionError(
f"{type(self).__name__}: ops.placeholder should not appear at codegen time"
)
@staticmethod
def _unimplemented(name: str) -> Callable[..., OpVarT]:
def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
@ -2570,6 +2580,7 @@ class CSEProxy:
)
# Use mypy to check protocol implemented correctly
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
return h
if TYPE_CHECKING:
class _typecheck_CSEProxy(CSEProxy, OpsHandler[CSEVariable]):
pass

View file

@ -368,6 +368,16 @@ class DtypePropagationOpsHandler:
) -> None:
return None
def output(self, x: DTypeArg) -> None:
raise AssertionError(
f"{type(self).__name__}: ops.output should not appear here"
)
def placeholder(self, index: int) -> torch.dtype:
raise AssertionError(
f"{type(self).__name__}: ops.placeholder should not appear here"
)
if TYPE_CHECKING:

View file

@ -449,7 +449,7 @@ class LoopBodyBlock:
)
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
self.name = "CaptureIndexing"
name = "CaptureIndexing"
def load(self, name: str, index: sympy.Expr):
index = add_index(index, MemoryUsageType.LOAD, buffer_name=name)

View file

@ -3,7 +3,17 @@ from __future__ import annotations
import itertools
import re
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
import warnings
from typing import (
Any,
Callable,
Literal,
NamedTuple,
Optional,
TYPE_CHECKING,
TypeVar,
Union,
)
from typing_extensions import Protocol
from unittest.mock import patch
@ -752,6 +762,14 @@ class OpsHandler(Protocol[T]):
) -> T:
...
def output(self, x0: T) -> None:
"""This is a fake op used in analysis but not codegen"""
...
def placeholder(self, index: int) -> T:
"""This is a fake op used in analysis but not codegen"""
...
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
@ -763,15 +781,53 @@ def list_ops(cls: type[Any]):
OP_NAMES = list_ops(OpsHandler)
def _return_none(*args, **kwargs):
return None
class DefaultHandler:
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
"""
Default implementation for all ops. Override in a subclass to
provide generic op behavior.
Args:
target: name of the op, see OpHandler.target
args: positional args passed to the op
kwargs: keyword args passed to the op
Returns:
return value of the op
"""
raise NotImplementedError
def __getattr__(self, name: str) -> Any:
def fallback(*args: Any, **kwargs: Any) -> Any:
return self._default(name, args, kwargs)
# would like to remove this function entirely, but it's used in MTIA backend
warnings.warn(f"undefined OpHandler.{name}, please add missing op schema")
return fallback
@staticmethod
def _call_default(target: str):
def call_default(self, *args, **kwargs):
return self._default(target, args, kwargs)
call_default.__name__ = target
return call_default
@classmethod
def _init_cls(cls):
for target in OP_NAMES:
setattr(cls, target, cls._call_default(target))
class NoopHandler:
DefaultHandler._init_cls()
class NoopHandler(DefaultHandler):
name = "NoopHandler"
def __getattr__(self, name: str) -> Callable[..., None]:
return _return_none
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
return None
@staticmethod
def masked(mask, body, other) -> None:
@ -794,9 +850,10 @@ class NoopHandler:
return sympy.S.Zero
# Use mypy to check protocol implemented correctly
def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]:
return h
if TYPE_CHECKING:
class _typecheck_NoopHandler(NoopHandler, OpsHandler[None]):
pass # mypy will error if we got any of the signatures wrong
class BasicMathOps:
@ -878,16 +935,14 @@ class BasicMathOps:
return f"-{a}"
class MockHandler(BasicMathOps):
class MockHandler(BasicMathOps, DefaultHandler):
name = "MockHandler"
def __getattr__(self, name):
def inner(*args, **kwargs):
fargs = [_arg_str(a) for a in args]
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
return f"ops.{name}({', '.join(fargs)})"
return inner
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
fargs = [*map(_arg_str, args)]
for k, v in kwargs.items():
fargs.append(f"{k}={_arg_str(v)}")
return f"ops.{name}({', '.join(fargs)})"
@staticmethod
def masked(mask, body, other) -> str:
@ -916,15 +971,16 @@ class MockHandler(BasicMathOps):
return sympy_index_symbol(str(index_var))
# Use mypy to check protocol implemented correctly
def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
return h
if TYPE_CHECKING:
class _typecheck_MockHandler(MockHandler, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class KernelFormatterHandler:
class KernelFormatterHandler(DefaultHandler):
def __init__(self, parent_handler):
self.parent_handler = parent_handler
self.output = IndentedBuffer(1)
self._output = IndentedBuffer(1)
self.var_counter = itertools.count()
@staticmethod
@ -936,8 +992,8 @@ class KernelFormatterHandler:
names = ["index", "rindex"] if rindex is not None else ["index"]
formatter = KernelFormatterHandler(MockHandler())
with formatter.output.indent(-1):
formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
with formatter._output.indent(-1):
formatter._output.writeline(f"def inner_fn({', '.join(names)}):")
for name, arg in zip(names, args):
if arg:
lhs = ", ".join(
@ -946,7 +1002,7 @@ class KernelFormatterHandler:
for v in arg
]
)
formatter.output.writeline(f"{lhs} = {name}")
formatter._output.writeline(f"{lhs} = {name}")
with V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
@ -954,21 +1010,19 @@ class KernelFormatterHandler:
result = ir_fn(*args)
return formatter.getvalue(result)
def __getattr__(self, name) -> Callable[..., Any]:
def inner(*args, **kwargs):
line = getattr(self.parent_handler, name)(*args, **kwargs)
if name == "indirect_indexing":
return line
def indirect_indexing(self, *args, **kwargs) -> sympy.Symbol:
return self.parent_handler.indirect_indexing(*args, **kwargs)
def write(line):
# replace line with a new variable name
varname = f"tmp{next(self.var_counter)}"
self.output.writeline(f"{varname} = {line}")
return varname
def _write(self, line):
# replace line with a new variable name
varname = f"tmp{next(self.var_counter)}"
self._output.writeline(f"{varname} = {line}")
return varname
return pytree.tree_map(write, line)
return inner
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
return pytree.tree_map(
self._write, getattr(self.parent_handler, name)(*args, **kwargs)
)
def reduction(
self,
@ -980,44 +1034,34 @@ class KernelFormatterHandler:
line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value)
num_values = reduction_num_outputs(reduction_type)
varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)]
self.output.writeline(f"{','.join(varnames)} = {line}")
self._output.writeline(f"{','.join(varnames)} = {line}")
return tuple(varnames) if num_values > 1 else varnames[0]
def getvalue(self, result):
self.output.writeline(f"return {result}")
return self.output.getvalue()
self._output.writeline(f"return {result}")
return self._output.getvalue()
# Use mypy to check protocol implemented correctly
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
return h
if TYPE_CHECKING:
class _typecheck_KernelFormatterHandler(KernelFormatterHandler, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class WrapperHandler(Generic[T]):
class WrapperHandler(DefaultHandler):
def __init__(self, inner: Any):
self._inner = inner
def __getattr__(self, item):
return getattr(self._inner, item)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
return getattr(self._inner, name)(*args, **kwargs)
# Use mypy to check protocol implemented correctly
def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]:
return h
class AddParenHandler(WrapperHandler[T]):
def __getattr__(self, name):
def inner(*args, **kwargs):
val = getattr(self._inner, name)(*args, **kwargs)
return f"({val})"
return inner
# Use mypy to check protocol implemented correctly
def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]:
return h
class AddParenHandler(WrapperHandler):
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
val = getattr(self._inner, name)(*args, **kwargs)
if not val or isinstance(val, (sympy.Expr, tuple, list)):
return val
return f"({val})"
class OpCountResult(NamedTuple):
@ -1027,7 +1071,7 @@ class OpCountResult(NamedTuple):
nontrivial_read_count: int
class OpCounterCSE:
class OpCounterCSE(DefaultHandler):
"""Shim to count how many ops are used"""
def __init__(self, inner):
@ -1035,18 +1079,15 @@ class OpCounterCSE:
self.parent_handler = inner
self.op_count = 0
self.var_names = {}
self._used_ops = OrderedSet[str]()
self._used_ops: OrderedSet[str] = OrderedSet()
self._read_names: list[str] = []
self._nontrivial_read_count = 0
def __getattr__(self, name):
def inner(*args, **kwargs):
return pytree.tree_map(
self._update_count, getattr(self.parent_handler, name)(*args, **kwargs)
)
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
self._used_ops.add(name)
return inner
return pytree.tree_map(
self._update_count, getattr(self.parent_handler, name)(*args, **kwargs)
)
def _update_count(self, val):
varname = self.var_names.get(val)
@ -1111,8 +1152,10 @@ class OpCounterCSE:
)
def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]:
return h
if TYPE_CHECKING:
class _typecheck_OpCounterCSE(OpCounterCSE, OpsHandler[str]):
pass # mypy will error if we got any of the signatures wrong
class ExtractConstantsHandler(NoopHandler):
@ -1125,44 +1168,45 @@ class ExtractConstantsHandler(NoopHandler):
return ir.Constant(value=value, dtype=dtype, device=self.device)
def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]:
return h
if TYPE_CHECKING:
class _typecheck_ExtractConstantsHandler(ExtractConstantsHandler, OpsHandler[Any]):
pass # mypy will error if we got any of the signatures wrong
class SimpleCSEHandler(WrapperHandler[T]):
class SimpleCSEHandler(WrapperHandler):
"""Wraps the underlying handler with a CSE pass
NOTE: Compared to codegen level CSE this is simplified as it
doesn't support stores which require load cache invalidation.
"""
def __init__(self, inner: OpsHandler[T]):
def __init__(self, inner: Any):
super().__init__(inner)
self.cse_cache: dict[str, Union[T, tuple[T, ...]]] = {}
self.cse_cache: dict[str, Union[Any, tuple[Any, ...]]] = {}
self.mock = MockHandler()
def indirect_indexing(self, *args, **kwargs) -> sympy.Expr:
return super().indirect_indexing(*args, **kwargs) # type: ignore[misc]
def store(self, *args, **kwargs) -> T:
def store(self, *args, **kwargs) -> None:
raise NotImplementedError("store not implemented")
def store_reduction(self, *args, **kwargs) -> T:
def store_reduction(self, *args, **kwargs) -> None:
raise NotImplementedError("store not implemented")
def __getattr__(self, name) -> Callable[..., Any]:
def inner(*args, **kwargs):
key = getattr(self.mock, name)(*args, **kwargs)
val = self.cse_cache.get(key)
if val is not None:
return val
val = getattr(self._inner, name)(*args, **kwargs)
self.cse_cache[key] = val
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
key = getattr(self.mock, name)(*args, **kwargs)
val = self.cse_cache.get(key)
if val is not None:
return val
return inner
val = getattr(self._inner, name)(*args, **kwargs)
self.cse_cache[key] = val
return val
def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]:
return h
if TYPE_CHECKING:
class _typecheck_SimpleCSEHandler(SimpleCSEHandler, OpsHandler[Any]):
pass # mypy will error if we got any of the signatures wrong

View file

@ -742,7 +742,7 @@ class TritonTemplateKernel(TritonKernel):
template_mask = self.template_mask
class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined]
self.name = name
name = "StoreOutputSubstitution"
def store(
self,

View file

@ -135,7 +135,7 @@ class InputDescriptor:
device: torch.device
class TracingOpsHandler(WrapperHandler[T]):
class TracingOpsHandler(WrapperHandler):
def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None:
parent = tracer.create_proxy("placeholder", "ops", (), {})
super().__init__(parent)

View file

@ -59,7 +59,7 @@ from __future__ import annotations
from contextlib import AbstractContextManager, contextmanager
from threading import local
from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union
from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union
from torch.utils._ordered_set import OrderedSet
@ -154,7 +154,9 @@ class NullKernelHandler(NullHandler):
self.index_dtype = "tl.int64"
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
_ops: Virtualized[OpsHandler[Any]] = Virtualized(
"ops", cast(type[OpsHandler[Any]], MockHandler)
)
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)