diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 781e2869e3c..fcc90bc102a 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -6338,6 +6338,21 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor): res = opt_mod(x) self.assertEqual(ref, res) + def test_symint_bitwise(self): + def fn(x): + z = x.shape[0] + z |= z >> 1 + z |= z << 1 + z &= z | (z > 1) + y = (z > 1) | (z <= 1) + # test composition with non-bitwise ops + z = (z | z) % 6 + return y, z + + opt_fn = torch.compile(fn, backend="eager", dynamic=True, fullgraph=True) + inp = torch.randn(3, 3) + self.assertEqual(fn(inp), opt_fn(inp)) + instantiate_parametrized_tests(ReproTests) diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index 0ec0d937d8d..7a2bf8b83e8 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -371,6 +371,39 @@ class TestPySymInt(TestCase): z = y.expand((y.shape[1],)) z = y.expand(y.shape[1]) + def test_symint_bitwise_and(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 0b1100) + b0 = create_symint(shape_env, 0b1010) + res_and = a0 & b0 + self.assertEqual(res_and, 0b1000) + self.assertIsInstance(res_and, torch.SymInt, msg=type(res_and)) + self.assertExpectedInline( + str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_and(s0, s1), 8)""" + ) + + a1 = create_symint(shape_env, 3) + b1 = create_symbool(shape_env, True) + self.assertEqual(a1 & b1, 1) + + a2 = create_symint(shape_env, 0b1100) + self.assertEqual(a2 & 0b1010, 0b1000) + + a3 = create_symbool(shape_env, True) + b3 = create_symbool(shape_env, True) + self.assertEqual(a3 & b3, True) + + def test_symint_bitwise_or(self): + shape_env = ShapeEnv() + a0 = create_symint(shape_env, 0b1100) + b0 = create_symint(shape_env, 0b1010) + res_or = a0 | b0 + self.assertEqual(res_or, 0b1110) + self.assertIsInstance(res_or, torch.SymInt, msg=type(res_or)) + self.assertExpectedInline( + str(shape_env.guards[0][0]), """Eq(BitwiseFn_bitwise_or(s0, s1), 14)""" + ) + def test_stride(self): shape_env = ShapeEnv() x = create_symbolic_tensor("x", torch.randn(5, 5), shape_env) @@ -1273,6 +1306,9 @@ class TestSymNumberMagicMethods(TestCase): if second_type == "float" and fn in ["mod"]: self.skipTest(f"{fn} only handles int") + if fn in sym_node.bitwise_ops and (first_type != "int" or second_type != "int"): + self.skipTest(f"{fn} is a bitwise op, only handles int") + is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": diff --git a/test/test_fx_experimental.py b/test/test_fx_experimental.py index 86d081b7c18..8d842f101cd 100644 --- a/test/test_fx_experimental.py +++ b/test/test_fx_experimental.py @@ -1747,7 +1747,7 @@ if TEST_Z3: import torch._dynamo.config from torch.fx.experimental.validator import SympyToZ3, TranslationValidator, ValidationException, z3str - from torch.utils._sympy.functions import FloorDiv, Mod + from torch.utils._sympy.functions import FloorDiv, Mod, BitwiseFn_bitwise_and class TestTranslationValidation(TestCase): def _prepare_for_translation_validation(self): @@ -1801,6 +1801,8 @@ if TEST_Z3: (sympy.Ge, operator.ge), ) ], + # Bitwise operations. + (BitwiseFn_bitwise_and(s0, s1), z3.BV2Int(z3.Int2BV(z0, 64) & z3.Int2BV(z1, 64))), # Other operations. ( s0 - s1, @@ -1847,6 +1849,18 @@ if TEST_Z3: validator.validate() + def test_sat_bitwise(self): + ( + (s0, s1, s2), + (z0, z1, z2), + validator, + ) = self._prepare_for_translation_validation() + + validator.add_source_expr(z3.BV2Int(z3.Int2BV(z0, 64) & z3.Int2BV(z1, 64)) == 5) + validator.add_source_expr(z0 == 0b110101) + + validator.validate() + def test_unsat(self): ( (s0, s1, s2), diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index 06dd8671167..7771c60c052 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -3,9 +3,9 @@ import functools import itertools import math +import pickle import sys from typing import Callable, List, Tuple, Type -import pickle import sympy @@ -20,7 +20,11 @@ from torch.testing._internal.common_utils import ( TEST_Z3, TestCase, ) -from torch.utils._sympy.functions import FloorDiv, simple_floordiv_gcd +from torch.utils._sympy.functions import ( + FloorDiv, + OpaqueUnaryFn_cos, + simple_floordiv_gcd, +) from torch.utils._sympy.interp import sympy_interp from torch.utils._sympy.numbers import int_oo, IntInfinity, NegativeIntInfinity from torch.utils._sympy.reference import ( @@ -31,7 +35,6 @@ from torch.utils._sympy.reference import ( from torch.utils._sympy.singleton_int import SingletonInt from torch.utils._sympy.solve import INEQUALITY_TYPES, mirror_rel_op, try_solve from torch.utils._sympy.value_ranges import ValueRangeAnalysis, ValueRanges -from torch.utils._sympy.functions import OpaqueUnaryFn_cos UNARY_OPS = [ @@ -58,6 +61,12 @@ BINARY_OPS = [ "minimum", "maximum", "mod", + "bitwise_and", + "bitwise_or", +] +BITWISE_OPS = [ + "bitwise_and", + "bitwise_or", ] UNARY_BOOL_OPS = ["not_"] @@ -231,6 +240,10 @@ class TestValueRanges(TestCase): @parametrize("dtype", ("int", "float")) def test_binary_ref(self, fn, dtype): to_dtype = {"int": sympy.Integer, "float": sympy.Float} + # Don't test bitwise methods since value range analysis on a singleton + # range may not return a singleton result. + if fn in BITWISE_OPS: + return # Don't test float on int only methods if dtype == "float" and fn in ["pow_by_natural", "mod"]: return @@ -280,7 +293,7 @@ class TestValueRanges(TestCase): else: self.assertEqual(len(unique), 2) - @parametrize("fn", BINARY_BOOL_OPS) + @parametrize("fn", BINARY_BOOL_OPS + BITWISE_OPS) def test_binary_bool_ref_range(self, fn): vals = [sympy.false, sympy.true] for a, b in itertools.product(generate_range(vals), repeat=2): @@ -338,6 +351,38 @@ class TestValueRanges(TestCase): if r.is_finite: self.assertIn(r, ref_r) + # stronger test specially for bitwise ops + @parametrize("fn", BITWISE_OPS) + def test_bitwise_ref_range(self, fn): + # N^4 complexity + vals = range(-4, 5) + for a, b in itertools.product(generate_range(vals), repeat=2): + with self.subTest(a=a, b=b): + for a0, b0 in itertools.product(vals, repeat=2): + if a0 not in a or b0 not in b: + continue + with self.subTest(a0=a0, b0=b0): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) + r = getattr(ReferenceAnalysis, fn)(a0, b0) + self.assertIn(r, ref_r) + + # test that bitwise ops can take bool arguments + bool_vals = [ + (3, sympy.true), + (3, sympy.false), + (sympy.true, 3), + (sympy.false, 3), + (sympy.true, sympy.true), + (sympy.true, sympy.false), + (sympy.false, sympy.true), + (sympy.false, sympy.false), + ] + for a, b in bool_vals: + with self.subTest(a=a, b=b): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) + r = getattr(ReferenceAnalysis, fn)(a, b) + self.assertIn(r, ref_r) + class TestSympyInterp(TestCase): @parametrize( @@ -358,6 +403,8 @@ class TestSympyInterp(TestCase): vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] + elif fn in BITWISE_OPS: + vals = vals + [True, False] arity = 1 if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 @@ -395,6 +442,8 @@ class TestSympyInterp(TestCase): vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] + elif fn in BITWISE_OPS: + vals = vals + [True, False] arity = 1 if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: @@ -472,6 +521,8 @@ class TestSympyInterp(TestCase): vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] + elif fn in BITWISE_OPS: + vals = vals + [True, False] arity = 1 if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: @@ -815,7 +866,7 @@ class TestSympySolve(TestCase): class TestSympyFunctions(TestCase): def test_pickle(self): - x = OpaqueUnaryFn_cos(sympy.Symbol('a')) + x = OpaqueUnaryFn_cos(sympy.Symbol("a")) r = pickle.loads(pickle.dumps(x)) self.assertEqual(x, r) diff --git a/torch/__init__.py b/torch/__init__.py index f341818395f..83b35515e2d 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -536,6 +536,12 @@ class SymInt: def __rsub__(self, other: "IntLikeType") -> "SymInt": raise TypeError("type stub not overridden") + def __and__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + + def __or__(self, other) -> "SymInt": + raise TypeError("type stub not overridden") + def __repr__(self): return self.node._graph_repr() @@ -922,6 +928,7 @@ for __name in ( __fn.__qualname__ = __fn.__name__ = __sym_name globals()[__sym_name] = __fn + del __fn, __name, __sym_name, _get_sym_math_fn # Adding temporary shortcut diff --git a/torch/_dynamo/variables/builtin.py b/torch/_dynamo/variables/builtin.py index 2e712663a19..129f4e55803 100644 --- a/torch/_dynamo/variables/builtin.py +++ b/torch/_dynamo/variables/builtin.py @@ -2026,6 +2026,8 @@ class BuiltinVariable(VariableTracker): return SetVariable(list(a.set_items & b.set_items)) # None no-ops this handler and lets the driving function proceed + call_iand = call_and_ + def call_or_(self, tx: "InstructionTranslator", a, b): # Rely on constant_handler if isinstance(a, ConstantVariable) and isinstance(b, ConstantVariable): @@ -2045,6 +2047,8 @@ class BuiltinVariable(VariableTracker): # None no-ops this handler and lets the driving function proceed return None + call_ior = call_or_ + def call_not_(self, tx: "InstructionTranslator", a): if isinstance(a, SymNodeVariable): return SymNodeVariable.create( diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index c041b113130..44739de2be3 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -409,6 +409,13 @@ class SymNode: def sym_and(self, other): return self.and_(other) + # Integer bitwise ops + def bitwise_and(self, other): + return self._bitwise_and(other) # type: ignore[attr-defined] + + def bitwise_or(self, other): + return self._bitwise_or(other) # type: ignore[attr-defined] + # There is no int_truediv available from C++ def truediv(self, other): return self.float_truediv(other) @@ -571,6 +578,7 @@ METHOD_TO_OPERATOR = { "abs": operator.abs, "add": operator.add, "and": operator.and_, + "bitwise_and": operator.and_, "ceil": math.ceil, "eq": operator.eq, "floor": math.floor, @@ -587,6 +595,7 @@ METHOD_TO_OPERATOR = { "ne": operator.ne, "neg": operator.neg, "or": operator.or_, + "bitwise_or": operator.or_, "float_pow": operator.pow, "pow_by_natural": operator.pow, "round": builtins.round, @@ -665,6 +674,11 @@ only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} +# remap necessary because an op name can have a bitwise and boolean implementation +bitwise_ops = { + "bitwise_and": "and", + "bitwise_or": "or", +} always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} @@ -755,6 +769,18 @@ def _sympy_rshift(a, b): return RShift(a, b) +def _bitwise_and(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_and + + return BitwiseFn_bitwise_and(a, b) + + +def _bitwise_or(a, b): + from torch.utils._sympy.functions import BitwiseFn_bitwise_or + + return BitwiseFn_bitwise_or(a, b) + + reflectable_magic_methods = { "add": operator.add, "sub": operator.sub, @@ -763,7 +789,9 @@ reflectable_magic_methods = { "pow_by_natural": _sympy_pow_by_natural, "float_pow": _sympy_float_pow, "and": _sympy_and, + "bitwise_and": _bitwise_and, "or": _sympy_or, + "bitwise_or": _bitwise_or, "float_truediv": _sympy_float_truediv, "int_truediv": _sympy_int_truediv, "int_floordiv": _sympy_floordiv, @@ -1570,9 +1598,12 @@ def _make_user_magic(method, user_type): setattr(user_type, f"__{method}__", round_magic_impl) else: - setattr(user_type, f"__{method}__", binary_magic_impl) + method_name = method + if method in bitwise_ops: + method_name = bitwise_ops[method] + setattr(user_type, f"__{method_name}__", binary_magic_impl) if method in reflectable_magic_methods: - setattr(user_type, f"__r{method}__", rbinary_magic_impl) + setattr(user_type, f"__r{method_name}__", rbinary_magic_impl) for method, func in magic_methods.items(): # type: ignore[assignment] @@ -1585,7 +1616,8 @@ for method, func in magic_methods.items(): # type: ignore[assignment] if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods: _make_user_magic(method, SymBool) _make_user_magic(method, SymInt) - _make_user_magic(method, SymFloat) + if method not in bitwise_ops: + _make_user_magic(method, SymFloat) del method del func diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 1dd6e27b9b3..503918bbc74 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -139,6 +139,22 @@ try: string = op + " " + " ".join(args) return f"({string.rstrip()})" + # We need to convert to/from BitVec in order to use z3 bitwise ops. + # We assume that integers are 64 bit. + # If all args are boolean, then use the boolean bitwise op implementation instead, if provided. + def _bitwise_op(bitwise_func, bool_func): + @functools.wraps(bitwise_func) + def wrapper(self, *args): + if bool_func is not None and all( + isinstance(arg, z3.BoolRef) for arg in args + ): + return bool_func(*args) + + wrapped_args = tuple(z3.Int2BV(a, 64) for a in args) + return z3.BV2Int(bitwise_func(*wrapped_args)) + + return wrapper + # Implementation of Python semantics as Z3 expressions. # # Z3 Real-Int theory has operators with semantics that differ that of @@ -234,6 +250,11 @@ try: self.floor(number + 0.5), ) + bitwise_and = _bitwise_op(operator.and_, z3.And) + bitwise_or = _bitwise_op(operator.or_, z3.Or) + lshift = _bitwise_op(operator.lshift, None) + rshift = _bitwise_op(operator.rshift, None) + # Lifts a callable to be used in Z3. # # This function replaces the given 'op' by a function that: @@ -247,7 +268,7 @@ try: # This is needed because the argument of some FX nodes were # literal integers, instead of booleans. So, whenever this flag # is set, we also convert ints to booleans. - boolean_ops = {operator.not_, operator.and_, operator.or_} + boolean_ops = {operator.not_} as_bool = op in boolean_ops # Lifts the function into 'z3.ExprRef' domain. @@ -281,8 +302,10 @@ try: replacement_map = { # Operator module. operator.not_: lift(z3.Not), - operator.and_: lift(z3.And), - operator.or_: lift(z3.Or), + operator.and_: lift(ops.bitwise_and), + operator.or_: lift(ops.bitwise_or), + operator.lshift: lift(ops.lshift), + operator.rshift: lift(ops.rshift), operator.floordiv: lift(ops.floordiv), operator.truediv: lift(ops.div), operator.mod: lift(ops.mod), @@ -416,6 +439,10 @@ try: "and_": z3.And, "or_": z3.Or, "not_": z3.Not, + "bitwise_and": self._ops.bitwise_and, + "bitwise_or": self._ops.bitwise_or, + "lshift": self._ops.lshift, + "rshift": self._ops.rshift, "floor": self._ops.floor, "ceil": self._ops.ceil, "minimum": self._ops.min, diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 920b097b632..08807968bc6 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -845,6 +845,7 @@ class Max(MinMaxBase, Application): # type: ignore[misc] r""" Return, if possible, the maximum value of the list. """ + zero = S.Infinity identity = S.NegativeInfinity @@ -1224,3 +1225,29 @@ OpaqueUnaryFn_exp = make_opaque_unary_fn("exp") OpaqueUnaryFn_log = make_opaque_unary_fn("log") OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh") OpaqueUnaryFn_log2 = make_opaque_unary_fn("log2") + + +def make_opaque_bitwise_fn(name, real_op_name): + class BitwiseFn(sympy.Function): + _torch_handler_name = name + + @classmethod + def eval(cls, a, b): + if a.is_Boolean and b.is_Boolean: + return getattr(operator, real_op_name)(a, b) + if a.is_Boolean: + a = sympy.Integer(1 if a else 0) + if b.is_Boolean: + b = sympy.Integer(1 if b else 0) + if isinstance(a, (sympy.Integer, int)) and isinstance( + b, (sympy.Integer, int) + ): + return sympy.Integer(getattr(operator, real_op_name)(int(a), int(b))) + return None + + BitwiseFn.__name__ = "BitwiseFn_" + name + return BitwiseFn + + +BitwiseFn_bitwise_and = make_opaque_bitwise_fn("bitwise_and", "and_") +BitwiseFn_bitwise_or = make_opaque_bitwise_fn("bitwise_or", "or_") diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 3d26fa861b1..718a4938b40 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -18,6 +18,8 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom import torch from .functions import ( + BitwiseFn_bitwise_and, + BitwiseFn_bitwise_or, CeilToInt, CleanDiv, FloatPow, @@ -104,6 +106,8 @@ def handlers(): RoundDecimal: "round_decimal", # TODO: do the rest of the opaque unary functions... OpaqueUnaryFn_log2: "log2", + BitwiseFn_bitwise_and: "bitwise_and", + BitwiseFn_bitwise_or: "bitwise_or", } # TODO: This is kind of pointless, we shouldn't be generating sympy.sin # for these functions, they should be Opaque instead diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 3798fe3ea13..8c960e92f22 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -8,6 +8,8 @@ import sympy import torch from torch.utils._sympy.functions import ( _keep_float, + BitwiseFn_bitwise_and, + BitwiseFn_bitwise_or, FloatPow, FloatTrueDiv, FloorDiv, @@ -195,6 +197,14 @@ class ReferenceAnalysis: def round_decimal(a, b): return RoundDecimal(a, b) + @staticmethod + def bitwise_and(a, b): + return BitwiseFn_bitwise_and(a, b) + + @staticmethod + def bitwise_or(a, b): + return BitwiseFn_bitwise_or(a, b) + # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain # Python types and is FX traceable. Inheritance here is purely for code @@ -307,6 +317,14 @@ class PythonReferenceAnalysis(ReferenceAnalysis): def round_decimal(a, b): return round(a, ndigits=b) + @staticmethod + def bitwise_and(a, b): + return a & b + + @staticmethod + def bitwise_or(a, b): + return a | b + # Like PythonReferenceAnalysis, but some export-unfriendly choices of # operators to make things faster @@ -358,6 +376,14 @@ class TensorReferenceAnalysis: def and_(a, b): return torch.ops.aten.logical_and.default(a, b) + @staticmethod + def bitwise_and(a, b): + return torch.ops.aten.bitwise_and(a, b) + + @staticmethod + def bitwise_or(a, b): + return torch.ops.aten.bitwise_or(a, b) + @staticmethod def eq(a, b): return torch.ops.aten.eq.Tensor(a, b) diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index 38cb27ebd40..171ec73d93e 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -500,6 +500,53 @@ class SymPyValueRangeAnalysis: def and_(a, b): return ValueRanges.coordinatewise_increasing_map(a, b, sympy.And) + @staticmethod + def _bool_to_int(x): + if x.is_singleton(): + return ValueRanges.wrap(sympy.Integer(1 if x.lower else 0)) + else: + return ValueRanges(sympy.Integer(0), sympy.Integer(1)) + + @classmethod + def bitwise_and(cls, a, b): + a, b = ValueRanges.wrap(a), ValueRanges.wrap(b) + if a.is_bool and b.is_bool: + return cls.and_(a, b) + if a.is_bool: + a = cls._bool_to_int(a) + if b.is_bool: + b = cls._bool_to_int(b) + lower = min(a.lower, b.lower) + if lower < 0 and lower != -int_oo: + # If both lower bounds are negative, then bits start like + # 1...10..., so the smallest possible value is 1...101...1. + # Thus, we need to find the next smallest power of 2 (inclusive). + lower = -(1 << int(-lower - 1).bit_length()) + else: + lower = 0 + return ValueRanges(lower, max(a.upper, b.upper)) + + @classmethod + def bitwise_or(cls, a, b): + a, b = ValueRanges.wrap(a), ValueRanges.wrap(b) + if a.is_bool and b.is_bool: + return cls.or_(a, b) + if a.is_bool: + a = cls._bool_to_int(a) + if b.is_bool: + b = cls._bool_to_int(b) + upper = max(a.upper, b.upper) + if upper == 0: + upper = 0 + elif upper > 0 and upper != int_oo: + # If both upper bounds are positive, then the largest + # possible value is 01...1, so we need to find + # next largest power of 2 (exclusive), minus 1 + upper = (1 << int(upper).bit_length()) - 1 + elif upper < 0: + upper = -1 + return ValueRanges(min(a.lower, b.lower), upper) + @staticmethod def eq(a, b): a = ValueRanges.wrap(a) @@ -1061,12 +1108,14 @@ def bound_sympy( "bound_sympy(%s)%s", expr, LazyString( - lambda: "\n" - + "\n".join( - f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols + lambda: ( + "\n" + + "\n".join( + f" {k}: {r}" for k, r in ranges.items() if k in expr.free_symbols + ) + if ranges + else "" ) - if ranges - else "" ), ) if isinstance(expr, sympy.Number):