mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[inductor] Refactor op handlers part 2
ghstack-source-id: b0e2f58719
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146252
This commit is contained in:
parent
579b9f2ed9
commit
4e40642e01
7 changed files with 169 additions and 102 deletions
|
|
@ -948,6 +948,16 @@ class OpOverrides(BasicMathOps, OpDecompositions):
|
|||
f"{type(self).__name__}: inline_asm_elementwise only implemented for Triton backend"
|
||||
)
|
||||
|
||||
def output(self, x0: OpVarT) -> None:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.output should not appear at codegen time"
|
||||
)
|
||||
|
||||
def placeholder(self, index: int) -> OpVarT:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.placeholder should not appear at codegen time"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _unimplemented(name: str) -> Callable[..., OpVarT]:
|
||||
def unimplemented(self: OpOverrides, *args: Any, **kwargs: Any) -> OpVarT:
|
||||
|
|
@ -2570,6 +2580,7 @@ class CSEProxy:
|
|||
)
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_CSEProxy(h: CSEProxy) -> OpsHandler[CSEVariable]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_CSEProxy(CSEProxy, OpsHandler[CSEVariable]):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -368,6 +368,16 @@ class DtypePropagationOpsHandler:
|
|||
) -> None:
|
||||
return None
|
||||
|
||||
def output(self, x: DTypeArg) -> None:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.output should not appear here"
|
||||
)
|
||||
|
||||
def placeholder(self, index: int) -> torch.dtype:
|
||||
raise AssertionError(
|
||||
f"{type(self).__name__}: ops.placeholder should not appear here"
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
|
|
|
|||
|
|
@ -449,7 +449,7 @@ class LoopBodyBlock:
|
|||
)
|
||||
|
||||
class CaptureIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self.name = "CaptureIndexing"
|
||||
name = "CaptureIndexing"
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
index = add_index(index, MemoryUsageType.LOAD, buffer_name=name)
|
||||
|
|
|
|||
|
|
@ -3,7 +3,17 @@ from __future__ import annotations
|
|||
|
||||
import itertools
|
||||
import re
|
||||
from typing import Any, Callable, Generic, Literal, NamedTuple, Optional, TypeVar, Union
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Literal,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from typing_extensions import Protocol
|
||||
from unittest.mock import patch
|
||||
|
||||
|
|
@ -752,6 +762,14 @@ class OpsHandler(Protocol[T]):
|
|||
) -> T:
|
||||
...
|
||||
|
||||
def output(self, x0: T) -> None:
|
||||
"""This is a fake op used in analysis but not codegen"""
|
||||
...
|
||||
|
||||
def placeholder(self, index: int) -> T:
|
||||
"""This is a fake op used in analysis but not codegen"""
|
||||
...
|
||||
|
||||
|
||||
_ignore_op_re = re.compile(r"_.*|paren").fullmatch
|
||||
|
||||
|
|
@ -763,15 +781,53 @@ def list_ops(cls: type[Any]):
|
|||
OP_NAMES = list_ops(OpsHandler)
|
||||
|
||||
|
||||
def _return_none(*args, **kwargs):
|
||||
return None
|
||||
class DefaultHandler:
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
"""
|
||||
Default implementation for all ops. Override in a subclass to
|
||||
provide generic op behavior.
|
||||
|
||||
Args:
|
||||
target: name of the op, see OpHandler.target
|
||||
args: positional args passed to the op
|
||||
kwargs: keyword args passed to the op
|
||||
|
||||
Returns:
|
||||
return value of the op
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
def fallback(*args: Any, **kwargs: Any) -> Any:
|
||||
return self._default(name, args, kwargs)
|
||||
|
||||
# would like to remove this function entirely, but it's used in MTIA backend
|
||||
warnings.warn(f"undefined OpHandler.{name}, please add missing op schema")
|
||||
return fallback
|
||||
|
||||
@staticmethod
|
||||
def _call_default(target: str):
|
||||
def call_default(self, *args, **kwargs):
|
||||
return self._default(target, args, kwargs)
|
||||
|
||||
call_default.__name__ = target
|
||||
return call_default
|
||||
|
||||
@classmethod
|
||||
def _init_cls(cls):
|
||||
for target in OP_NAMES:
|
||||
setattr(cls, target, cls._call_default(target))
|
||||
|
||||
|
||||
class NoopHandler:
|
||||
DefaultHandler._init_cls()
|
||||
|
||||
|
||||
class NoopHandler(DefaultHandler):
|
||||
name = "NoopHandler"
|
||||
|
||||
def __getattr__(self, name: str) -> Callable[..., None]:
|
||||
return _return_none
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def masked(mask, body, other) -> None:
|
||||
|
|
@ -794,9 +850,10 @@ class NoopHandler:
|
|||
return sympy.S.Zero
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_NoopHandler(h: NoopHandler) -> OpsHandler[None]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_NoopHandler(NoopHandler, OpsHandler[None]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class BasicMathOps:
|
||||
|
|
@ -878,16 +935,14 @@ class BasicMathOps:
|
|||
return f"-{a}"
|
||||
|
||||
|
||||
class MockHandler(BasicMathOps):
|
||||
class MockHandler(BasicMathOps, DefaultHandler):
|
||||
name = "MockHandler"
|
||||
|
||||
def __getattr__(self, name):
|
||||
def inner(*args, **kwargs):
|
||||
fargs = [_arg_str(a) for a in args]
|
||||
fargs.extend(f"{k}={v}" for k, v in kwargs.items())
|
||||
return f"ops.{name}({', '.join(fargs)})"
|
||||
|
||||
return inner
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
fargs = [*map(_arg_str, args)]
|
||||
for k, v in kwargs.items():
|
||||
fargs.append(f"{k}={_arg_str(v)}")
|
||||
return f"ops.{name}({', '.join(fargs)})"
|
||||
|
||||
@staticmethod
|
||||
def masked(mask, body, other) -> str:
|
||||
|
|
@ -916,15 +971,16 @@ class MockHandler(BasicMathOps):
|
|||
return sympy_index_symbol(str(index_var))
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_MockHandler(h: MockHandler) -> OpsHandler[str]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_MockHandler(MockHandler, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class KernelFormatterHandler:
|
||||
class KernelFormatterHandler(DefaultHandler):
|
||||
def __init__(self, parent_handler):
|
||||
self.parent_handler = parent_handler
|
||||
self.output = IndentedBuffer(1)
|
||||
self._output = IndentedBuffer(1)
|
||||
self.var_counter = itertools.count()
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -936,8 +992,8 @@ class KernelFormatterHandler:
|
|||
names = ["index", "rindex"] if rindex is not None else ["index"]
|
||||
formatter = KernelFormatterHandler(MockHandler())
|
||||
|
||||
with formatter.output.indent(-1):
|
||||
formatter.output.writeline(f"def inner_fn({', '.join(names)}):")
|
||||
with formatter._output.indent(-1):
|
||||
formatter._output.writeline(f"def inner_fn({', '.join(names)}):")
|
||||
for name, arg in zip(names, args):
|
||||
if arg:
|
||||
lhs = ", ".join(
|
||||
|
|
@ -946,7 +1002,7 @@ class KernelFormatterHandler:
|
|||
for v in arg
|
||||
]
|
||||
)
|
||||
formatter.output.writeline(f"{lhs} = {name}")
|
||||
formatter._output.writeline(f"{lhs} = {name}")
|
||||
|
||||
with V.set_ops_handler(formatter), patch.object(
|
||||
FlexibleLayout, "allow_indexing", True
|
||||
|
|
@ -954,21 +1010,19 @@ class KernelFormatterHandler:
|
|||
result = ir_fn(*args)
|
||||
return formatter.getvalue(result)
|
||||
|
||||
def __getattr__(self, name) -> Callable[..., Any]:
|
||||
def inner(*args, **kwargs):
|
||||
line = getattr(self.parent_handler, name)(*args, **kwargs)
|
||||
if name == "indirect_indexing":
|
||||
return line
|
||||
def indirect_indexing(self, *args, **kwargs) -> sympy.Symbol:
|
||||
return self.parent_handler.indirect_indexing(*args, **kwargs)
|
||||
|
||||
def write(line):
|
||||
# replace line with a new variable name
|
||||
varname = f"tmp{next(self.var_counter)}"
|
||||
self.output.writeline(f"{varname} = {line}")
|
||||
return varname
|
||||
def _write(self, line):
|
||||
# replace line with a new variable name
|
||||
varname = f"tmp{next(self.var_counter)}"
|
||||
self._output.writeline(f"{varname} = {line}")
|
||||
return varname
|
||||
|
||||
return pytree.tree_map(write, line)
|
||||
|
||||
return inner
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
return pytree.tree_map(
|
||||
self._write, getattr(self.parent_handler, name)(*args, **kwargs)
|
||||
)
|
||||
|
||||
def reduction(
|
||||
self,
|
||||
|
|
@ -980,44 +1034,34 @@ class KernelFormatterHandler:
|
|||
line = self.parent_handler.reduction(dtype, src_dtype, reduction_type, value)
|
||||
num_values = reduction_num_outputs(reduction_type)
|
||||
varnames = [f"tmp{next(self.var_counter)}" for _ in range(num_values)]
|
||||
self.output.writeline(f"{','.join(varnames)} = {line}")
|
||||
self._output.writeline(f"{','.join(varnames)} = {line}")
|
||||
return tuple(varnames) if num_values > 1 else varnames[0]
|
||||
|
||||
def getvalue(self, result):
|
||||
self.output.writeline(f"return {result}")
|
||||
return self.output.getvalue()
|
||||
self._output.writeline(f"return {result}")
|
||||
return self._output.getvalue()
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_KernelFormatterHandler(h: KernelFormatterHandler) -> OpsHandler[str]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_KernelFormatterHandler(KernelFormatterHandler, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class WrapperHandler(Generic[T]):
|
||||
class WrapperHandler(DefaultHandler):
|
||||
def __init__(self, inner: Any):
|
||||
self._inner = inner
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._inner, item)
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
return getattr(self._inner, name)(*args, **kwargs)
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_WrapperHandler(h: WrapperHandler[T]) -> OpsHandler[T]:
|
||||
return h
|
||||
|
||||
|
||||
class AddParenHandler(WrapperHandler[T]):
|
||||
def __getattr__(self, name):
|
||||
def inner(*args, **kwargs):
|
||||
val = getattr(self._inner, name)(*args, **kwargs)
|
||||
return f"({val})"
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
# Use mypy to check protocol implemented correctly
|
||||
def _typecheck_AddParenHandler(h: AddParenHandler[T]) -> OpsHandler[T]:
|
||||
return h
|
||||
class AddParenHandler(WrapperHandler):
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
val = getattr(self._inner, name)(*args, **kwargs)
|
||||
if not val or isinstance(val, (sympy.Expr, tuple, list)):
|
||||
return val
|
||||
return f"({val})"
|
||||
|
||||
|
||||
class OpCountResult(NamedTuple):
|
||||
|
|
@ -1027,7 +1071,7 @@ class OpCountResult(NamedTuple):
|
|||
nontrivial_read_count: int
|
||||
|
||||
|
||||
class OpCounterCSE:
|
||||
class OpCounterCSE(DefaultHandler):
|
||||
"""Shim to count how many ops are used"""
|
||||
|
||||
def __init__(self, inner):
|
||||
|
|
@ -1035,18 +1079,15 @@ class OpCounterCSE:
|
|||
self.parent_handler = inner
|
||||
self.op_count = 0
|
||||
self.var_names = {}
|
||||
self._used_ops = OrderedSet[str]()
|
||||
self._used_ops: OrderedSet[str] = OrderedSet()
|
||||
self._read_names: list[str] = []
|
||||
self._nontrivial_read_count = 0
|
||||
|
||||
def __getattr__(self, name):
|
||||
def inner(*args, **kwargs):
|
||||
return pytree.tree_map(
|
||||
self._update_count, getattr(self.parent_handler, name)(*args, **kwargs)
|
||||
)
|
||||
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
self._used_ops.add(name)
|
||||
return inner
|
||||
return pytree.tree_map(
|
||||
self._update_count, getattr(self.parent_handler, name)(*args, **kwargs)
|
||||
)
|
||||
|
||||
def _update_count(self, val):
|
||||
varname = self.var_names.get(val)
|
||||
|
|
@ -1111,8 +1152,10 @@ class OpCounterCSE:
|
|||
)
|
||||
|
||||
|
||||
def _typecheck_OpCounterCSE(h: OpCounterCSE) -> OpsHandler[str]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_OpCounterCSE(OpCounterCSE, OpsHandler[str]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class ExtractConstantsHandler(NoopHandler):
|
||||
|
|
@ -1125,44 +1168,45 @@ class ExtractConstantsHandler(NoopHandler):
|
|||
return ir.Constant(value=value, dtype=dtype, device=self.device)
|
||||
|
||||
|
||||
def _typecheck_ExtractConstantsHandler(h: ExtractConstantsHandler) -> OpsHandler[Any]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_ExtractConstantsHandler(ExtractConstantsHandler, OpsHandler[Any]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
||||
|
||||
class SimpleCSEHandler(WrapperHandler[T]):
|
||||
class SimpleCSEHandler(WrapperHandler):
|
||||
"""Wraps the underlying handler with a CSE pass
|
||||
|
||||
NOTE: Compared to codegen level CSE this is simplified as it
|
||||
doesn't support stores which require load cache invalidation.
|
||||
"""
|
||||
|
||||
def __init__(self, inner: OpsHandler[T]):
|
||||
def __init__(self, inner: Any):
|
||||
super().__init__(inner)
|
||||
self.cse_cache: dict[str, Union[T, tuple[T, ...]]] = {}
|
||||
self.cse_cache: dict[str, Union[Any, tuple[Any, ...]]] = {}
|
||||
self.mock = MockHandler()
|
||||
|
||||
def indirect_indexing(self, *args, **kwargs) -> sympy.Expr:
|
||||
return super().indirect_indexing(*args, **kwargs) # type: ignore[misc]
|
||||
|
||||
def store(self, *args, **kwargs) -> T:
|
||||
def store(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError("store not implemented")
|
||||
|
||||
def store_reduction(self, *args, **kwargs) -> T:
|
||||
def store_reduction(self, *args, **kwargs) -> None:
|
||||
raise NotImplementedError("store not implemented")
|
||||
|
||||
def __getattr__(self, name) -> Callable[..., Any]:
|
||||
def inner(*args, **kwargs):
|
||||
key = getattr(self.mock, name)(*args, **kwargs)
|
||||
val = self.cse_cache.get(key)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
val = getattr(self._inner, name)(*args, **kwargs)
|
||||
self.cse_cache[key] = val
|
||||
def _default(self, name: str, args: tuple[Any, ...], kwargs: dict[str, Any]) -> Any:
|
||||
key = getattr(self.mock, name)(*args, **kwargs)
|
||||
val = self.cse_cache.get(key)
|
||||
if val is not None:
|
||||
return val
|
||||
|
||||
return inner
|
||||
val = getattr(self._inner, name)(*args, **kwargs)
|
||||
self.cse_cache[key] = val
|
||||
return val
|
||||
|
||||
|
||||
def _typecheck_SimpleCSEHandler(h: SimpleCSEHandler[Any]) -> OpsHandler[Any]:
|
||||
return h
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class _typecheck_SimpleCSEHandler(SimpleCSEHandler, OpsHandler[Any]):
|
||||
pass # mypy will error if we got any of the signatures wrong
|
||||
|
|
|
|||
|
|
@ -742,7 +742,7 @@ class TritonTemplateKernel(TritonKernel):
|
|||
template_mask = self.template_mask
|
||||
|
||||
class StoreOutputSubstitution(V.WrapperHandler): # type: ignore[name-defined]
|
||||
self.name = name
|
||||
name = "StoreOutputSubstitution"
|
||||
|
||||
def store(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -135,7 +135,7 @@ class InputDescriptor:
|
|||
device: torch.device
|
||||
|
||||
|
||||
class TracingOpsHandler(WrapperHandler[T]):
|
||||
class TracingOpsHandler(WrapperHandler):
|
||||
def __init__(self, tracer: torch.fx.Tracer, num_inputs: int) -> None:
|
||||
parent = tracer.create_proxy("placeholder", "ops", (), {})
|
||||
super().__init__(parent)
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ from __future__ import annotations
|
|||
|
||||
from contextlib import AbstractContextManager, contextmanager
|
||||
from threading import local
|
||||
from typing import Any, Callable, Generic, TYPE_CHECKING, TypeVar, Union
|
||||
from typing import Any, Callable, cast, Generic, TYPE_CHECKING, TypeVar, Union
|
||||
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
||||
|
|
@ -154,7 +154,9 @@ class NullKernelHandler(NullHandler):
|
|||
self.index_dtype = "tl.int64"
|
||||
|
||||
|
||||
_ops: Virtualized[OpsHandler[Any]] = Virtualized("ops", MockHandler)
|
||||
_ops: Virtualized[OpsHandler[Any]] = Virtualized(
|
||||
"ops", cast(type[OpsHandler[Any]], MockHandler)
|
||||
)
|
||||
_graph: Virtualized[GraphLowering] = Virtualized("graph", NullHandler)
|
||||
_real_inputs: Virtualized[list[torch.Tensor]] = Virtualized("real_inputs", NullHandler)
|
||||
_fake_mode: Virtualized[FakeTensorMode] = Virtualized("fake_mode", NullHandler)
|
||||
|
|
|
|||
Loading…
Reference in a new issue