diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index 6bf3104d150..77fe55d11a7 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -605,51 +605,16 @@ class PythonPrinter(_PythonPrinter): return super().doprint(expr) -class OpOverrides: - def __init__(self, parent): - super().__init__() - self._parent = parent - - @staticmethod - def paren(string: str) -> str: - def all_in_parens(string: str) -> bool: - if string[0] != "(" or len(string) < 2: - return False - count = 1 - for i, char in enumerate(string[1:]): - if char == "(": - count += 1 - elif char == ")": - count -= 1 - if count == 0 and i != len(string) - 2: - return False - assert count == 0 - return True - - if ( - isinstance(string, CSEVariable) - or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE) - or re.match(r"^\([^)]*\)$", string, re.IGNORECASE) - or string == "" - ): - return string - # don't put extra parens for strings that are already wrapped in parens - if all_in_parens(string): - return string - return f"({string})" - - def __getattr__(self, item): - return getattr(self._parent, item) +class OpDecompositions: + """ + Decomposes inductor ops + """ @staticmethod def identity(value): # used to trigger cse return value - @staticmethod - def constant(value, dtype): - return repr(value) - @staticmethod def reciprocal(x): return ops.truediv(ops.constant(1, torch.int32), x) @@ -691,15 +656,86 @@ class OpOverrides: one = ops.constant(1, torch.int32) return ops.truediv(one, ops.add(one, ops.exp(ops.neg(x)))) + @staticmethod + def relu(x): + return ops.maximum(x, ops.constant(0, torch.int32)) + + @staticmethod + def fma(x, y, z): + # for backends that don't override this (halide) + return ops.add(ops.mul(x, y), z) + + @staticmethod + def floor_to_int(a, dtype): + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a, dtype): + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def trunc_to_int(a, dtype): + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def remainder(a, b): + r = ops.mod(a, b) + cond = ops.and_( + ops.ne(r, ops.constant(0, torch.int32)), + ops.ne(ops.signbit(r), ops.signbit(b)), + ) + return ops.where(cond, ops.add(r, b), r) + + @staticmethod + def round_to_int(a, dtype): + return ops.to_dtype(ops.round(a), dtype) + + +class OpOverrides(OpDecompositions): + def __init__(self, parent): + super().__init__() + self._parent = parent + + @staticmethod + def paren(string: str) -> str: + def all_in_parens(string: str) -> bool: + if string[0] != "(" or len(string) < 2: + return False + count = 1 + for i, char in enumerate(string[1:]): + if char == "(": + count += 1 + elif char == ")": + count -= 1 + if count == 0 and i != len(string) - 2: + return False + assert count == 0 + return True + + if ( + isinstance(string, CSEVariable) + or re.match(r"^[a-z0-9_.]+$", string, re.IGNORECASE) + or re.match(r"^\([^)]*\)$", string, re.IGNORECASE) + or string == "" + ): + return string + # don't put extra parens for strings that are already wrapped in parens + if all_in_parens(string): + return string + return f"({string})" + + def __getattr__(self, item): + return getattr(self._parent, item) + + @staticmethod + def constant(value, dtype): + return repr(value) + @staticmethod def libdevice_sigmoid(x): one = ops.constant(1, torch.int32) return ops.truediv(one, ops.add(one, ops.libdevice_exp(ops.neg(x)))) - @staticmethod - def relu(x): - return ops.maximum(x, ops.constant(0, torch.int32)) - @staticmethod def libdevice_abs(x): return ops.abs(x) @@ -752,36 +788,6 @@ class OpOverrides: def bitwise_right_shift(x, y): return f"{OpOverrides.paren(x)} >> {OpOverrides.paren(y)}" - @staticmethod - def remainder(a, b): - r = ops.mod(a, b) - cond = ops.and_( - ops.ne(r, ops.constant(0, torch.int32)), - ops.ne(ops.signbit(r), ops.signbit(b)), - ) - return ops.where(cond, ops.add(r, b), r) - - @staticmethod - def fma(x, y, z): - # for backends that don't override this (halide) - return ops.add(ops.mul(x, y), z) - - @staticmethod - def trunc_to_int(a, dtype): - return ops.to_dtype(ops.trunc(a), dtype) - - @staticmethod - def floor_to_int(a, dtype): - return ops.to_dtype(ops.floor(a), dtype) - - @staticmethod - def ceil_to_int(a, dtype): - return ops.to_dtype(ops.ceil(a), dtype) - - @staticmethod - def round_to_int(a, dtype): - return ops.to_dtype(ops.round(a), dtype) - @staticmethod def int_truediv(a, b): # TODO: this is wrong diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index b31f64872d8..1ae590d444f 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -51,8 +51,7 @@ def _arg_str(a) -> str: # implementations make heavy use of __getattr__ magic, and pre-existing # stubs for methods would interfere with this mechanism. # -# TODO: A superclass that does desugaring for operations like -# reciprocal/square might be useful. +# See OpDecompositions for superclass that desugars operations like reciprocal/square. class OpsHandler(Protocol[T]): """ Protocol describing the set of valid operations on ``torch._inductor.virtualized.ops``,