mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[Easy] factor out inductor ophandler decompositions (#142400)
Factor out inductor operator decompositions Pull Request resolved: https://github.com/pytorch/pytorch/pull/142400 Approved by: https://github.com/Chillee, https://github.com/jansel
This commit is contained in:
parent
c170248b78
commit
0b75b7ff2b
2 changed files with 80 additions and 75 deletions
|
|
@ -605,51 +605,16 @@ class PythonPrinter(_PythonPrinter):
|
|||
return super().doprint(expr)
|
||||
|
||||
|
||||
class OpOverrides:
|
||||
def __init__(self, parent):
|
||||
super().__init__()
|
||||
self._parent = parent
|
||||
|
||||
@staticmethod
|
||||
def paren(string: str) -> str:
|
||||
def all_in_parens(string: str) -> bool:
|
||||
if string[0] != "(" or len(string) < 2:
|
||||
return False
|
||||
count = 1
|
||||
for i, char in enumerate(string[1:]):
|
||||
if char == "(":
|
||||
count += 1
|
||||
elif char == ")":
|
||||
count -= 1
|
||||
if count == 0 and i != len(string) - 2:
|
||||
return False
|
||||
assert count == 0
|
||||
return True
|
||||
|
||||
if (
|
||||
isinstance(string, CSEVariable)
|
||||
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
|
||||
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
|
||||
or string == ""
|
||||
):
|
||||
return string
|
||||
# don't put extra parens for strings that are already wrapped in parens
|
||||
if all_in_parens(string):
|
||||
return string
|
||||
return f"({string})"
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._parent, item)
|
||||
class OpDecompositions:
|
||||
"""
|
||||
Decomposes inductor ops
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def identity(value):
|
||||
# used to trigger cse
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def constant(value, dtype):
|
||||
return repr(value)
|
||||
|
||||
@staticmethod
|
||||
def reciprocal(x):
|
||||
return ops.truediv(ops.constant(1, torch.int32), x)
|
||||
|
|
@ -691,15 +656,86 @@ class OpOverrides:
|
|||
one = ops.constant(1, torch.int32)
|
||||
return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x))))
|
||||
|
||||
@staticmethod
|
||||
def relu(x):
|
||||
return ops.maximum(x, ops.constant(0, torch.int32))
|
||||
|
||||
@staticmethod
|
||||
def fma(x, y, z):
|
||||
# for backends that don't override this (halide)
|
||||
return ops.add(ops.mul(x, y), z)
|
||||
|
||||
@staticmethod
|
||||
def floor_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.floor(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def ceil_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.ceil(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def trunc_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.trunc(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def remainder(a, b):
|
||||
r = ops.mod(a, b)
|
||||
cond = ops.and_(
|
||||
ops.ne(r, ops.constant(0, torch.int32)),
|
||||
ops.ne(ops.signbit(r), ops.signbit(b)),
|
||||
)
|
||||
return ops.where(cond, ops.add(r, b), r)
|
||||
|
||||
@staticmethod
|
||||
def round_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.round(a), dtype)
|
||||
|
||||
|
||||
class OpOverrides(OpDecompositions):
|
||||
def __init__(self, parent):
|
||||
super().__init__()
|
||||
self._parent = parent
|
||||
|
||||
@staticmethod
|
||||
def paren(string: str) -> str:
|
||||
def all_in_parens(string: str) -> bool:
|
||||
if string[0] != "(" or len(string) < 2:
|
||||
return False
|
||||
count = 1
|
||||
for i, char in enumerate(string[1:]):
|
||||
if char == "(":
|
||||
count += 1
|
||||
elif char == ")":
|
||||
count -= 1
|
||||
if count == 0 and i != len(string) - 2:
|
||||
return False
|
||||
assert count == 0
|
||||
return True
|
||||
|
||||
if (
|
||||
isinstance(string, CSEVariable)
|
||||
or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE)
|
||||
or re.match(r"^\([^)]*\)$", string, re.IGNORECASE)
|
||||
or string == ""
|
||||
):
|
||||
return string
|
||||
# don't put extra parens for strings that are already wrapped in parens
|
||||
if all_in_parens(string):
|
||||
return string
|
||||
return f"({string})"
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._parent, item)
|
||||
|
||||
@staticmethod
|
||||
def constant(value, dtype):
|
||||
return repr(value)
|
||||
|
||||
@staticmethod
|
||||
def libdevice_sigmoid(x):
|
||||
one = ops.constant(1, torch.int32)
|
||||
return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x))))
|
||||
|
||||
@staticmethod
|
||||
def relu(x):
|
||||
return ops.maximum(x, ops.constant(0, torch.int32))
|
||||
|
||||
@staticmethod
|
||||
def libdevice_abs(x):
|
||||
return ops.abs(x)
|
||||
|
|
@ -752,36 +788,6 @@ class OpOverrides:
|
|||
def bitwise_right_shift(x, y):
|
||||
return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}"
|
||||
|
||||
@staticmethod
|
||||
def remainder(a, b):
|
||||
r = ops.mod(a, b)
|
||||
cond = ops.and_(
|
||||
ops.ne(r, ops.constant(0, torch.int32)),
|
||||
ops.ne(ops.signbit(r), ops.signbit(b)),
|
||||
)
|
||||
return ops.where(cond, ops.add(r, b), r)
|
||||
|
||||
@staticmethod
|
||||
def fma(x, y, z):
|
||||
# for backends that don't override this (halide)
|
||||
return ops.add(ops.mul(x, y), z)
|
||||
|
||||
@staticmethod
|
||||
def trunc_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.trunc(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def floor_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.floor(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def ceil_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.ceil(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def round_to_int(a, dtype):
|
||||
return ops.to_dtype(ops.round(a), dtype)
|
||||
|
||||
@staticmethod
|
||||
def int_truediv(a, b):
|
||||
# TODO: this is wrong
|
||||
|
|
|
|||
|
|
@ -51,8 +51,7 @@ def _arg_str(a) -> str:
|
|||
# implementations make heavy use of __getattr__ magic, and pre-existing
|
||||
# stubs for methods would interfere with this mechanism.
|
||||
#
|
||||
# TODO: A superclass that does desugaring for operations like
|
||||
# reciprocal/square might be useful.
|
||||
# See OpDecompositions for superclass that desugars operations like reciprocal/square.
|
||||
class OpsHandler(Protocol[T]):
|
||||
"""
|
||||
Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,
|
||||
|
|
|
|||
Loading…
Reference in a new issue