[Easy] factor out inductor ophandler decompositions (#142400)

Factor out inductor operator decompositions

Pull Request resolved: https://github.com/pytorch/pytorch/pull/142400
Approved by: https://github.com/Chillee, https://github.com/jansel
This commit is contained in:
eellison 2024-12-11 16:36:48 -08:00 committed by PyTorch MergeBot
parent c170248b78
commit 0b75b7ff2b
2 changed files with 80 additions and 75 deletions

View file

@ -605,51 +605,16 @@ class PythonPrinter(_PythonPrinter):
return super().doprint(expr)
class OpOverrides:
def __init__(self, parent):
super().__init__()
self._parent = parent
@staticmethod
def paren(string: str) -> str:
def all_in_parens(string: str) -> bool:
if string[0] != "(" or len(string) < 2:
return False
count = 1
for i, char in enumerate(string[1:]):
if char == "(":
count += 1
elif char == ")":
count -= 1
if count == 0 and i != len(string) - 2:
return False
assert count == 0
return True
if (
isinstance(string, CSEVariable)
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
or string == ""
):
return string
# don't put extra parens for strings that are already wrapped in parens
if all_in_parens(string):
return string
return f"({string})"
def __getattr__(self, item):
return getattr(self._parent, item)
class OpDecompositions:
"""
Decomposes inductor ops
"""
@staticmethod
def identity(value):
# used to trigger cse
return value
@staticmethod
def constant(value, dtype):
return repr(value)
@staticmethod
def reciprocal(x):
return ops.truediv(ops.constant(1, torch.int32), x)
@ -691,15 +656,86 @@ class OpOverrides:
one = ops.constant(1, torch.int32)
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
@staticmethod
def relu(x):
return ops.maximum(x, ops.constant(0, torch.int32))
@staticmethod
def fma(x, y, z):
# for backends that don't override this (halide)
return ops.add(ops.mul(x, y), z)
@staticmethod
def floor_to_int(a, dtype):
return ops.to_dtype(ops.floor(a), dtype)
@staticmethod
def ceil_to_int(a, dtype):
return ops.to_dtype(ops.ceil(a), dtype)
@staticmethod
def trunc_to_int(a, dtype):
return ops.to_dtype(ops.trunc(a), dtype)
@staticmethod
def remainder(a, b):
r = ops.mod(a, b)
cond = ops.and_(
ops.ne(r, ops.constant(0, torch.int32)),
ops.ne(ops.signbit(r), ops.signbit(b)),
)
return ops.where(cond, ops.add(r, b), r)
@staticmethod
def round_to_int(a, dtype):
return ops.to_dtype(ops.round(a), dtype)
class OpOverrides(OpDecompositions):
def __init__(self, parent):
super().__init__()
self._parent = parent
@staticmethod
def paren(string: str) -> str:
def all_in_parens(string: str) -> bool:
if string[0] != "(" or len(string) < 2:
return False
count = 1
for i, char in enumerate(string[1:]):
if char == "(":
count += 1
elif char == ")":
count -= 1
if count == 0 and i != len(string) - 2:
return False
assert count == 0
return True
if (
isinstance(string, CSEVariable)
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
or string == ""
):
return string
# don't put extra parens for strings that are already wrapped in parens
if all_in_parens(string):
return string
return f"({string})"
def __getattr__(self, item):
return getattr(self._parent, item)
@staticmethod
def constant(value, dtype):
return repr(value)
@staticmethod
def libdevice_sigmoid(x):
one = ops.constant(1, torch.int32)
return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
@staticmethod
def relu(x):
return ops.maximum(x, ops.constant(0, torch.int32))
@staticmethod
def libdevice_abs(x):
return ops.abs(x)
@ -752,36 +788,6 @@ class OpOverrides:
def bitwise_right_shift(x, y):
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
@staticmethod
def remainder(a, b):
r = ops.mod(a, b)
cond = ops.and_(
ops.ne(r, ops.constant(0, torch.int32)),
ops.ne(ops.signbit(r), ops.signbit(b)),
)
return ops.where(cond, ops.add(r, b), r)
@staticmethod
def fma(x, y, z):
# for backends that don't override this (halide)
return ops.add(ops.mul(x, y), z)
@staticmethod
def trunc_to_int(a, dtype):
return ops.to_dtype(ops.trunc(a), dtype)
@staticmethod
def floor_to_int(a, dtype):
return ops.to_dtype(ops.floor(a), dtype)
@staticmethod
def ceil_to_int(a, dtype):
return ops.to_dtype(ops.ceil(a), dtype)
@staticmethod
def round_to_int(a, dtype):
return ops.to_dtype(ops.round(a), dtype)
@staticmethod
def int_truediv(a, b):
# TODO: this is wrong

View file

@ -51,8 +51,7 @@ def _arg_str(a) -> str:
# implementations make heavy use of __getattr__ magic, and pre-existing
# stubs for methods would interfere with this mechanism.
#
# TODO: A superclass that does desugaring for operations like
# reciprocal/square might be useful.
# See OpDecompositions for superclass that desugars operations like reciprocal/square.
class OpsHandler(Protocol[T]):
"""
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,