diff --git a/c10/core/SymNodeImpl.h b/c10/core/SymNodeImpl.h index 9ffab506510..bb92b09775b 100644 --- a/c10/core/SymNodeImpl.h +++ b/c10/core/SymNodeImpl.h @@ -49,15 +49,33 @@ class C10_API SymNodeImpl : public c10::intrusive_ptr_target { virtual SymNode mul(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + // NB: legacy, prefer float_truediv or int_truediv virtual SymNode truediv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_truediv(const SymNode& other) { + return truediv(other); + } + virtual SymNode int_truediv(const SymNode& other) { + return truediv(other); + } + // NB: legacy, prefer float_pow or pow_by_natural virtual SymNode pow(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode float_pow(const SymNode& other) { + return pow(other); + } + virtual SymNode pow_by_natural(const SymNode& other) { + return pow(other); + } + // NB: legacy, prefer int_floordiv virtual SymNode floordiv(const SymNode& other) { TORCH_CHECK(false, "NYI"); } + virtual SymNode int_floordiv(const SymNode& other) { + return floordiv(other); + } virtual SymNode mod(const SymNode& other) { TORCH_CHECK(false, "NYI"); } diff --git a/test/dynamo/test_dynamic_shapes.py b/test/dynamo/test_dynamic_shapes.py index 0bead6e47e4..a3c63ef6615 100644 --- a/test/dynamo/test_dynamic_shapes.py +++ b/test/dynamo/test_dynamic_shapes.py @@ -78,13 +78,6 @@ for test in tests: del test if TEST_Z3: - # this only fails when z3 is available - unittest.expectedFailure( - # SymPy is incorrectly transforming 's0 / 6 == 0.5' into 'False'. - # Ref: https://github.com/sympy/sympy/issues/25146 - DynamicShapesReproTests.test_dynamic_shapes_float_guard_dynamic_shapes # noqa: F821 - ) - if not config.inline_inbuilt_nn_modules: # TODO model is somehow not being freed when z3 is available unittest.expectedFailure( diff --git a/test/dynamo/test_export.py b/test/dynamo/test_export.py index 9f1417e2324..7ae0f839f6f 100644 --- a/test/dynamo/test_export.py +++ b/test/dynamo/test_export.py @@ -2385,8 +2385,7 @@ def forward(self, x): with self.assertRaisesRegex( torch._dynamo.exc.UserError, "Constraints violated .*!(.*\n)*.*" - "by dim0 = 2\\*dim1(.*\n)*.*" - "Not all values of dim1 .* satisfy the generated guard 2 <= .* and .* <= 5(.*\n)*.*", + "Not all values of dim0 .* satisfy the generated guard 4 <= .* and .* <= 10(.*\n)*.*", ): torch.export.export(foo, (t,), dynamic_shapes=dynamic_shapes) diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index bcb0fd18818..dc2b9530f0d 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -9309,7 +9309,7 @@ ShapeEnv not equal: field values don't match: > Left: {0: 0, 1: 1, 2: s1, 3: s0} > Right: {0: 0, 1: 1} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} > Right: {} ==> var_to_sources: values don't match. > Left: {s0: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=0)], s1: [TensorPropertySource(base=ConstantSource(source_name='x'), prop=, idx=1)]} @@ -9343,7 +9343,7 @@ ShapeEnv not equal: field values don't match: > Left: 2 > Right: 0 ==> var_to_range: values don't match. - > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False), u1: ValueRanges(lower=0, upper=1, is_bool=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False)} + > Left: {u0: ValueRanges(lower=-9223372036854775808, upper=9223372036854775807, is_bool=False, is_int=True, is_float=False), u1: ValueRanges(lower=0, upper=1, is_bool=False, is_int=True, is_float=False), zuf0: ValueRanges(lower=-oo, upper=oo, is_bool=False, is_int=False, is_float=True)} > Right: {} """, ) @@ -9420,8 +9420,8 @@ ShapeEnv not equal: field values don't match: > Left: {s0: 3} > Right: {} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=3, upper=3, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} """, ) self._replay_and_check(main) @@ -9458,8 +9458,8 @@ ShapeEnv not equal: field values don't match: > Left: {_assert, ge, x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} > Right: {x_size_0_, x_size_1_, x_storage_offset, x_stride_0_, x_stride_1_} ==> var_to_range: values don't match. - > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} - > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False)} + > Left: {s0: ValueRanges(lower=3, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} + > Right: {s0: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False), s1: ValueRanges(lower=2, upper=9223372036854775806, is_bool=False, is_int=True, is_float=False)} """, ) self._replay_and_check(main) @@ -9484,10 +9484,7 @@ ShapeEnv not equal: field values don't match: ShapeEnv not equal: field values don't match: ==> deferred_runtime_asserts: values don't match. - > Left: {u0: [Eq(Mod(u0, 3), 0)]} - > Right: {} -==> divisible: values don't match. - > Left: {Mod(u0, 3)} + > Left: {u0: [Eq(PythonMod(u0, 3), 0)]} > Right: {} ==> name_to_node: values don't match. > Left: {_assert, eq, mod, u0} diff --git a/test/inductor/test_indexing.py b/test/inductor/test_indexing.py index 299a619f9cd..da527cfbb1d 100644 --- a/test/inductor/test_indexing.py +++ b/test/inductor/test_indexing.py @@ -11,7 +11,12 @@ from torch.testing._internal.common_utils import ( instantiate_parametrized_tests, parametrize, ) -from torch.utils._sympy.functions import FloorDiv, ModularIndexing, Round, RoundDecimal +from torch.utils._sympy.functions import ( + FloorDiv, + ModularIndexing, + RoundDecimal, + RoundToInt, +) class TestIndexingSimplification(InductorTestCase): @@ -168,21 +173,11 @@ class ExprPrinterTests(InductorTestCase): common_cases = [ # expr, result - # Test exprs. - ( - s1 / (2 * s1 - 1) - 1 / (2 * s1 - 1), - lambda c, L: f"((-1{L})*({c}/((-1{L}) + (2{L}*foo)))) + (foo*({c}/((-1{L}) + (2{L}*foo))))", - ), - (s1 / (s2 - s3), lambda c, L: f"foo*({c}/(bar + ((-1{L})*baz)))"), # Test Pow directly. ( sympy.Pow(s1 + s2, 0), lambda _, L: f"1{L}", ), # note: simplified before _print_Pow - ( - sympy.Pow(s1 + s2, -3), - lambda c, _: f"{c}/((bar + foo)*(bar + foo)*(bar + foo))", - ), ] gpu_cases = common_cases + [ @@ -231,12 +226,10 @@ class ExprPrinterTests(InductorTestCase): self.assertExpectedInline(cexpr(expr), """std::ceil((1.0/2.0)*s1)""") def test_print_round(self): - expr = Round(sympy.Symbol("x", integer=True) / 2) + expr = RoundToInt(sympy.Symbol("x", integer=True) / 2) self.assertExpectedInline(pexpr(expr), """round((1/2)*x)""") self.assertExpectedInline(cexpr(expr), """std::lrint((1.0/2.0)*x)""") - self.assertExpectedInline( - texpr(expr), """libdevice.llrint((1/2)*x).to(tl.int64)""" - ) + self.assertExpectedInline(texpr(expr), """libdevice.llrint((1/2)*x)""") @parametrize("ndigits", [-1, 0, 1]) def test_print_round_decimal(self, ndigits): @@ -251,45 +244,18 @@ class ExprPrinterTests(InductorTestCase): f"libdevice.nearbyint(1e{ndigits} * ((1/2)*x)) * 1e{-ndigits}", ) - expr = RoundDecimal(sympy.Symbol("x", integer=True), ndigits) - if ndigits >= 0: - for do_print in [pexpr, cexpr, texpr]: - self.assertEqual(do_print(expr), "x") - else: - self.assertEqual(pexpr(expr), f"round(x, {ndigits})") - for do_print in [cexpr, texpr]: - with self.assertRaisesRegex( - ValueError, "only non-negative ndigits are currently supported" - ): - do_print(expr) - def test_print_floor_div(self): - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.Symbol("s2", integer=integer) - expr = FloorDiv(s1, s2) - self.assertEqual(pexpr(expr), "(s1 // s2)") - if integer: - self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") - else: - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast(s2))", - ) + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.Symbol("s2", integer=True) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(s1 // s2)") + self.assertEqual(cexpr(expr), "c10::div_floor_integer(s1, s2)") - for integer in [True, False]: - s1 = sympy.Symbol("s1", integer=integer) - s2 = sympy.S(-1) - expr = FloorDiv(s1, s2) - if integer: - self.assertEqual(pexpr(expr), "(-1)*s1") - self.assertEqual(cexpr(expr), "(-1L)*s1") - else: - self.assertEqual(pexpr(expr), "(s1 // (-1))") - self.assertEqual( - cexpr(expr), - "c10::div_floor_floating(static_cast(s1), static_cast((-1L)))", - ) + s1 = sympy.Symbol("s1", integer=True) + s2 = sympy.S(-1) + expr = FloorDiv(s1, s2) + self.assertEqual(pexpr(expr), "(-1)*s1") + self.assertEqual(cexpr(expr), "(-1L)*s1") def test_print_Min_Max(self): cases = ( diff --git a/test/inductor/test_torchinductor_dynamic_shapes.py b/test/inductor/test_torchinductor_dynamic_shapes.py index 8513e928c41..2f9506a9d56 100644 --- a/test/inductor/test_torchinductor_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_dynamic_shapes.py @@ -3,6 +3,7 @@ import contextlib import importlib import math +import operator import os import sys import unittest @@ -649,6 +650,33 @@ class TestInductorDynamic(TestCase): actual = cfn(5) self.assertEqual(expect, actual) + def test_interpolate_ceil_eq(self, device): + ceiling = math.ceil + IntTrueDiv = operator.truediv + + def fn(t): + s0, s2, s3 = t.size() + x = torch.zeros( + ( + s0, + 2048, + ceiling(IntTrueDiv(2 * ((s2 - 1) // 8) + 2, 1)), + ceiling(IntTrueDiv(2 * ((s3 - 1) // 8) + 2, 1)), + ), + dtype=torch.bfloat16, + ) + return torch.nn.functional.interpolate( + x, + scale_factor=2, + mode="nearest", + ) + + cfn = self.compile_fn(fn) + arg = torch.randn(4, 16, 18) + expect = fn(arg) + actual = cfn(arg) + self.assertEqual(expect, actual) + def test_full_recompiles(self, device): def fn(x): _, L = x.shape diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index b70bfbf9c4a..0f0e01bc0dc 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -158,8 +158,12 @@ class TestFxToOnnxWithOnnxRuntime(onnx_test_common._TestONNXRuntime): torch.tensor([operator.sub(x.item(), y.item())]), torch.tensor([operator.mul(x.item(), y.item())]), torch.tensor([operator.truediv(x.item(), y.item())]), - torch.tensor([operator.floordiv(x.item(), y.item())]), - torch.tensor([operator.pow(x.item(), y.item())]), + # This requires torch.sym_float, probably easy to lower to + # ONNX but I don't know where to put it + # torch.tensor([operator.floordiv(x.item(), y.item())]), + # NB: abs so that the base and exponent are provably + # non-negative, so we don't generate runtime asserts + torch.tensor([operator.pow(abs(x.item()), abs(y.item()))]), torch.tensor([operator.abs(x.item())]), torch.tensor([operator.neg(x.item())]), torch.tensor([math.ceil(x.item())]), diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d548e9df070..3b47f12198d 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -205,15 +205,15 @@ def create_symtype(cls, pytype, shape_env, val, duck=True): # TODO: default duck to False -def create_symint(shape_env, i: int, duck=True): +def create_symint(shape_env, i: int, duck=True) -> SymInt: return create_symtype(SymInt, int, shape_env, i, duck=duck) -def create_symbool(shape_env, b: bool): +def create_symbool(shape_env, b: bool) -> SymBool: return create_symtype(SymBool, bool, shape_env, b) -def create_symfloat(shape_env, f: float): +def create_symfloat(shape_env, f: float) -> SymFloat: return create_symtype(SymFloat, float, shape_env, f) @@ -457,14 +457,16 @@ class TestPySymInt(TestCase): r = sym_int(a1 / 2) self.assertEqual(guard_int(r), 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(Trunc(s1/2), 3)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), """Eq(TruncToInt(IntTrueDiv(s1, 2)), 3)""" + ) a3 = create_symint(shape_env, 3) r = sym_int(2.0 * torch.sym_float(a3)) self.assertEqual(guard_int(r), 6) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[2][0]), """Eq(Trunc(2.0*s2), 6)""" + str(shape_env.guards[2][0]), """Eq(TruncToInt(2.0*ToFloat(s2)), 6)""" ) def test_sym_sqrt(self): @@ -474,7 +476,7 @@ class TestPySymInt(TestCase): self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymFloat, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2)""" + str(shape_env.guards[0][0]), """Eq(OpaqueUnaryFn_sqrt(s0), 2.0)""" ) def test_sym_floor(self): @@ -483,11 +485,17 @@ class TestPySymInt(TestCase): r = math.floor(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(floor(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), + """Eq(FloorToInt(IntTrueDiv(s0, 2)), 2)""", + ) r = math.floor(3.0 * a0) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", + ) def test_sym_trunc(self): shape_env = ShapeEnv() @@ -495,12 +503,14 @@ class TestPySymInt(TestCase): r = math.trunc(a0 / 2) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[0][0]), """Eq(Trunc(s0/2), 2)""") + self.assertExpectedInline( + str(shape_env.guards[0][0]), """Eq(TruncToInt(IntTrueDiv(s0, 2)), 2)""" + ) r = torch.sym_int(torch.sym_sqrt(a0)) self.assertEqual(r, 2) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[1][0]), """Eq(Trunc(OpaqueUnaryFn_sqrt(s0)), 2)""" + str(shape_env.guards[1][0]), """Eq(TruncToInt(OpaqueUnaryFn_sqrt(s0)), 2)""" ) def test_sym_ceil(self): @@ -510,12 +520,17 @@ class TestPySymInt(TestCase): self.assertEqual(r, 3) self.assertIsInstance(r, torch.SymInt, msg=type(r)) self.assertExpectedInline( - str(shape_env.guards[0][0]), """Eq(ceiling(s0/2), 3)""" + str(shape_env.guards[0][0]), + """Eq(CeilToInt(IntTrueDiv(s0, 2)), 3)""", ) - r = math.floor(3.0 * a0) + r1 = 3.0 * a0 + r = math.floor(r1) self.assertEqual(r, 15) self.assertIsInstance(r, torch.SymInt, msg=type(r)) - self.assertExpectedInline(str(shape_env.guards[1][0]), """Eq(3*s0, 15)""") + self.assertExpectedInline( + str(shape_env.guards[1][0]), + """Eq(FloorToInt(3.0*ToFloat(s0)), 15)""", + ) def test_sym_ite(self): shape_env = ShapeEnv() @@ -962,8 +977,14 @@ class f(torch.nn.Module): ) class TestSymNumberMagicMethods(TestCase): def _do_test(self, fn, inp1, inp2, shape_env, is_unary_fn): + with self.subTest(fn=fn, inp1=inp1, inp2=inp2, is_unary_fn=is_unary_fn): + return self._do_test2(fn, inp1, inp2, shape_env, is_unary_fn) + + def _do_test2(self, fn, inp1, inp2, shape_env, is_unary_fn): # Helper function # NB: don't use one as that will get specialized + # TODO: We don't have to circuitously create the float, can just + # create a symfloat directly seed_node = (create_symint(shape_env, 2) / 2.0).node bool_seed_node = (create_symint(shape_env, 2) == 2).node @@ -976,27 +997,42 @@ class TestSymNumberMagicMethods(TestCase): else: return torch.SymFloat(to_node(seed_node, inp)) + if fn == "float_pow": + if inp1 < 0: + return + + if fn == "pow_by_natural": + if isinstance(inp1, float) or isinstance(inp2, float): + return + if inp2 < 0: + return + def maybe_xfail(inp1, inp2): if fn == "sym_sqrt" and inp1 < 0: # ValueError: math domain error return self.assertRaises((ValueError,)) - elif fn in ("truediv", "floordiv", "mod") and inp2 == 0: + elif ( + fn in ("float_truediv", "int_truediv", "int_floordiv", "mod") + and inp2 == 0 + ): # ZeroDivisionError: division by zero return self.assertRaises((ZeroDivisionError,)) - elif fn == "pow" and inp1 == 0 and inp2 < 0: + elif fn in ["float_pow", "pow_by_natural"] and inp1 == 0 and inp2 < 0: # ZeroDivisionError: 0.0 cannot be raised to a negative power return self.assertRaises((ZeroDivisionError,)) elif ( - fn == "pow" + # TODO: dear catastrophe waitress, + # this doesn't work + fn in ["float_pow", "pow_by_natural"] and inp1 < 0 - and inp2 in (2.5, -2.5) and ( - type(inp1) in (SymFloat, SymInt) or type(inp2) in (SymFloat, SymInt) + type(inp1) is (SymInt, SymFloat) or type(inp2) is (SymInt, SymFloat) ) + and (type(inp1) is (SymFloat, float) or type(inp2) is (SymFloat, float)) ): # Complex result, which we do not support: # TypeError: Cannot convert complex to float - return self.assertRaises((TypeError,)) + return self.assertRaises((RuntimeError,)) elif fn in ("lshift", "rshift") and not ( isinstance(inp1, (SymInt, int)) and isinstance(inp2, (SymInt, int)) ): @@ -1080,6 +1116,9 @@ class TestSymNumberMagicMethods(TestCase): ) and fn in sym_node.only_float_magic_methods: self.skipTest(f"{fn} is not an int method") + if second_type == "float" and fn in ["mod"]: + self.skipTest(f"{fn} only handles int") + is_unary_fn = fn in sym_node.unary_methods or fn == "round" # Second argument is ignored for unary function. So only run for one type if is_unary_fn and second_type == "float": @@ -1251,112 +1290,15 @@ class TestFloorDiv(TestCase): yield (-x, -y) def test_floordiv_float_int(self): - values = ( - (2.5, 2.1), - (2.1, 2.5), - (2.0, 2.1), - (7, 2.5), - (2.1, 7), - (7, 2), - ) + values = ((7, 2),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( TestFloorDiv.python_floordiv(x, y), TestFloorDiv.torch_floordiv(x, y) ) - def test_floordiv_bool(self): - values = ( - (False, True), - (True, 2.5), - (2.5, True), - (False, 7), - (7, True), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # Compares to int since our FloorDiv has no bool support - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(int(x), int(y)), - ) - # Tests that our impl throws - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_complex(self): - values = ( - (1.5 + 2.5j, 1.3 + 3.5j), - (1.5 + 2.5j, 2.5), - (2.5, 1.5 + 2.5j), - (1.5 + 2.5j, 7), - (7, 1.5 + 2.5j), - ) - - for x, y in TestFloorDiv.yield_test_cases(values): - # We don't test error messages to avoid depending on Python - # interpreter version - self.assertRaises(TypeError, lambda: TestFloorDiv.python_floordiv(x, y)) - self.assertRaisesRegex( - TypeError, - ( - rf"unsupported operand type\(s\) for //: " - rf"'{type(sympy.sympify(x)).__name__}' and '{type(sympy.sympify(y)).__name__}'" - rf", expected integer or real" - ), - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_div_by_zero(self): - values = ( - (2.5, 0), - (2.1, 0.0), - (2.3, sympy.Symbol("s", zero=True)), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - # We don't test error messages to avoid depending on Python - # interpreter version - if type(y) is not sympy.Symbol: - self.assertRaises( - ZeroDivisionError, lambda: TestFloorDiv.python_floordiv(x, y) - ) - self.assertRaisesRegex( - ZeroDivisionError, - "division by zero", - lambda: TestFloorDiv.torch_floordiv(x, y), - ) - - def test_floordiv_zero_base(self): - values = ( - (0, 2.5), - (0.0, 2.1), - (sympy.Symbol("s", zero=True), 2.3), - ) - - for x, y in TestFloorDiv.yield_test_cases(values, negate=False): - if type(x) is not sympy.Symbol: - self.assertEqual( - TestFloorDiv.python_floordiv(x, y), - TestFloorDiv.torch_floordiv(x, y), - ) - else: - self.assertEqual(0, TestFloorDiv.torch_floordiv(x, y)) - def test_floordiv_div_by_one(self): - values = ( - (2.5, 1), - (2.1, 1.0), - (2, 1.0), - (2, 1), - ) + values = ((2, 1),) for x, y in TestFloorDiv.yield_test_cases(values): self.assertEqual( @@ -1367,12 +1309,7 @@ class TestFloorDiv(TestCase): # Tests how we simplify or evaluate FloorDiv without free variables shape_env = ShapeEnv() result = 21 - exprs = ( - 7 * FloorDiv(6, 2), - 7 * FloorDiv(6.28, 2), - 7 * FloorDiv(6.28, 2.0), - 7 * FloorDiv(6.28, (FloorDiv(6.28, 3.14))), - ) + exprs = (7 * FloorDiv(6, 2),) for expr in exprs: self.assertEqual(expr, result) @@ -1382,33 +1319,10 @@ class TestFloorDiv(TestCase): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) - def test_floordiv_simplify_rational(self): - result = 21 - - a = sympy.Symbol("a", integer=True) - b = sympy.Symbol("b") - - cases = [ - (FloorDiv(a, sympy.Rational(1, 8)), 8 * a), - (FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)), - ] - - for expr, expected in cases: - self.assertEqual(expr, expected) - def test_floordiv_assumptions(self): - # We define two Symbols (with different names) for each type to make - # sure the behavior is consistent regardless of whether both arguments - # are the same object or not. cases = ( sympy.Symbol("i1", integer=True), sympy.Symbol("i2", integer=True), - sympy.Symbol("r1", real=True), - sympy.Symbol("r2", real=True), - sympy.Symbol("c1", complex=True, real=False, integer=False), - sympy.Symbol("c2", complex=True, real=False, integer=False), - sympy.Symbol("s1"), - sympy.Symbol("s2"), ) for base, divisor in itertools.product(cases, repeat=2): diff --git a/test/test_proxy_tensor.py b/test/test_proxy_tensor.py index c7b2e51ced2..04483ffba0f 100644 --- a/test/test_proxy_tensor.py +++ b/test/test_proxy_tensor.py @@ -1618,7 +1618,8 @@ def forward(self, lengths_1, values_1): self.assertExpectedInline(r, """\ def forward(self, a_1): sym_size_int = torch.ops.aten.sym_size.int(a_1, 0) - pow_1 = sym_size_int ** 0.5; sym_size_int = None + sym_float = torch.sym_float(sym_size_int); sym_size_int = None + pow_1 = sym_float ** 0.5; sym_float = None div = torch.ops.aten.div.Tensor(a_1, pow_1); a_1 = pow_1 = None return div""") diff --git a/test/test_sympy_utils.py b/test/test_sympy_utils.py index c5da8f7fc0d..8b16b2c620f 100644 --- a/test/test_sympy_utils.py +++ b/test/test_sympy_utils.py @@ -36,7 +36,12 @@ UNARY_OPS = [ "floor", "ceil", ] -BINARY_OPS = ["truediv", "div", "floordiv", "truncdiv", "add", "mul", "sub", "pow", "minimum", "maximum", "mod"] +BINARY_OPS = [ + "truediv", "floordiv", + # "truncdiv", # TODO + # NB: pow is float_pow + "add", "mul", "sub", "pow", "pow_by_natural", "minimum", "maximum", "mod" +] UNARY_BOOL_OPS = ["not_"] BINARY_BOOL_OPS = ["or_", "and_"] @@ -81,16 +86,24 @@ def valid_unary(fn, v): def valid_binary(fn, a, b): if fn == "pow" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big b > 4 - or ( # sympy will expand to x*x*... for integral b; don't do it if it's big - a <= 0 and b == -1 - ) - or (a == b == 0) # no imaginary numbers # 0**0 is undefined + # no imaginary numbers + or a <= 0 + # 0**0 is undefined + or (a == b == 0) ): return False - elif fn == "mod" and b == 0: + elif fn == "pow_by_natural" and ( + # sympy will expand to x*x*... for integral b; don't do it if it's big + b > 4 + or b < 0 + or (a == b == 0) + ): return False - elif (fn == "div" or fn == "truediv") and b == 0: + elif fn == "mod" and (a < 0 or b <= 0): + return False + elif (fn in ["div", "truediv", "floordiv"]) and b == 0: return False return True @@ -130,27 +143,26 @@ class TestValueRanges(TestCase): ValueRangeAnalysis.pow(ValueRanges.unknown(), ValueRanges.wrap(0.5)) @parametrize("fn", BINARY_OPS) - @parametrize("dtype_a", ("int", "float")) - @parametrize("dtype_b", ("int", "float")) - def test_binary_ref(self, fn, dtype_a, dtype_b): + @parametrize("dtype", ("int", "float")) + def test_binary_ref(self, fn, dtype): to_dtype = {"int": sympy.Integer, "float": sympy.Float} - dtype_a = to_dtype[dtype_a] - dtype_b = to_dtype[dtype_b] + # Don't test float on int only methods + if dtype == "float" and fn in ["pow_by_natural", "mod"]: + return + dtype = to_dtype[dtype] for a, b in itertools.product(CONSTANTS, repeat=2): if not valid_binary(fn, a, b): continue - a = dtype_a(a) - b = dtype_b(b) + a = dtype(a) + b = dtype(b) with self.subTest(a=a, b=b): r = getattr(ValueRangeAnalysis, fn)(a, b) if r == ValueRanges.unknown(): continue ref_r = getattr(ReferenceAnalysis, fn)(a, b) - # sympy.floordiv does 1.0 // 1.0 == 1 rather than 1.0. wtf - if fn != "floordiv": - self.assertEqual(r.lower.is_integer, r.upper.is_integer) - self.assertEqual(ref_r.is_integer, r.upper.is_integer) + self.assertEqual(r.lower.is_integer, r.upper.is_integer) + self.assertEqual(ref_r.is_integer, r.upper.is_integer) self.assertEqual(r.lower, r.upper) self.assertEqual(ref_r, r.lower) @@ -200,7 +212,8 @@ class TestValueRanges(TestCase): @parametrize("fn", UNARY_OPS) def test_unary_ref_range(self, fn): - vals = [-sympy.oo, *CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = CONSTANTS for a in generate_range(vals): with self.subTest(a=a): ref_r = getattr(ValueRangeAnalysis, fn)(a) @@ -216,40 +229,26 @@ class TestValueRanges(TestCase): # This takes about 4s for all the variants @parametrize("fn", BINARY_OPS + COMPARE_OPS) def test_binary_ref_range(self, fn): - vals = [-sympy.oo, *LESS_CONSTANTS, sympy.oo] + # TODO: bring back sympy.oo testing for float unary fns + vals = LESS_CONSTANTS for a, b in itertools.product(generate_range(vals), repeat=2): # don't attempt pow on exponents that are too large (but oo is OK) if fn == "pow" and b.upper > 4 and b.upper != sympy.oo: continue with self.subTest(a=a, b=b): - ref_r = getattr(ValueRangeAnalysis, fn)(a, b) for a0, b0 in itertools.product(LESS_CONSTANTS, repeat=2): if a0 not in a or b0 not in b: continue if not valid_binary(fn, a0, b0): continue with self.subTest(a0=a0, b0=b0): + ref_r = getattr(ValueRangeAnalysis, fn)(a, b) r = getattr(ReferenceAnalysis, fn)( sympy.Integer(a0), sympy.Integer(b0) ) if r.is_finite: self.assertIn(r, ref_r) - def test_rational_bounds(self): - # Repro from https://github.com/pytorch/pytorch/issues/105097 - from sympy import floor, Eq - shape_0 = sympy.Symbol('shape_0', positive=True, integer=True) - new_expr = ( - Eq(30 * floor(4 * ((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647 + - 2584 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 647), - 2880 * floor(((shape_0 + 1) // 96) * - ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 15528 + - 323 * ((shape_0 + 62017) // (((shape_0 + 1) // 96) + 646)) / 7764))) - new_range_env = {shape_0: ValueRanges(lower=1, upper=190)} - self.assertTrue(new_expr.subs({shape_0: 95})) - self.assertIn(True, sympy_interp(ValueRangeAnalysis, new_range_env, new_expr)) - class TestSympyInterp(TestCase): @parametrize("fn", UNARY_OPS + BINARY_OPS + UNARY_BOOL_OPS + BINARY_BOOL_OPS + COMPARE_OPS) @@ -258,7 +257,13 @@ class TestSympyInterp(TestCase): if fn in ("div", "truncdiv", "minimum", "maximum", "mod"): return - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) + vals = CONSTANTS if fn in {*UNARY_BOOL_OPS, *BINARY_BOOL_OPS}: vals = [True, False] @@ -300,29 +305,17 @@ class TestSympyInterp(TestCase): if fn in {*BINARY_OPS, *BINARY_BOOL_OPS, *COMPARE_OPS}: arity = 2 - from sympy.abc import x, y + is_integer = None + if fn == "pow_by_natural": + is_integer = True + + x = sympy.Dummy('x', integer=is_integer) + y = sympy.Dummy('y', integer=is_integer) symbols = [x] if arity == 2: symbols = [x, y] - # Workaround mpf from symbol error - if fn == "minimum": - sympy_expr = sympy.Min(x, y) - elif fn == "maximum": - sympy_expr = sympy.Max(x, y) - else: - sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) - - if arity == 1: - def trace_f(px): - return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) - else: - def trace_f(px, py): - return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) - - gm = fx.symbolic_trace(trace_f) - for args in itertools.product(vals, repeat=arity): if arity == 1 and not valid_unary(fn, *args): continue @@ -330,11 +323,28 @@ class TestSympyInterp(TestCase): continue if fn == "truncdiv" and args[1] == 0: continue - elif fn == "pow" and (args[0] == 0 and args[1] <= 0): + elif fn in ("pow", "pow_by_natural") and (args[0] == 0 and args[1] <= 0): continue elif fn == "floordiv" and args[1] == 0: continue with self.subTest(args=args): + # Workaround mpf from symbol error + if fn == "minimum": + sympy_expr = sympy.Min(x, y) + elif fn == "maximum": + sympy_expr = sympy.Max(x, y) + else: + sympy_expr = getattr(ReferenceAnalysis, fn)(*symbols) + + if arity == 1: + def trace_f(px): + return sympy_interp(PythonReferenceAnalysis, {x: px}, sympy_expr) + else: + def trace_f(px, py): + return sympy_interp(PythonReferenceAnalysis, {x: px, y: py}, sympy_expr) + + gm = fx.symbolic_trace(trace_f) + self.assertEqual( sympy_interp(PythonReferenceAnalysis, dict(zip(symbols, args)), sympy_expr), gm(*args) diff --git a/torch/__init__.py b/torch/__init__.py index 18f1752019e..dfb1da76739 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -316,6 +316,75 @@ class SymInt: # Magic methods installed by torch.fx.experimental.sym_node + def __round__(self, ndigits=None): + return self + + def __truediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__float_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_truediv__(other) + + def __rtruediv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rfloat_truediv__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_truediv__(other) + + def __floordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(sym_float(self) / other)) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__int_floordiv__(other) + + def __rfloordiv__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return torch.sym_float(math.floor(other / sym_float(self))) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + return self.__rint_floordiv__(other) + + # nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + def __pow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__pow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + # Guards! This guard is necessary because we need to know it to + # determine the output type of this operation + if other >= 0: + return self.__pow_by_natural__(other) + else: + # Mercifully, when the exponent is negative, Python just promotes + # to doubles and does a float pow: + # + # if (Py_SIZE(b) < 0 && c == NULL) { + # /* if exponent is negative and there's no modulus: + # return a float. This works because we know + # that this calls float_pow() which converts its + # arguments to double. */ + # Py_DECREF(a); + # Py_DECREF(b); + # return PyFloat_Type.tp_as_number->nb_power(v, w, x); + # } + return sym_float(self).__pow__(sym_float(other)) + + def __rpow__(self, other): + if isinstance(other, (builtins.float, SymFloat)): + return sym_float(self).__rpow__(other) + if not isinstance(other, (builtins.int, SymInt)): + return NotImplemented + if self >= 0: # self is exponent + return self.__rpow_by_natural__(other) + else: + return sym_float(self).__rpow__(sym_float(other)) + def __eq__(self, other: object) -> builtins.bool: raise AssertionError("type stub not overridden") @@ -337,6 +406,24 @@ class SymInt: def __mul__(self, other) -> "SymInt": raise AssertionError("type stub not overridden") + def __pow_by_natural__(self, other) -> "SymInt": + raise AssertionError("type stub not overridden") + + def __rpow_by_natural__(self, other) -> "SymInt": + raise AssertionError("type stub not overridden") + + def __int_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rint_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __int_floordiv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rint_floordiv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + def __sym_max__(self, other): raise AssertionError("type stub not overridden") @@ -371,9 +458,43 @@ class SymFloat: # class has a field named node that stores SymNode self.node = node + def __truediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__float_truediv__(sym_float(other)) + + def __rtruediv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return self.__rfloat_truediv__(sym_float(other)) + + def __floordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(self / sym_float(other))) + + def __rfloordiv__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + return torch.sym_float(math.floor(sym_float(other) / self)) + def __bool__(self): return self.node.bool_() + # Symbolic power does NOT work with negative base, this is to avoid + # potential complex outputs + def __pow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(self >= 0) + return self.__float_pow__(other) + + def __rpow__(self, other): + if not isinstance(other, (builtins.int, builtins.float, SymInt, SymFloat)): + return NotImplemented + torch._check(other >= 0) + return self.__rfloat_pow__(other) + # Magic methods installed by torch.fx.experimental.sym_node def __eq__(self, other: object) -> builtins.bool: @@ -391,6 +512,18 @@ class SymFloat: def __ge__(self, other) -> builtins.bool: raise AssertionError("type stub not overridden") + def __float_pow__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rfloat_pow__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __float_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + + def __rfloat_truediv__(self, other) -> "SymFloat": + raise AssertionError("type stub not overridden") + def __trunc__(self): raise AssertionError("type stub not overridden") @@ -524,7 +657,12 @@ def sym_int(a): return py_int(a) # type: ignore[operator] def sym_max(a, b): - """ SymInt-aware utility for max().""" + """ + SymInt-aware utility for max which avoids branching on a < b. + Unlike builtins.max(), this only works for int/float, and it always + promotes to float if any argument is float (unlike builtins.max, which + will faithfully preserve the type of the input argument). + """ from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -532,14 +670,19 @@ def sym_max(a, b): if isinstance(a, (SymInt, SymFloat)): return a.__sym_max__(b) elif isinstance(b, (SymInt, SymFloat)): - # NB: If you actually care about preserving output type exactly - # if you do something like max(0, 0.0), it is NOT sound to treat - # min/max as commutative + # Due to promotion semantics, this is operator is commutative: + # max(1, 1.0) === max(1.0, 1) === 1.0 return b.__sym_max__(a) - return builtins.max(a, b) # type: ignore[operator] + # TODO: Probably can make bool work too, just lazy + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.max(a, b)) + else: + return builtins.max(a, b) def sym_min(a, b): - """ SymInt-aware utility for max().""" + """ SymInt-aware utility for min().""" from .overrides import has_torch_function, handle_torch_function if has_torch_function((a, b)): @@ -548,7 +691,12 @@ def sym_min(a, b): return a.__sym_min__(b) elif isinstance(b, (SymInt, SymFloat)): return b.__sym_min__(a) - return builtins.min(a, b) # type: ignore[operator] + assert isinstance(a, (builtins.int, builtins.float)), type(a) + assert isinstance(b, (builtins.int, builtins.float)), type(b) + if isinstance(a, builtins.float) or isinstance(b, builtins.float): + return builtins.float(builtins.min(a, b)) + else: + return builtins.min(a, b) # Drop in replacement for math.sqrt, math.sin, math.cos etc current_module = sys.modules[__name__] diff --git a/torch/_export/serde/serialize.py b/torch/_export/serde/serialize.py index 8d6dc939fb5..9a92c238f95 100644 --- a/torch/_export/serde/serialize.py +++ b/torch/_export/serde/serialize.py @@ -1474,10 +1474,15 @@ class GraphModuleDeserializer(metaclass=Final): # Here we force symbols corresponding to SymInts to be at least integers. # Otherwise some expressions that the shape env would otherwise evaluate to False, # e.g., 2*s = 9, can have rational solutions, e.g., 9/2. + # TODO: This is HIGHLY SUSPICIOUS ezyang(May 2024) sym = sym.subs( {s: sympy.Symbol(s.name, integer=True) for s in sym.free_symbols} ) - if isinstance(sym, sympy.Symbol): + # We need to check if the symbol has already been allocated, + # self.symbol_name_to_symbol is not enough because the + # integer-ification of symbols can induce simplification; + # e.g., (2**s0 + 1) // 2 --> s0 when we know s0 is integral + if isinstance(sym, sympy.Symbol) and sym not in self.shape_env.var_to_val: self.symbol_name_to_symbol[val.expr_str] = sym if hint is not None: self.shape_env.add_var_to_val(sym, hint) @@ -1496,7 +1501,7 @@ class GraphModuleDeserializer(metaclass=Final): free_symbols = sym.free_symbols for s in free_symbols: if s.name not in self.symbol_name_to_symbol: - self.symbol_name_to_symbol[s.name] = s + self.symbol_name_to_symbol[s.name] = s # type: ignore[assignment] if vr := self.symbol_name_to_range.get(s.name): self.shape_env.constrain_symbol_range( s, diff --git a/torch/_inductor/bounds.py b/torch/_inductor/bounds.py index 4640ec4dce6..212b79e35bf 100644 --- a/torch/_inductor/bounds.py +++ b/torch/_inductor/bounds.py @@ -1,3 +1,4 @@ +import logging import operator from functools import partial from typing import Any, Callable, Dict @@ -11,6 +12,9 @@ from .utils import cache_on_self, dominated_nodes from .virtualized import V +log = logging.getLogger(__name__) + + class BoundVars: """ Performs Value Range Analysis on LoopBody's fx graph by calling BoundVars.run() @@ -55,6 +59,7 @@ class BoundVars: with V.set_ops_handler(ValueRangeAnalysis()): interpreter = InterpreterShim(self.loop_body.root_block.graph, submodules) + log.debug("get_bounds:\n%s", self.loop_body.root_block.graph) interpreter.run(V.get_ops_handler(), initial_env=self._bounds) return self._bounds diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py index f7b3e7a45d6..dae72186df0 100644 --- a/torch/_inductor/codegen/common.py +++ b/torch/_inductor/codegen/common.py @@ -340,6 +340,8 @@ class DataTypePropagation: DataTypePropagation.propagate_loopbody(node._body) +# This printer contains rules that are supposed to be generic for both C/C++ and +# Python class ExprPrinter(Printer): @staticmethod def paren(string): @@ -369,12 +371,6 @@ class ExprPrinter(Printer): return string return f"({string})" - def _print_Infinity(self, expr): - return "math.inf" - - def _print_NegativeInfinity(self, expr): - return "-math.inf" - def _print_Relational(self, expr): return f" {expr.rel_op} ".join(map(self.paren, map(self._print, expr.args))) @@ -384,11 +380,14 @@ class ExprPrinter(Printer): def _print_Add(self, expr): return " + ".join(map(self.paren, map(self._print, expr.args))) + # NB: this is OK to put here, because Mod is only defined for positive + # numbers, and so across C/Python its behavior is consistent def _print_Mod(self, expr): return " % ".join(map(self.paren, map(self._print, expr.args))) - def _print_FloorDiv(self, expr): - raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" def _print_CleanDiv(self, expr): return self._print_FloorDiv(expr) @@ -399,10 +398,84 @@ class ExprPrinter(Printer): # Go figure... return " >= ".join(map(self.paren, map(self._print, expr.args))) + # NB: The C implementation is injected into codegen at + # torch/_inductor/codegen/wrapper.py def _print_align(self, expr): assert len(expr.args) == 1 return f"align({self._print(expr.args[0])})" + # This must be implemented because sympy will collect x * x into Pow(x, 2), without + # any explicit intervention. We print it just like x * x, notably, we + # never generate sympy.Pow with floats. + # + # NB: this pow by natural, you should never have used builtin sympy.pow + # for FloatPow, and a symbolic exponent should be PowByNatural. These + # means exp is guaranteed to be integer. + def _print_Pow(self, expr): + base, exp = expr.args + base = self._print(base) + assert exp == int(exp), exp + exp = int(exp) + assert exp >= 0 + if exp > 0: + return "*".join([self.paren(base)] * exp) + else: # exp == 0 + return "1" + + # Explicit NotImplemented functions are to prevent default sympy printing + # behavior, which will just barf out ToFloat(...) to your IR. The error + # message is better here because it tells you which printer class it needs + # to go in. + + def _print_ToFloat(self, expr): + raise NotImplementedError(f"_print_ToFloat not implemented for {type(self)}") + + def _print_Infinity(self, expr): + raise NotImplementedError(f"_print_Infinity not implemented for {type(self)}") + + def _print_NegativeInfinity(self, expr): + raise NotImplementedError( + f"_print_NegativeInfinity not implemented for {type(self)}" + ) + + def _print_FloorDiv(self, expr): + raise NotImplementedError(f"_print_FloorDiv not implemented for {type(self)}") + + def _print_PythonMod(self, expr): + raise NotImplementedError(f"_print_PythonMod not implemented for {type(self)}") + + def _print_IntTrueDiv(self, expr): + raise NotImplementedError(f"_print_IntTrueDiv not implemented for {type(self)}") + + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatPow(self, expr): + raise NotImplementedError(f"_print_FloatPow not implemented for {type(self)}") + + def _print_TruncToInt(self, expr): + raise NotImplementedError(f"_print_TruncToInt not implemented for {type(self)}") + + def _print_RoundToInt(self, expr): + raise NotImplementedError(f"_print_RoundToInt not implemented for {type(self)}") + + def _print_RoundDecimal(self, expr): + raise NotImplementedError( + f"_print_RoundDecimal not implemented for {type(self)}" + ) + + # NB: Some float operations are INTENTIONALLY not implemented for + # printers. You can implement them as a quick unblock, but it is better + # to ask yourself why we haven't done this computation in the Tensor + # universe instead + + def _print_TruncToFloat(self, expr): + raise NotImplementedError( + f"_print_TruncToFloat not implemented for {type(self)}" + ) + def doprint(self, expr, *, simplify: bool = True): # TODO: why are people passing strings to the printer here :think: if simplify and isinstance(expr, sympy.Expr) and hasattr(V.graph, "sizevars"): @@ -411,6 +484,10 @@ class ExprPrinter(Printer): class PythonPrinter(ExprPrinter): + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"float({self._print(expr.args[0])})" + def _print_ModularIndexing(self, expr): x, div, mod = expr.args x = self.paren(self.doprint(x)) @@ -420,56 +497,72 @@ class PythonPrinter(ExprPrinter): x = f"({x} // {div})" return f"{x} % {mod}" + def _print_Infinity(self, expr): + return "math.inf" + + def _print_NegativeInfinity(self, expr): + return "-math.inf" + + # WARNING: this is dangerous for Triton, which has C-style modulus + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # WARNING: this is dangerous for Triton, which has C-style modulus def _print_FloorDiv(self, expr): x, div = expr.args x = self.paren(self.doprint(x)) div = self.paren(self.doprint(div)) return f"({x} // {div})" + # WARNING: this is dangerous for Triton, when lhs, rhs > 2**53, Python + # does a special algorithm + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + def _helper_sqrt(self, expr): return f"math.sqrt({self._print(expr)})" def _print_OpaqueUnaryFn_sqrt(self, expr): return self._helper_sqrt(expr.args[0]) - def _print_Pow(self, expr): - # Pow() confuses triton + def _print_FloatPow(self, expr): base, exp = expr.args - # NB: Remember this is sizevar computation! You don't typically - # expect to have to do floating point computation including exponents - # in sizevar compute. Instead of adding support for floating - # point pow, you should make upstream retranslate the Sympy expression - # into Tensor expressions earlier and do that instead. - if exp == 0.5: - return self._helper_sqrt(base) - elif exp == -0.5: - return "1/" + self._helper_sqrt(base) - base = self._print(base) - assert exp == int(exp), exp - exp = int(exp) - if exp > 0: - return "*".join([self.paren(base)] * exp) - elif exp < 0: - return "1/" + self.paren("*".join([self.paren(base)] * abs(exp))) - else: # exp == 0 - return "1" + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" + + # TODO: Not sure this works with Triton, even when base/exp are integral + def _print_PowByNatural(self, expr): + base, exp = expr.args + return f"{self.paren(self._print(base))} ** {self.paren(self._print(exp))}" def _print_floor(self, expr): assert len(expr.args) == 1 return f"math.floor({self._print(expr.args[0])})" - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): assert len(expr.args) == 1 + return f"math.floor({self._print(expr.args[0])})" + + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + # This also could have been int(), they'll do the same thing for float return f"math.trunc({self._print(expr.args[0])})" def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"math.ceil({self._print(expr.args[0])})" + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"math.ceil({self._print(expr.args[0])})" + def _print_Abs(self, expr): assert len(expr.args) == 1 return f"abs({self._print(expr.args[0])})" + # NB: It's expected that we've made explicit any promotion in the sympy + # expression, so it doesn't matter that Python max/min doesn't perform + # promotion def _print_Max(self, expr): assert len(expr.args) >= 2 return f"max({', '.join(map(self._print, expr.args))})" @@ -514,7 +607,7 @@ class PythonPrinter(ExprPrinter): assert len(expr.args) == 1 return f"math.atan({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 return f"round({self._print(expr.args[0])})" @@ -653,6 +746,29 @@ class OpOverrides: ) return ops.where(cond, ops.add(r, b), r) + @staticmethod + def trunc_to_int(a, dtype): + return ops.to_dtype(ops.trunc(a), dtype) + + @staticmethod + def floor_to_int(a, dtype): + return ops.to_dtype(ops.floor(a), dtype) + + @staticmethod + def ceil_to_int(a, dtype): + return ops.to_dtype(ops.ceil(a), dtype) + + @staticmethod + def round_to_int(a, dtype): + return ops.to_dtype(ops.round(a), dtype) + + @staticmethod + def int_truediv(a, b): + # TODO: this is wrong + # TODO: an easy bandaid is to generate runtime asserts that it's + # <= 2**53, which is when this equation is correct + return ops.truediv(a, b) + @staticmethod def load_seed(name, offset): return ops.load(name, sympy.Integer(offset)) diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py index eabb5bbef47..311781102c3 100644 --- a/torch/_inductor/codegen/cpp.py +++ b/torch/_inductor/codegen/cpp.py @@ -275,11 +275,11 @@ def simplify_index_in_vec_range(index: sympy.Expr, var: sympy.Expr, vec_length: original_index = index - div = sympy.Wild("divisor") + div = sympy.Wild("divisor", integer=True) if index.has(FloorDiv): index = index.replace(FloorDiv(var, div), visit_indexing_div) - mod = sympy.Wild("modulus") + mod = sympy.Wild("modulus", integer=True) if index.has(ModularIndexing): index = index.replace(ModularIndexing(var, div, mod), visit_modular_indexing) diff --git a/torch/_inductor/codegen/cpp_utils.py b/torch/_inductor/codegen/cpp_utils.py index 4ab33a5e26d..aac0c20df0c 100644 --- a/torch/_inductor/codegen/cpp_utils.py +++ b/torch/_inductor/codegen/cpp_utils.py @@ -100,10 +100,53 @@ class CppPrinter(ExprPrinter): r = f"std::floor({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::floor({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + + def _print_TruncToInt(self, expr): assert len(expr.args) == 1 r = f"std::trunc({self._print(expr.args[0])})" - return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + return f"static_cast<{INDEX_TYPE}>({r})" + + def _print_TruncToFloat(self, expr): + assert len(expr.args) == 1 + return f"std::trunc({self._print(expr.args[0])})" + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"static_cast({self._print(expr.args[0])})" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_CMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + # TODO: This is only accurate up to 2**53 + return f"static_cast({self._print(lhs)}) / static_cast({self._print(rhs)})" + + # TODO: PowByNatural: we need to implement our own int-int pow. Do NOT + # use std::pow, that operates on floats + def _print_PowByNatural(self, expr): + raise NotImplementedError( + f"_print_PowByNatural not implemented for {type(self)}" + ) + + def _print_FloatTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + def _print_FloatPow(self, expr): + base, exp = expr.args + return f"std::pow({self._print(base)}, {self._print(exp)})" def _print_Pow(self, expr): # Uses float constants to perform FP div @@ -139,6 +182,11 @@ class CppPrinter(ExprPrinter): r = f"std::ceil({self._print(expr.args[0])})" return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + r = f"std::ceil({self._print(expr.args[0])})" + return f"static_cast<{INDEX_TYPE}>({r})" if expr.is_integer else r + def _print_Min(self, expr): args = [self._print(a) for a in expr.args] if len(args) == 2: @@ -200,8 +248,9 @@ class CppPrinter(ExprPrinter): def _print_OpaqueUnaryFn_sqrt(self, expr): return f"std::sqrt({self._print(expr.args[0])})" - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 + # TODO: dispatch to llrint depending on index type return f"std::lrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): diff --git a/torch/_inductor/codegen/triton.py b/torch/_inductor/codegen/triton.py index 4b0ea92f3bf..f74086615c6 100644 --- a/torch/_inductor/codegen/triton.py +++ b/torch/_inductor/codegen/triton.py @@ -272,23 +272,68 @@ def triton_reshape(value: str, old_shape: List[str], new_shape: List[str]): return f"{value}[{', '.join(expand)}]" +# NB: Inheriting from PythonPrinter is somewhat dangerous, because there are a +# number of operators which Triton "implements", but in a way that is +# inconsistent with Python semantics (and consistent with C semantics). We +# must override all of these, or it is potential silent correctness problem class TritonPrinter(PythonPrinter): + def _print_TruncToInt(self, expr): + assert len(expr.args) == 1 + return ( + f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + ) + + def _print_ToFloat(self, expr): + assert len(expr.args) == 1 + return f"{self.paren(self._print(expr.args[0]))}.to(tl.float64)" + + # TODO: This is wrong if one of the inputs is negative. This is hard to + # tickle though, as the inputs are typically positive (and if we can prove + # they are positive, we will have used Mod instead, for which this codegen + # is right). If you are trying to hit this, maybe try something like + # torch.arange(n, device="cuda") - 1 and then do a modulus on it + def _print_PythonMod(self, expr): + return " % ".join(map(self.paren, map(self._print, expr.args))) + + # TODO: This is wrong, see + # https://github.com/triton-lang/triton/issues/955 + # But for Sympy expressions, things will /mostly/ work out because we + # don't usually deal with negative numbers in the division + def _print_FloorDiv(self, expr): + assert expr.is_integer + x, div = expr.args + x = self.paren(self.doprint(x)) + div = self.paren(self.doprint(div)) + return f"({x} // {div})" + + # TODO: This is wrong, when lhs, rhs > 2**53, Python does a higher + # precision algorithm, which we would need to replicate here + def _print_IntTrueDiv(self, expr): + lhs, rhs = expr.args + return f"{self.paren(self._print(lhs))} / {self.paren(self._print(rhs))}" + + # NB: sympy.floor/ceiling produce integers, so we have to do the + # conversion to index dtype def _print_floor(self, expr): assert len(expr.args) == 1 return ( f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) - def _print_Trunc(self, expr): + def _print_FloorToInt(self, expr): assert len(expr.args) == 1 return ( - f"libdevice.trunc({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + f"libdevice.floor({self._print(expr.args[0])}).to({V.kernel.index_dtype})" ) def _print_ceiling(self, expr): assert len(expr.args) == 1 return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + def _print_CeilToInt(self, expr): + assert len(expr.args) == 1 + return f"libdevice.ceil({self._print(expr.args[0])}).to({V.kernel.index_dtype})" + def _helper_sqrt(self, expr): return f"libdevice.sqrt({self._print(expr)}.to(tl.float32))" @@ -359,20 +404,9 @@ class TritonPrinter(PythonPrinter): assert len(expr.args) == 1 return f"libdevice.atan(({self._print(expr.args[0])}).to(tl.float32))" - def _print_FloorDiv(self, expr): - if expr.is_integer: - return super()._print_FloorDiv(expr) - - x, div = expr.args - x = self.paren(self.doprint(x)) - div = self.paren(self.doprint(div)) - return f"libdevice.floor({x} / {div}).to({V.kernel.index_dtype})" - - def _print_Round(self, expr): + def _print_RoundToInt(self, expr): assert len(expr.args) == 1 - return ( - f"libdevice.llrint({self._print(expr.args[0])}).to({V.kernel.index_dtype})" - ) + return f"libdevice.llrint({self._print(expr.args[0])})" def _print_RoundDecimal(self, expr): assert len(expr.args) == 2 diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index 337a7375afa..abe93686ac8 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -1196,8 +1196,11 @@ class GraphLowering(torch.fx.Interpreter): elif is_magic_method(n.target): # TODO: this is sus, it probably should be handled in the # lowerings themselves similarly to sym_size/sym-stride + # https://github.com/pytorch/pytorch/issues/127789 debug("is_magic_method") - if isinstance(n.meta["val"], torch.SymInt): + if isinstance( + n.meta["val"], (torch.SymInt, torch.SymFloat, torch.SymBool) + ): result = n.meta["val"].node.expr else: result = super().run_node(n) diff --git a/torch/_inductor/ir.py b/torch/_inductor/ir.py index c46cad5e41e..e9adfcd19a2 100644 --- a/torch/_inductor/ir.py +++ b/torch/_inductor/ir.py @@ -44,7 +44,6 @@ from torch._prims_common import ( is_boolean_dtype, is_float_dtype, make_channels_last_strides_for, - make_contiguous_strides_for, StrideType, ) from torch._subclasses.fake_tensor import get_schema_info @@ -236,7 +235,7 @@ def ir_node_to_tensor(x, guard_shape=True): if is_storage_and_layout(x): stride = [shape_fn(s) for s in x.get_layout().stride] # type: ignore[misc] else: - stride = make_contiguous_strides_for(size) # type: ignore[arg-type] + stride = FlexibleLayout.contiguous_strides(size) # type: ignore[arg-type] dtype = x.get_dtype() device = x.get_device() size = convert_shape_to_symint(size) @@ -2766,6 +2765,7 @@ class FlexibleLayout(Layout): allow_indexing = False + # WARNING! This doesn't handle zero size tensors correctly @staticmethod def contiguous_strides(sizes): if len(sizes) == 0: @@ -5915,7 +5915,7 @@ def _prepare_convolution_fusion_create( # To align the behavior of the Conv kernel, we set the output_stride in such case to be contiguous instead of channels last. dynamic_shapes = not all(isinstance(i, int) for i in (output_size)) if dynamic_shapes and is_contiguous_storage_and_layout(x): - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) else: output_stride = make_channels_last_strides_for(output_size) @@ -5967,7 +5967,7 @@ def _prepare_linear_fusion_create( assert x.get_device().type == "cpu" and weight.get_device().type == "cpu" inputs = [x, weight] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) kernel_layout = FixedLayout( x.get_device(), x.get_dtype(), @@ -6283,7 +6283,7 @@ class MKLPackedLinear(ExternKernelAlloc): *m, _ = x.get_size() oc, _ = orig_w.get_size() output_size = list(m) + [oc] - output_stride = make_contiguous_strides_for(output_size) + output_stride = FlexibleLayout.contiguous_strides(output_size) inputs = [x, packed_w, orig_w] constant_args = [batch_size] if B is not None: @@ -6601,13 +6601,13 @@ class MkldnnRnnLayer(ExternKernelAlloc): def get_strides_of_lstm_output(output_shape, batch_first): assert len(output_shape) == 3, "Expect output_shape to be 3D" - return make_contiguous_strides_for(output_shape) + return FlexibleLayout.contiguous_strides(output_shape) output_sizes = [output_shape, hy_shape, cy_shape] output_strides = [ get_strides_of_lstm_output(output_shape, batch_first), - make_contiguous_strides_for(hy_shape), - make_contiguous_strides_for(cy_shape), + FlexibleLayout.contiguous_strides(hy_shape), + FlexibleLayout.contiguous_strides(cy_shape), ] output_ir = [ MultiOutput( diff --git a/torch/_inductor/kernel/flex_attention.py b/torch/_inductor/kernel/flex_attention.py index 42fabf65591..f3492949a84 100644 --- a/torch/_inductor/kernel/flex_attention.py +++ b/torch/_inductor/kernel/flex_attention.py @@ -5,7 +5,6 @@ from enum import auto, Enum from typing import Any, List, Tuple import torch -from torch._prims_common import make_contiguous_strides_for from .. import config from ..ir import ( ComputedBuffer, @@ -389,7 +388,7 @@ def flex_attention(*args, **kwargs): query.get_device(), query.get_dtype(), query.get_size(), - make_contiguous_strides_for(query.get_size()), + FlexibleLayout.contiguous_strides(query.get_size()), ) # see NOTE:[TritonTemplates with multiple outputs] logsumexp_shape = query.get_size()[:-1] # [B, H, M] @@ -745,7 +744,7 @@ def flex_attention_backward(*args, **kwargs): key.get_device(), key.get_dtype(), key.get_size(), - make_contiguous_strides_for(key.get_size()), + FlexibleLayout.contiguous_strides(key.get_size()), ) # Create delta which will is needed for the bwd's kernel diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 0a1909890e6..deec9b13e56 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -34,7 +34,7 @@ from torch._prims_common import ( Number, ) from torch.fx.experimental.sym_node import magic_methods, method_to_operator -from torch.utils._sympy.functions import CeilDiv, FloorDiv, ModularIndexing +from torch.utils._sympy.functions import CeilDiv, FloorDiv, IntTrueDiv, ModularIndexing from .._dynamo.utils import import_submodule from . import config, inductor_prims, ir, test_operators # NOQA: F401 @@ -4262,7 +4262,7 @@ def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim): out_sz = out_sz[dim] in_sz = in_sz[dim] kernel_sz = kernel_sz[dim] - alpha = (in_sz - kernel_sz) / (out_sz - 1) + alpha = IntTrueDiv(in_sz - kernel_sz, out_sz - 1) samples_loader = samples.make_loader() def load(prefix, i): @@ -4372,7 +4372,7 @@ def upsample_nearest2d_backward( w_kernel_max = ceildiv(inp_w, out_w) def start_index(index, out_dim, inp_dim): - return CeilDiv(index * inp_dim, out_dim) + return CeilDiv(index * inp_dim, sympy.sympify(out_dim)) def end_index(index, out_dim, inp_dim): return start_index((index + 1), out_dim, inp_dim) diff --git a/torch/_inductor/ops_handler.py b/torch/_inductor/ops_handler.py index 5630061b442..f88cd948ca4 100644 --- a/torch/_inductor/ops_handler.py +++ b/torch/_inductor/ops_handler.py @@ -138,6 +138,38 @@ class OpsHandler(Protocol[T]): """ ... + def trunc_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with truncation semantics (similar to how the int + constructor works in Python). In Inductor codegen, this just decays + to trunc and then to_dtype, but this composite operation helps + roundtrips for Sympy evaluation. + + dtype is taken as an explicit parameter because the desired output + dtype is typically the index dtype, which may vary between int32 and + int64 depending on if we've shown that all the indexing operations can + be done in int32. + """ + ... + + def ceil_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def floor_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with ceiling semantics. See also trunc_to_int. + """ + ... + + def round_to_int(self, x: T, dtype: torch.dtype) -> T: + """ + Convert x to dtype with round-to-even semantics. See also trunc_to_int. + """ + ... + def to_dtype_bitcast(self, x: T, dtype: torch.dtype, src_dtype: torch.dtype) -> T: """ Reinterpret cast x to dtype (reinterpreting the bits in memory as another dtype.) @@ -398,21 +430,23 @@ class OpsHandler(Protocol[T]): def isnan(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation + # This rounds half to even to break ties def round(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def floor(self, x0: T) -> T: ... def sign(self, x0: T) -> T: ... - def to_int(self, x0: T) -> T: - ... - + # NB: this returns a float, like the torch operation def trunc(self, x0: T) -> T: ... + # NB: this returns a float, like the torch operation def ceil(self, x0: T) -> T: ... @@ -449,6 +483,7 @@ class OpsHandler(Protocol[T]): def mul(self, x0: T, x1: T) -> T: ... + # NB: this returns a float, like the torch operation def pow(self, x0: T, x1: T) -> T: ... @@ -617,14 +652,21 @@ class OpsHandler(Protocol[T]): def floordiv(self, x0: T, x1: T) -> T: """Python-style floor division between integers only. Computes the - true division of two numbers and floors the result. + true division of two numbers and floors the result. If you want + floor division for floats, do regular truediv and floor the result. """ ... def truediv(self, x0: T, x1: T) -> T: - """True division between floats. Integer inputs are NOT valid: to do - Python style (int, int) -> float division, promote the inputs to float - first.""" + """True division between floats. Integer inputs are NOT valid. To + do Python-style (int, int) -> float division, use int_truediv""" + ... + + def int_truediv(self, x0: T, x1: T) -> T: + """True division between integers. This is NOT the same as promoting + to float and doing integer division, there is a bespoke algorithm for + doing the division in higher precision than the above. + """ ... def div(self, x0: T, x1: T) -> T: @@ -640,6 +682,10 @@ class OpsHandler(Protocol[T]): """Python-style modulus, take sign from RHS (x1).""" ... + def round_decimal(self, x0: T, x1: T) -> T: + """Python-style round with decimal argument""" + ... + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # In CUDA, optimized implementations of other mathematical operations are # offered separately via libdevice for double precision computation (in diff --git a/torch/_inductor/select_algorithm.py b/torch/_inductor/select_algorithm.py index 5e5cbf35baf..a1b029aa288 100644 --- a/torch/_inductor/select_algorithm.py +++ b/torch/_inductor/select_algorithm.py @@ -386,7 +386,7 @@ class TritonTemplateKernel(TritonKernel): assert isinstance(mask, (str, type(None))) assert self.template_mask is None indices = list(map(TritonPrinter.paren, indices)) - index_symbols = [sympy.Symbol(x) for x in indices] + index_symbols = [sympy.Symbol(x, integer=True) for x in indices] lengths = [ V.graph.sizevars.simplify(s) for s in self.output_node.get_size() ] @@ -410,7 +410,7 @@ class TritonTemplateKernel(TritonKernel): output_index = self.output_node.get_layout().make_indexer()(index_symbols) output_index = self.rename_indexing(output_index) if output_index == contiguous_index: - output_index = sympy.Symbol("xindex") + output_index = sympy.Symbol("xindex", integer=True) epilogue_args = [val] for input_node in itertools.chain( diff --git a/torch/_inductor/sizevars.py b/torch/_inductor/sizevars.py index bc8803a5e71..fba9a66f923 100644 --- a/torch/_inductor/sizevars.py +++ b/torch/_inductor/sizevars.py @@ -161,9 +161,9 @@ class SizeVarAllocator: if expr.has(ModularIndexing): expr = expr.replace( ModularIndexing( - sympy.Wild("base"), - sympy.Wild("divisor"), - sympy.Wild("modulus"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), + sympy.Wild("modulus", integer=True), ), visit_modular_indexing, ) @@ -171,8 +171,8 @@ class SizeVarAllocator: if expr.has(FloorDiv): expr = expr.replace( FloorDiv( - sympy.Wild("base"), - sympy.Wild("divisor"), + sympy.Wild("base", integer=True), + sympy.Wild("divisor", integer=True), ), visit_indexing_div, ) @@ -604,11 +604,11 @@ def _join_dimensions_cached(expr: Expr) -> Expr: """ assert isinstance(expr, sympy.Add) - scale = sympy.Wild("scale", exclude=[0]) - base = sympy.Wild("base") - divisor = sympy.Wild("divisor") - mod1 = sympy.Wild("modulus") - mod2 = sympy.Wild("modulus2") + scale = sympy.Wild("scale", exclude=[0], integer=True) + base = sympy.Wild("base", integer=True) + divisor = sympy.Wild("divisor", integer=True) + mod1 = sympy.Wild("modulus", integer=True) + mod2 = sympy.Wild("modulus2", integer=True) for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: diff --git a/torch/_inductor/utils.py b/torch/_inductor/utils.py index 0915a8330c3..a635c2f509c 100644 --- a/torch/_inductor/utils.py +++ b/torch/_inductor/utils.py @@ -192,7 +192,7 @@ def ceildiv( numer: Union[int, sympy.Expr], denom: Union[int, sympy.Expr] ) -> Union[int, sympy.Expr]: if isinstance(numer, sympy.Expr) or isinstance(denom, sympy.Expr): - return CeilDiv(numer, denom) + return CeilDiv(sympy.sympify(numer), sympy.sympify(denom)) # TODO: There is a bug in a call to this function, to repro: # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # --amp --only YituTechConvBert --dynamic-shapes diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 47d4abcf77b..9343490de3e 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1727,7 +1727,7 @@ class FakeTensorMode(TorchDispatchMode): for run_impl_check, op_impl in op_implementations_checks: if run_impl_check(func): op_impl_out = op_impl(self, func, *args, **kwargs) - if op_impl_out != NotImplemented: + if op_impl_out is not NotImplemented: return maybe_propagate_real_tensors(op_impl_out) def maybe_run_unsafe_fallback(error=None): diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index a7ce337f9ac..2a3cb62c56d 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1200,8 +1200,13 @@ void initJITBindings(PyObject* module) { SYMNODE_BINARY(sub) SYMNODE_BINARY(mul) SYMNODE_BINARY(truediv) + SYMNODE_BINARY(int_truediv) + SYMNODE_BINARY(float_truediv) SYMNODE_BINARY(pow) + SYMNODE_BINARY(float_pow) + SYMNODE_BINARY(pow_by_natural) SYMNODE_BINARY(floordiv) + SYMNODE_BINARY(int_floordiv) SYMNODE_BINARY(mod) SYMNODE_BINARY(eq) SYMNODE_BINARY(ne) diff --git a/torch/csrc/utils/python_symnode.h b/torch/csrc/utils/python_symnode.h index f8c710cf657..15738b1a67e 100644 --- a/torch/csrc/utils/python_symnode.h +++ b/torch/csrc/utils/python_symnode.h @@ -198,14 +198,34 @@ class PythonSymNodeImpl : public c10::SymNodeImpl { return dispatch_common_(__func__, other); } + c10::SymNode float_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode int_truediv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode pow(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode float_pow(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + + c10::SymNode pow_by_natural(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode floordiv(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } + c10::SymNode int_floordiv(const c10::SymNode& other) override { + return dispatch_common_(__func__, other); + } + c10::SymNode mod(const c10::SymNode& other) override { return dispatch_common_(__func__, other); } diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index a4ed16e975b..ac2bdd60a55 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -1,7 +1,6 @@ import builtins import dataclasses import inspect -import math import sys import weakref from collections import defaultdict @@ -254,11 +253,14 @@ class _Constraint(_ConstraintTarget, metaclass=_ConstraintFactory): shared: Optional[_ConstraintTarget] = None debug_name: Optional[str] = None - def _clone_with_range(self, lower=0, upper=math.inf): + def _clone_with_range(self, lower=0, upper=None): # Import sympy locally from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges + if upper is None: + upper = sys.maxsize - 1 + constraint_range = StrictMinMaxConstraint( vr=self.constraint_range.vr & ValueRanges(lower=lower, upper=upper), warn_only=False, @@ -486,7 +488,6 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): ) # Import sympy locally - import sympy from torch.fx.experimental.symbolic_shapes import StrictMinMaxConstraint from torch.utils._sympy.value_ranges import ValueRanges @@ -496,7 +497,7 @@ def dynamic_dim(t: torch.Tensor, index: int, debug_name: Optional[str] = None): id(t), index, StrictMinMaxConstraint( - vr=ValueRanges(lower=0, upper=sympy.oo), warn_only=False + vr=ValueRanges(lower=0, upper=sys.maxsize - 1), warn_only=False ), debug_name=debug_name, ) diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index 4bf9ebab17b..28df3fddab0 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -277,7 +277,13 @@ def record_shapeenv_event(*, save_tracked_fakes: bool = False) -> Callable: raise except Exception: - log.error("failed while running %s(*%s, **%s)", name, args[1:], kwargs) + log.error( # noqa: G201 + "failed while running %s(*%s, **%s)", + name, + args[1:], + kwargs, + exc_info=log.isEnabledFor(logging.INFO), + ) raise return wrapper diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index 98cba67a73a..c7f0aba9fac 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -267,8 +267,11 @@ class SymNode: def mod(self, other) -> "SymNode": return self._mod(other) # type: ignore[attr-defined] - def pow(self, other) -> "SymNode": - return self._pow(other) # type: ignore[attr-defined] + def float_pow(self, other) -> "SymNode": + return self._float_pow(other) # type: ignore[attr-defined] + + def pow_by_natural(self, other) -> "SymNode": + return self._pow_by_natural(other) # type: ignore[attr-defined] def and_(self, other) -> "SymNode": return self._and_(other) # type: ignore[attr-defined] @@ -276,11 +279,14 @@ class SymNode: def or_(self, other) -> "SymNode": return self._or_(other) # type: ignore[attr-defined] - def truediv(self, other) -> "SymNode": - return self._truediv(other) # type: ignore[attr-defined] + def float_truediv(self, other) -> "SymNode": + return self._float_truediv(other) # type: ignore[attr-defined] - def floordiv(self, other) -> "SymNode": - return self._floordiv(other) # type: ignore[attr-defined] + def int_truediv(self, other) -> "SymNode": + return self._int_truediv(other) # type: ignore[attr-defined] + + def int_floordiv(self, other) -> "SymNode": + return self._int_floordiv(other) # type: ignore[attr-defined] def lshift(self, other) -> "SymNode": return self._lshift(other) # type: ignore[attr-defined] @@ -361,6 +367,17 @@ class SymNode: def sym_and(self, other): return self.and_(other) + # There is no int_truediv available from C++ + def truediv(self, other): + return self.float_truediv(other) + + def floordiv(self, other) -> "SymNode": + return self.int_floordiv(other) + + # We didn't bind integer pow in C++ + def pow(self, other): + return self.float_pow(other) + def is_non_overlapping_and_dense(self, sizes, strides): return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] @@ -477,7 +494,7 @@ METHOD_TO_OPERATOR = { "eq": operator.eq, "floor": math.floor, "trunc": math.trunc, - "floordiv": operator.floordiv, + "int_floordiv": operator.floordiv, "ge": operator.ge, "gt": operator.gt, "is_integer": lambda x: x.is_integer(), @@ -489,7 +506,8 @@ METHOD_TO_OPERATOR = { "ne": operator.ne, "neg": operator.neg, "or": operator.or_, - "pow": operator.pow, + "float_pow": operator.pow, + "pow_by_natural": operator.pow, "round": builtins.round, "rshift": operator.rshift, "sub": operator.sub, @@ -498,12 +516,14 @@ METHOD_TO_OPERATOR = { "sym_max": sym_max, "sym_min": sym_min, "sym_not": sym_not, - "truediv": operator.truediv, + "float_truediv": operator.truediv, + "int_truediv": operator.truediv, } unary_magic_methods = { "abs", "sym_float", + "sym_int", "ceil", "floor", "neg", @@ -559,20 +579,20 @@ also_bool_magic_methods = {"eq"} bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods # Methods that are only for float -only_float_magic_methods = {"is_integer"} +only_float_magic_methods = {"is_integer", "round", "sym_int"} magic_methods_on_operator_with_trailing_underscore = {"and", "or"} -always_float_magic_methods = {"truediv", "sym_float", "pow"} +always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"} for name in math_op_names: sym_name = f"sym_{name}" always_float_magic_methods.add(sym_name) -always_int_magic_methods = {"ceil", "floor", "trunc"} +always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"} always_bool_magic_methods = { "eq", "ne", @@ -590,10 +610,16 @@ always_bool_magic_methods = { # Methods that have a `__foo__` as well as `__rfoo__` -def _sympy_truediv(a, b): - from torch.utils._sympy.functions import TrueDiv +def _sympy_float_truediv(a, b): + from torch.utils._sympy.functions import FloatTrueDiv - return TrueDiv(a, b) + return FloatTrueDiv(a, b) + + +def _sympy_int_truediv(a, b): + from torch.utils._sympy.functions import IntTrueDiv + + return IntTrueDiv(a, b) def _sympy_floordiv(a, b): @@ -603,15 +629,24 @@ def _sympy_floordiv(a, b): def _sympy_mod(a, b): - from torch.utils._sympy.functions import Mod + from torch.utils._sympy.functions import Mod, PythonMod - return Mod(a, b) + if a.is_nonnegative and b.is_nonnegative: + return Mod(a, b) + else: + return PythonMod(a, b) -def _sympy_pow(a, b): - from torch.utils._sympy.functions import Pow +def _sympy_pow_by_natural(a, b): + from torch.utils._sympy.functions import PowByNatural - return Pow(a, b) + return PowByNatural(a, b) + + +def _sympy_float_pow(a, b): + from torch.utils._sympy.functions import FloatPow + + return FloatPow(a, b) def _sympy_and(a, b): @@ -643,11 +678,13 @@ reflectable_magic_methods = { "sub": operator.sub, "mul": operator.mul, "mod": _sympy_mod, - "pow": _sympy_pow, + "pow_by_natural": _sympy_pow_by_natural, + "float_pow": _sympy_float_pow, "and": _sympy_and, "or": _sympy_or, - "truediv": _sympy_truediv, - "floordiv": _sympy_floordiv, + "float_truediv": _sympy_float_truediv, + "int_truediv": _sympy_int_truediv, + "int_floordiv": _sympy_floordiv, "lshift": _sympy_lshift, "rshift": _sympy_rshift, } @@ -672,21 +709,23 @@ def _floor_ceil_helper(a, fn): def _sympy_floor(a): - import sympy + from torch.utils._sympy.functions import FloorToInt - return _floor_ceil_helper(a, sympy.floor) + return FloorToInt(a) +# NB: this is Python trunc semantics which returns an int. Do NOT use this to +# represent torch.trunc (which is float to float) def _sympy_trunc(a): - from torch.utils._sympy.functions import Trunc + from torch.utils._sympy.functions import TruncToInt - return Trunc(a) + return TruncToInt(a) def _sympy_ceil(a): - import sympy + from torch.utils._sympy.functions import CeilToInt - return _floor_ceil_helper(a, sympy.ceiling) + return CeilToInt(a) def _sympy_eq(a, b): @@ -771,26 +810,28 @@ def _sympy_abs(a): def _sympy_round(number, ndigits=None): - from torch.utils._sympy.functions import Round, RoundDecimal + from torch.utils._sympy.functions import RoundDecimal, RoundToInt if ndigits is None: - return Round(number) + return RoundToInt(number) else: return RoundDecimal(number, ndigits) def _sympy_sym_float(a): - # Cannot use sympy.Float(a) here, coz it expects python literals - # Multiply by 1.0 to cast to float. This is needed when the input - # is a SymInt which has the assumption that it is integer and - # SymPy will otherwise assume that return value cannot be a float. - return a * 1.0 + from torch.utils._sympy.functions import ToFloat + + # NB: Cannot use a * 1.0 here, because 0 * 1.0 is 0 which incorrectly + # reports that it is an integer + return ToFloat(a) def _sympy_is_integer(a): import sympy - return sympy.Eq(sympy.floor(a), a) + from torch.utils._sympy.functions import ToFloat + + return sympy.Eq(ToFloat(sympy.floor(a)), a) magic_methods = { @@ -989,9 +1030,26 @@ def _make_node_magic(method, func): self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {}) ) assert isinstance(other, SymNode) - # TODO: consider constant prop here try: - out = func(self.expr, other.expr) + if method == "mod": + from torch.utils._sympy.functions import Mod, PythonMod + + # Special handling for mod that requires access to the value + # ranges + shape_env = self.shape_env + if ( + self.expr.is_nonnegative + or shape_env.bound_sympy(self.expr).lower >= 0 + ) and ( + other.expr.is_nonnegative + or shape_env.bound_sympy(other.expr).lower >= 0 + ): + out = Mod(self.expr, other.expr) + else: + out = PythonMod(self.expr, other.expr) + else: + # TODO: consider constant prop here + out = func(self.expr, other.expr) except Exception: log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr) raise @@ -1122,9 +1180,13 @@ def _make_node_magic(method, func): except Exception: log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits) raise + out = safe_expand(out) - pytype = int if ndigits is None else self.pytype + if ndigits is None: + pytype = int + else: + pytype = self.pytype out_hint = None if self.hint is not None: @@ -1136,6 +1198,7 @@ def _make_node_magic(method, func): # hack down below works, because all round function down the line all take ndigits=None as default in their # signature. # TODO: Remove the args construction below if a different sentinel is used by FX. + # ezyang(May 2024): LOL args = [self.fx_node] if ndigits is not None: args.append(ndigits) @@ -1259,6 +1322,32 @@ def _make_user_magic(method, user_type): return x.node.is_constant() return False + # Promotion rules for binary operations. NB: we preserve PYTHON semantics + # - if args are same type, do nothing + # - if one arg is float, promote other arg to float + # - nb: this applies to floordiv, even though output is integral + # (it's still float) + # - pow is funny business + # - if both ints + # - trigger a guard on exponent >= 0 + # - if non-negative, output is int + # - otherwise, output is float + # - otherwise, promote other arg to float + # - nb: complex is impossible to handle correctly lol, with + # negative base and integral float need to diverge semantics and + # just always return complex. Neener neener pretend this problem + # doesn't exist + # - equality is pain: Python does the fancy thing where it unpacks the + # mantissa from the float and then compares that against the int. + # Which means it is able to tell that + # 9007199254740993 != 9007199254740992. (rather than if the LHS was + # promoted to float, in which case it would have truncated to the RHS + # and subsequently been equal). We'll model this exactly by having + # special mixed type equality operations. Unfortunately, we need to + # do this for all comparison operations (maybe I'll only implement + # compare) + # - sym_ite mumble mumble really shouldn't allow mixed but whatever + if method in bool_becomes_int_magic_methods: def promote(x): @@ -1272,6 +1361,41 @@ def _make_user_magic(method, user_type): def promote(x): return x + def promote2(self, other): + # TODO: Remove eq and other relations from this list. + # CPython has fancy implementations for these to get as much precision + # as possible instead of just promoting to float64 and praying, so we + # need to handle them specially too. + # Also, note that int_truediv doesn't go through this path: both + # arguments are "int" so there isn't any promotion + if method not in [ + "add", + "sub", + "mul", + "mod", + "float_pow", + "float_truediv", + "int_floordiv", + "sym_min", + "sym_max", + # TODO: remove these + "eq", + "ne", + "gt", + "lt", + "le", + "ge", + ]: + return self, other + f_self = isinstance(self, (float, torch.SymFloat)) + f_other = isinstance(other, (float, torch.SymFloat)) + if f_self or f_other: + if not f_self: + self = torch.sym_float(self) + if not f_other: + other = torch.sym_float(other) + return self, other + # Before and after performing the operation, check if any operands are constant. # If so, extract out the constant values first. If `self` itself is a # constant, then "redispatch" by calling back into the operator. Sometimes @@ -1286,9 +1410,12 @@ def _make_user_magic(method, user_type): return wrap_node(getattr(self.node, method_attr)()) def binary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented sym_node_log.debug("MAGIC %s %s %s", method, self, other) self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): @@ -1300,8 +1427,11 @@ def _make_user_magic(method, user_type): return get_constant(ret) if is_constant(ret) else ret def rbinary_magic_impl(self, other): + if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)): + return NotImplemented self = promote(self) other = promote(other) + self, other = promote2(self, other) if is_constant(self): return (method_to_operator(method))(get_constant(self), other) if is_constant(other): diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index a2abde3a861..687d2bcbd1e 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -61,7 +61,7 @@ from torch._logging import trace_structured, structured from torch import SymBool, SymFloat, SymInt from torch._guards import ShapeGuard, Source, TracingContext from torch.utils._python_dispatch import is_traceable_wrapper_subclass -from torch.utils._sympy.functions import FloorDiv, Mod, IsNonOverlappingAndDenseIndicator +from torch.utils._sympy.functions import FloorDiv, Mod, PythonMod, IsNonOverlappingAndDenseIndicator, CleanDiv from torch.utils._sympy.solve import try_solve from torch.utils._sympy.value_ranges import bound_sympy, SymPyValueRangeAnalysis, ValueRanges, ValueRangeError from torch.utils._sympy.singleton_int import SingletonInt @@ -869,9 +869,9 @@ def constrain_range(a, *, min: Optional[int], max: Optional[int] = None): for N=1. """ if min is None: - min = -sympy.oo + min = -sys.maxsize - 1 if max is None: - max = sympy.oo + max = sys.maxsize - 1 if max < min: raise ValueError( @@ -979,16 +979,6 @@ def eval_guards(gm, *args, ignore_static=True): def bind_symbols(gm, *args): return gm.shape_env.bind_symbols(fx_placeholder_vals(gm), args) -def _assert_bound_is_rational(expr: sympy.Expr, bound: ValueRanges): - """ - We assert that the bounds are either Boolean, or not finite, or can be computed - in exact prevision via rational arithmetic. - The only exception to this is the rare case when the user calls `sqrt(s0)` - sqrt is turned into sympy.Pow so we just match for that (it matches more things, but still) - """ - assert bound.lower.is_rational or bound.lower.is_Boolean or not bound.lower.is_finite or expr.has(sympy.Pow), (bound, expr) - assert bound.upper.is_rational or bound.upper.is_Boolean or not bound.upper.is_finite or expr.has(sympy.Pow), (bound, expr) - class DimDynamic(Enum): """ Controls how to perform symbol allocation for a dimension. It is always @@ -1387,14 +1377,19 @@ SYMPY_INTERP = { 'Min': min, 'Max': max, 'Mod': operator.mod, + 'PythonMod': operator.mod, 'FloorDiv': operator.floordiv, 'TrueDiv': operator.truediv, 'IsNonOverlappingAndDenseIndicator': eval_is_non_overlapping_and_dense, 'floor': math.floor, 'ceiling': math.ceil, + 'FloorToInt': math.floor, + 'CeilToInt': math.ceil, 'cast_symbool_to_symint_guardless': cast_symbool_to_symint_guardless, - 'Round': builtins.round, + 'RoundToInt': builtins.round, 'RoundDecimal': builtins.round, + 'TruncToInt': math.trunc, + 'IntTrueDiv': operator.truediv, } @@ -1642,10 +1637,17 @@ class DimConstraints: congruence = (base - mod_reduced) % divisor if congruence != 0: self._congruences[s].add(congruence) + # NB: Must not be CleanDiv, it needs to be regular sympy division + # so inequality solver works. This is sort of problematic for + # is_integer tests though haha return (base - mod_reduced) / divisor if expr.has(Mod): expr = expr.replace(Mod, mod_handler) + # 7 // -3 is -3, 7 % -3 is -2, and 7 - (-2) / -3 is -3.0 so negative + # arguments should be OK. + if expr.has(PythonMod): + expr = expr.replace(PythonMod, mod_handler) if expr.has(FloorDiv): expr = expr.replace(FloorDiv, floor_div_handler) return expr @@ -3330,6 +3332,7 @@ class ShapeEnv: self.pending_fresh_unbacked_symbols.append(symbol) self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges.unknown() + assert vr.is_float # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, float) @@ -3348,6 +3351,7 @@ class ShapeEnv: self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = self._default_unspecified_value_range() + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, int) @@ -3371,6 +3375,7 @@ class ShapeEnv: self.counter["create_unbacked_symbol"] += 1 self.var_to_stack[symbol] = CapturedTraceback.extract(skip=1) vr = self.var_to_range[symbol] = ValueRanges(0, 1) + assert vr.is_int # Create a new FX placeholder and Z3 variable for 'symbol'. fx_node = self._create_fx_placeholder_and_z3var(symbol, bool) @@ -3516,6 +3521,7 @@ class ShapeEnv: self.var_to_range[sympy_expr] &= constraint_dim.vr vr = self.var_to_range[sympy_expr] + assert vr.is_int if val not in vr: raise ConstraintViolationError(f"{val} not in range [{vr.lower}, {vr.upper}]") @@ -3524,6 +3530,7 @@ class ShapeEnv: elif isinstance(val, float): self.var_to_range[sympy_expr] = vr = ValueRanges(-sympy.oo, sympy.oo) range_str = f"[{vr.lower}, {vr.upper}]" + assert vr.is_float else: # Skip var_range logic for SingletonInt # Only used for jagged layout nested tensors @@ -3573,6 +3580,7 @@ class ShapeEnv: def add_var_to_val(self, expr: sympy.Symbol, val: int): """ Adds a new symbol to the symbolic environment. """ + log.debug("add_var_to_val %s %s", expr, val, stack_info=True) assert expr not in self.var_to_val, f"{expr} already exists" self.var_to_val[expr] = sympy.Integer(val) @@ -4301,7 +4309,8 @@ class ShapeEnv: # Clamp values of size-like variables for x in self.size_like & var_to_range.keys(): if var_to_range[x] is not None: - var_to_range[x] = ValueRanges(2, sympy.oo) + var_to_range[x] = ValueRanges(2, sys.maxsize - 1) + assert var_to_range[x].is_int return bound_sympy(expr, var_to_range) @_lru_cache @@ -4418,6 +4427,11 @@ class ShapeEnv: vr = self._default_unspecified_value_range() if size_oblivious and k in self.size_like: lower = max(2, vr.lower) + # This is a bit dodgy: what this means is that there was a + # size-like unbacked symbol whose upper bound < 2. This + # causes... problems. + if lower <= vr.upper: + vr = ValueRanges(lower, vr.upper) else: lower = vr.lower # Don't do anything if we don't have a nontrivial lower bound @@ -4425,10 +4439,17 @@ class ShapeEnv: # SymInt if ( lower < (-sys.maxsize - 1) // 2 or - (unbacked_only and k in self.var_to_val) + (unbacked_only and k in self.var_to_val) or + not vr.is_int ): new_range_env[k] = vr continue + # The goal is to take our symbols which have various lower bounds + # and reallocate them into new symbols which are exactly positive; + # e.g., if we have s0 in [2, inf], we want to turn it into ess0 in + # [1, inf], where s0 = ess0 + 1. This gives the most information + # to sympy for subsequent simplifications. + # # Positive means >= 1 # Positive - 1 means >= 0 # Positive + lower - 1 means >= lower @@ -4460,6 +4481,14 @@ class ShapeEnv: self.counter["sympy_recursion_error"] += 1 return None + new_expr = safe_expand(new_expr) + if new_expr.is_number: + return new_expr + + # This is bad to do, the replacement with division leaves us with + # rationals when atom.args[0] is addition, e.g., sympy will happily + # turn (s0 + s1) // 2 into s0 / 2 + s1 / 2. Needless complication! + """ floor_div_replace = {} for atom in new_expr.atoms(FloorDiv): floor_div_replace[atom] = sympy.floor(atom.args[0] / atom.args[1]) @@ -4468,13 +4497,12 @@ class ShapeEnv: # are still free symbols if new_expr.is_number: return new_expr + """ # Check if the range can solve it statically out = bound_sympy(new_expr, new_range_env) - if expect_rational: - _assert_bound_is_rational(new_expr, out) - if out.is_singleton(): - return out.lower + if out.is_singleton(): + return out.lower return new_expr if unbacked_only else None @@ -4526,7 +4554,7 @@ class ShapeEnv: for fd in expr.atoms(FloorDiv): base, divisor = fd.args if self.replace(Mod(base, divisor)) in self.divisible: - div_replacements[fd] = base / divisor + div_replacements[fd] = CleanDiv(base, divisor) new_expr = expr.xreplace(div_replacements) new_expr = safe_expand(new_expr) new_pows = new_expr.atoms(sympy.Pow) @@ -4670,7 +4698,10 @@ class ShapeEnv: int_range = ValueRanges(-sys.maxsize - 1, sys.maxsize - 1) def issubset(x, y): - return (x & int_range).issubset(y & int_range) + if x.is_int and y.is_int: + return (x & int_range).issubset(y & int_range) + else: + return x.issubset(y) # First, refine the value range of a based on the computed value range # of tgt. This is always OK to do, even if we decide not to do the @@ -4688,7 +4719,7 @@ class ShapeEnv: b = next(iter(tgt.free_symbols)) # Try to invert the equality r = try_solve(sympy.Eq(a, tgt), b, floordiv_inequality=False) - if r is not None: + if r is not None and all(t.is_integer for t in sympy.preorder_traversal(r[1])): b_bound = self.bound_sympy(r[1]) self.var_to_range[b] = b_bound & self.var_to_range[b] tgt_bound = self.bound_sympy(tgt) @@ -4899,12 +4930,12 @@ class ShapeEnv: ): # We have Mod(i0, q / c) == 0, which means we can # rewrite i0 as (q / gcd(q, c)) * i1 - d = q / sympy.gcd(q, c) + d = q / sympy.gcd(q, c) # TODO: CleanDiv? i1 = self.create_unbacked_symint().node.expr # Propagate the value ranges. It doesn't really # matter if we use truediv or floordiv, because we # have established divisibility. - self._update_var_to_range(i1, SymPyValueRangeAnalysis.truediv( + self._update_var_to_range(i1, SymPyValueRangeAnalysis.floordiv( self.var_to_range[i0], ValueRanges.wrap(d) )) # Propagate size-like-ness @@ -5341,7 +5372,6 @@ class ShapeEnv: lower, upper = vr.lower, vr.upper rhs_vr = bound_sympy(rhs, self.var_to_range) - _assert_bound_is_rational(rhs, rhs_vr) # Let's suppose that we have a preexisting range for x [0, 100]. # Now, we issue a guard x > y, where the range for y is [50, 150]. diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 6dcb59db797..d06b38d60c8 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -216,10 +216,7 @@ try: def abs(self, number: z3.ArithRef) -> z3.ArithRef: return z3.Abs(number) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - if ndigits is not None: - raise ValueError("round(..., ndigits=) is currently not supported by shape validations.") - + def round_to_int(self, number: z3.ArithRef) -> z3.ArithRef: # Pythons builtin 'round' implements the 'round half to even' strategy # See https://en.wikipedia.org/wiki/Rounding#Rounding_half_to_even # z3 has an equivalent z3.fpRoundToIntegral(z3.RoundNearestTiesToEven(), ...), but this only applies to @@ -284,7 +281,7 @@ try: operator.truediv: lift(ops.div), operator.mod: lift(ops.mod), operator.abs: lift(ops.abs), - builtins.round: lift(ops.round), + builtins.round: lift(ops.round_to_int), # Math module. math.ceil: lift(ops.ceil), @@ -350,6 +347,7 @@ try: self._ops = _Z3Ops(self._validator) def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: + # TODO: Probably OK to relax this and allow lower precision if dtype is torch.int64: return z3.IntVal(int(value)) if dtype is torch.double: @@ -358,6 +356,20 @@ try: return z3.BoolVal(bool(value)) raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") + def to_dtype(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + if dtype == torch.float64: + return z3.ToReal(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + + def trunc_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return z3.ToInt(x) + + def round_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.round_to_int(x) + + def int_truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: + return self._ops.div(numerator, denominator) + def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: return self._ops.div(numerator, denominator) @@ -370,11 +382,17 @@ try: def pow(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: return self._ops.pow(base, exp) + def pow_by_natural(self, base: z3.ArithRef, exp: z3.ArithRef) -> z3.ArithRef: + return self._ops.pow(base, exp) + def mod(self, p: z3.ArithRef, q: z3.ArithRef) -> z3.ArithRef: return self._ops.mod(p, q) - def round(self, number: z3.ArithRef, ndigits: Optional[z3.ArithRef] = None) -> z3.ArithRef: - return self._ops.round(number, ndigits) + def ceil_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.ceil(x) + + def floor_to_int(self, x: z3.ArithRef, dtype: torch.dtype) -> z3.ArithRef: + return self._ops.floor(x) def __getattr__(self, name: str) -> Any: REPLACEMENT = { diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 1384261b451..128ce537c01 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -1,43 +1,78 @@ +import functools import math +import sys import sympy from sympy import S -from sympy.core.logic import fuzzy_and, fuzzy_not, fuzzy_or __all__ = [ "FloorDiv", "ModularIndexing", "CleanDiv", "CeilDiv", - "Pow", - "TrueDiv", + "IntTrueDiv", + "FloatTrueDiv", "LShift", "RShift", "IsNonOverlappingAndDenseIndicator", - "Round", + "RoundToInt", "RoundDecimal", + "ToFloat", + "FloatPow", + "PowByNatural", ] +def _keep_float(f): + @functools.wraps(f) + def inner(*args): + r = f(*args) + if any(isinstance(a, sympy.Float) for a in args) and not isinstance( + r, sympy.Float + ): + r = sympy.Float(float(r)) + return r + + return inner + + def fuzzy_eq(x, y): if None in (x, y): return None return x == y +# It would be nice to have assertions on whether or not inputs is_integer +# However, with bugs like https://github.com/sympy/sympy/issues/26620 sympy +# sometimes inconsistently reports floats an integers. +# +# What we can assume from sympy is that if something is an int, it +# definitely is is_integer, but if it is a float it may or may not +# be is_integer. So we are unable to do strong asserts that things +# are NOT integers. + + +# TODO: In Triton, // rounds to zero, but in Python, it is floor division. +# When we can prove both arguments are non-negative, we should just have a +# GenericFloorDiv (name pending) which can codegen efficiently in Python/C, +# and then PythonFloorDiv and CIntDiv which have the appropriate rounding +# semantics. +# +# Right now, FloorDiv de facto changes behavior if arguments are negative or +# not, this can potentially cause correctness issues. class FloorDiv(sympy.Function): """ We maintain this so that: 1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. 2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) + + NB: This is Python-style floor division, round to -Inf """ nargs = (2,) precedence = 50 # precedence of mul # noqa: F811 - # Default return type for SymPy assumptions. - # https://docs.sympy.org/latest/guides/assumptions.html#implementing-assumptions-handlers - is_real = True + is_integer = True @property def base(self): @@ -52,29 +87,14 @@ class FloorDiv(sympy.Function): divisor = printer.parenthesize(self.divisor, self.precedence) return f"({base}//{divisor})" - # SymPy assumptions based on argument types. - def _eval_is_real(self): - return fuzzy_or([self.base.is_real, self.divisor.is_real]) - - def _eval_is_integer(self): - return fuzzy_and([self.base.is_integer, self.divisor.is_integer]) - # Automatic evaluation. # https://docs.sympy.org/latest/guides/custom-functions.html#best-practices-for-eval @classmethod def eval(cls, base, divisor): - def check_supported_type(x): - if ( - x.is_integer is False and x.is_real is False and x.is_complex - ) or x.is_Boolean: - raise TypeError( - f"unsupported operand type(s) for //: " - f"'{type(base).__name__}' and '{type(divisor).__name__}'" - f", expected integer or real" - ) - - check_supported_type(base) - check_supported_type(divisor) + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # Assert triggered by inequality solver + # assert base.is_integer, base + # assert divisor.is_integer, divisor # We don't provide the same error message as in Python because SymPy # makes it difficult to check the types. @@ -85,26 +105,22 @@ class FloorDiv(sympy.Function): return sympy.S.Zero if base.is_integer and divisor == 1: return base - if base.is_real and divisor == 1: - return sympy.floor(base) if base.is_integer and divisor == -1: return sympy.Mul(base, -1) if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): - return base // divisor - if isinstance(base, (sympy.Integer, sympy.Float)) and isinstance( - divisor, (sympy.Integer, sympy.Float) - ): - return sympy.floor(base / divisor) + return sympy.Integer(int(base) // int(divisor)) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) - if isinstance(divisor, sympy.Rational) and divisor.p == 1: - return sympy.floor(base * divisor.q) + # gcd in sympy is over polynomials, so you'll end up with rationals if + # you do this. Don't. + """ if isinstance(base, sympy.Add): for a in base.args: gcd = sympy.gcd(a, divisor) if gcd == divisor: return FloorDiv(base - a, divisor) + a / gcd + """ try: gcd = sympy.gcd(base, divisor) @@ -189,6 +205,19 @@ class Where(sympy.Function): nargs = (3,) + def _eval_is_integer(self): + return True if self.args[1].is_integer and self.args[2].is_integer else None # type: ignore[attr-defined] + + def _eval_is_nonnegative(self): + return ( + True + if self.args[1].is_nonnegative and self.args[2].is_nonnegative # type: ignore[attr-defined] + else None + ) + + def _eval_is_positive(self): + return True if self.args[1].is_positive and self.args[2].is_positive else None # type: ignore[attr-defined] + @classmethod def eval(cls, c, p, q): if c == sympy.true: @@ -197,28 +226,27 @@ class Where(sympy.Function): return q -class Mod(sympy.Function): - """ - We maintain this so that we avoid SymPy correctness issues, such as: - https://github.com/sympy/sympy/issues/25146 - """ - +# Python-style modulus: take sign from RHS +class PythonMod(sympy.Function): nargs = (2,) + is_integer = True + @classmethod def eval(cls, p, q): - # This was adapted from: sympy/core/mod.py + # python test/dynamo/test_export.py -k ExportTests.test_trivial_constraint + # Triggered by sympy.solvers.inequalities.reduce_inequalities + # assert p.is_integer, p + # assert q.is_integer, q if q.is_zero: raise ZeroDivisionError("Modulo by zero") - # If either of them is NaN or infinite. - if p is S.NaN or q is S.NaN or p.is_finite is False or q.is_finite is False: - return S.NaN + # Three cases: # 1. p == 0 # 2. p is either q or -q # 3. p is integer and q == 1 - if p is S.Zero or p in (q, -q) or (p.is_integer and q == 1): + if p is S.Zero or p in (q, -q) or q == 1: return S.Zero # Evaluate if they are both literals. @@ -247,10 +275,7 @@ class Mod(sympy.Function): if sympy.Mod(p, q) == 0: return S.Zero - def _eval_is_integer(self): - p, q = self.args - return fuzzy_and([p.is_integer, q.is_integer, fuzzy_not(q.is_zero)]) # type: ignore[attr-defined] - + # NB: args[1] for PythonMod def _eval_is_nonnegative(self): return True if self.args[1].is_positive else None # type: ignore[attr-defined] @@ -258,6 +283,58 @@ class Mod(sympy.Function): return True if self.args[1].is_negative else None # type: ignore[attr-defined] +# Generic modulus: only defined on non-negative arguments +class Mod(sympy.Function): + nargs = (2,) + + is_integer = True + is_nonnegative = True + + @classmethod + def eval(cls, p, q): + # This was adapted from: sympy/core/mod.py + + # Triggered by + # python test/test_dynamic_shapes.py -k TestDimConstraints.test_dim_constraints_solve_full + # assert p.is_integer, p + # assert q.is_integer, q + + if q.is_zero: + raise ZeroDivisionError("Modulo by zero") + + # Three cases: + # 1. p == 0 + # 2. p is either q or -q + # 3. p is integer and q == 1 + if p is S.Zero or p in (q, -q) or q == 1: + return S.Zero + + # Evaluate if they are both literals. + if q.is_Number and p.is_Number: + assert p >= 0, p + assert q >= 1, q + return p % q + + # If q == 2, it's a matter of whether p is odd or even. + if q.is_Number and q == 2: + if p.is_even: + return S.Zero + if p.is_odd: + return S.One + + # If p is a multiple of q. + r = p / q + if r.is_integer: + return S.Zero + + # If p < q and its ratio is positive, then: + # - floor(p / q) = 0 + # - p % q = p - floor(p / q) * q = p + less = p < q + if less.is_Boolean and bool(less) and r.is_positive: + return p + + class CleanDiv(FloorDiv): """ Div where we can assume no rounding. @@ -267,6 +344,36 @@ class CleanDiv(FloorDiv): pass +# Don't use sympy ceiling/floor as they will attempt simplifications involving +# frac +class CeilToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): + return sympy.Integer(math.ceil(float(number))) + + +class FloorToInt(sympy.Function): + is_integer = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): + return sympy.Integer(math.floor(float(number))) + + class CeilDiv(sympy.Function): """ Div used in indexing that rounds up. @@ -275,6 +382,8 @@ class CeilDiv(sympy.Function): is_integer = True def __new__(cls, base, divisor): + base = sympy.sympify(base) + divisor = sympy.sympify(divisor) if sympy.gcd(base, divisor) == divisor: return CleanDiv(base, divisor) else: @@ -282,6 +391,8 @@ class CeilDiv(sympy.Function): class LShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): if shift < 0: @@ -290,6 +401,8 @@ class LShift(sympy.Function): class RShift(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, shift): if shift < 0: @@ -297,28 +410,107 @@ class RShift(sympy.Function): return base // 2**shift -# Overloaded to be compatible with regular Python. -# https://github.com/pytorch/pytorch/issues/90900 -class Pow(sympy.Function): +def safe_pow(base, exp): + sign = 1 + if base < 0: + base = -base + sign = 1 if exp % 2 == 0 else -1 + return sign * _safe_pow(base, exp) + + +def _safe_pow(base, exponent): + if exponent < 0: + raise ValueError("Exponent must be non-negative.") + + if exponent == 0: + return 1 + + half_exp = safe_pow(base, exponent // 2) + if half_exp > sys.maxsize - 1: + return sys.maxsize - 1 + + result = half_exp * half_exp + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + if exponent % 2 == 1: + result *= base + if result > sys.maxsize - 1: + return sys.maxsize - 1 + + return result + + +class PowByNatural(sympy.Function): + is_integer = True + @classmethod def eval(cls, base, exp): - if exp.is_zero: - return sympy.Integer(1) - elif base.is_zero and exp < 0: - raise ZeroDivisionError(f"{base} cannot be raised to a negative power") - else: - return base**exp + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Integer(safe_pow(base, exp)) + if isinstance(exp, sympy.Integer): + # Translate power into iterated multiplication + r = sympy.Integer(1) + for _ in range(int(exp)): + r *= base + return r + # NB: do NOT translate into sympy.Pow, we will lose knowledge that exp + # is a natural number if we do + + +# base is assumed to be nonnegative, thereby prevent complex numbers from +# occuring +class FloatPow(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, exp): + if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): + return sympy.Float(float(base) ** float(exp)) + # NB: do not do any nontrivial reasoning # Overloaded to be compatible with regular Python. # https://github.com/pytorch/pytorch/issues/90900 -class TrueDiv(sympy.Function): +# +# In particular, sympy division is willing to simplify x/x == 1 +# where 1 is an integer, but this must be a float if x was float. +class FloatTrueDiv(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, base, divisor): + # assert base.is_integer is not True, base + # assert divisor.is_integer is not True, divisor + + if divisor.is_zero: + raise ZeroDivisionError("division by zero") + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(float(base) / float(divisor)) + + +# Overloaded to be compatible with regular Python. We distinguish this from +# FloatTrueDiv, because the code generation has to be different for this case: +# Python has a fancy algorithm for integer true division that isn't just +# "promote both arguments to float and use float division", so you need to +# codegen it differently. While technically you can work it out from the +# types of the input, this is often inconvenient to do in Inductor codegen, +# so just have a different operator +# NB: Right now, Inductor codegen doesn't implement this correctly lol +class IntTrueDiv(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, base, divisor): if divisor.is_zero: raise ZeroDivisionError("division by zero") - else: - return base / divisor + + if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): + return sympy.Float(int(base) / int(divisor)) # TODO: As an indicator, this != 0 implies == 1 (and vice versa). @@ -353,45 +545,85 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function): return None -class Trunc(sympy.Function): +# NB: this is inconsistent with math.trunc in Python +class TruncToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + # assert number.is_integer is not True, number + if isinstance(number, sympy.Number): + # NB: It is safe to use truncation to integer, which is what + # math.trunc does, as Python integers are arbitrary precision and + # so we are guaranteed not to lose precision when we do this + return sympy.Float(math.trunc(float(number))) + + +class TruncToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): + # assert number.is_integer is not True, number + if number == sympy.oo: + return sympy.Integer(sys.maxsize - 1) + if number == -sympy.oo: + return sympy.Integer(-sys.maxsize - 1) + if isinstance(number, sympy.Number): return sympy.Integer(math.trunc(float(number))) -class Round(sympy.Function): +# This is float -> int +class RoundToInt(sympy.Function): is_integer = True @classmethod def eval(cls, number): - if number.is_integer: - return number - elif isinstance(number, sympy.Number): - return sympy.Integer(round(float(number))) + # assert number.is_integer is not True, number - def __int__(self): - # This will only ever be called when computing size hints. At that point, self.args[0] should be a number and - # no longer an expression. If it were, the float call would fail and the caller would handle this further. - return round(float(self.args[0])) # type: ignore[arg-type] + if isinstance(number, sympy.Float): + return sympy.Integer(round(float(number), 0)) +# To get float -> int, Python style round semantics. +# +# x = PyFloat_AsDouble(self); +# if (o_ndigits == Py_None) { +# /* single-argument round or with None ndigits: +# * round to nearest integer */ +# rounded = round(x); +# if (fabs(x-rounded) == 0.5) +# /* halfway case: round to even */ +# rounded = 2.0*round(x/2.0); +# return PyLong_FromDouble(rounded); +# } + + +# NB: Like Round, this only ever returns floats. ndigits cannot be None class RoundDecimal(sympy.Function): + is_integer = False + is_real = True + @classmethod def eval(cls, number, ndigits): - if number.is_integer and ndigits >= 0: + # assert number.is_integer is not True, number + + if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): + return sympy.Float(round(float(number), int(ndigits))) + + +class ToFloat(sympy.Function): + is_integer = False + is_real = True + + @classmethod + def eval(cls, number): + if number in [sympy.oo, -sympy.oo]: return number - elif isinstance(number, sympy.Number) and isinstance(ndigits, sympy.Integer): - value_type, output_type = ( - (int, sympy.Integer) - if isinstance(number, sympy.Integer) - else (float, sympy.Float) - ) - return output_type(round(value_type(number), int(ndigits))) + + if isinstance(number, sympy.Integer): + return sympy.Float(int(number)) def make_opaque_unary_fn(name): diff --git a/torch/utils/_sympy/interp.py b/torch/utils/_sympy/interp.py index 806e91cfe28..09a4b838474 100644 --- a/torch/utils/_sympy/interp.py +++ b/torch/utils/_sympy/interp.py @@ -15,16 +15,23 @@ from sympy.logic.boolalg import Boolean as SympyBoolean, BooleanAtom import torch from .functions import ( + CeilToInt, CleanDiv, + FloatPow, + FloatTrueDiv, FloorDiv, + FloorToInt, + IntTrueDiv, IsNonOverlappingAndDenseIndicator, Mod, ModularIndexing, - Pow, - Round, + PowByNatural, + PythonMod, RoundDecimal, - TrueDiv, - Trunc, + RoundToInt, + ToFloat, + TruncToFloat, + TruncToInt, Where, ) @@ -49,30 +56,39 @@ def handlers(): sympy.Le: "le", sympy.Ge: "ge", sympy.Not: "not_", - TrueDiv: "truediv", + IntTrueDiv: "int_truediv", + FloatTrueDiv: "truediv", FloorDiv: "floordiv", - CleanDiv: "div", - Trunc: "trunc", + CleanDiv: "floordiv", # TODO: hmm? + TruncToFloat: "trunc", Where: "where", sympy.Add: "add", sympy.Mul: "mul", - Pow: "pow", - sympy.Pow: "pow", + FloatPow: "pow", + PowByNatural: "pow_by_natural", + # sympy simplifies x * x into Pow(x, 2), so we need to handle this. + # Do NOT use builtin Pow for floats + # TODO: There is a hazard here, if we have float * float it will + # also get turned into Pow(float, 2) but we don't want this because + # pow_by_natural is assumed to only be integers. Probably the fix is + # to add a FloatMul to impede this optimization + sympy.Pow: "pow_by_natural", Mod: "mod", + PythonMod: "mod", # TODO: this is wrong + # TODO: Inductor can generate these, but it's ill-specified which + # semantics were intended here. Needs to be cleaned up along with + # FloorDiv in a bigger cleanup sympy.Mod: "mod", sympy.Abs: "abs", sympy.log: "log", sympy.exp: "exp", - sympy.floor: "floor", - sympy.ceiling: "ceil", sympy.Min: "minimum", sympy.Max: "maximum", ModularIndexing: "modular_indexing", sympy.functions.elementary.piecewise.ExprCondPair: "expr_cond_pair", sympy.Piecewise: "piecewise", IsNonOverlappingAndDenseIndicator: "is_non_overlapping_and_dense_indicator", - Round: "round", - RoundDecimal: "round", + RoundDecimal: "round_decimal", } for name in ["cos", "sin", "tan", "sinh", "cosh", "tanh", "asin", "acos", "atan"]: HANDLERS[getattr(sympy, name)] = name @@ -84,7 +100,11 @@ ASSOCIATIVE_OPS = {"minimum", "maximum", "mul", "add", "and_", "or_"} def sympy_interp( - analysis, env: Dict[sympy.Symbol, Any], expr: Union[sympy.Expr, SympyBoolean] + analysis, + env: Dict[sympy.Symbol, Any], + expr: Union[sympy.Expr, SympyBoolean], + *, + index_dtype=torch.int64, ): # Handle base cases dtype = None @@ -105,9 +125,32 @@ def sympy_interp( expr.args[1], sympy.core.numbers.Half ): return analysis.sqrt(sympy_interp(analysis, env, expr.args[0])) + if isinstance(expr, ToFloat): + return analysis.to_dtype( + sympy_interp(analysis, env, expr.args[0]), torch.float64 + ) # Recursive case args = [sympy_interp(analysis, env, arg) for arg in expr.args] # type: ignore[arg-type] + + # These handlers are special because they take an extra dtype argument + # specifying what they should convert to, and we need to appropriately set + # this up when we convert from Sympy. A reasonable default when you + # are translating is to conservatively do int64, and then narrow these + # arguments later when you discover you can narrow the index range. But + # if you already know that 32-bit indexing is OK, you can directly do the + # sympy translation with index_dtype=torch.int32 + INDEX_DTYPE_HANDLERS = { + TruncToInt: "trunc_to_int", + sympy.floor: "floor_to_int", + sympy.ceiling: "ceil_to_int", + FloorToInt: "floor_to_int", + CeilToInt: "ceil_to_int", + RoundToInt: "round_to_int", + } + if (handler_name := INDEX_DTYPE_HANDLERS.get(expr.func)) is not None: + return getattr(analysis, handler_name)(*args, index_dtype) + if hasattr(expr.func, "_torch_handler_name"): handler_name = expr.func._torch_handler_name else: diff --git a/torch/utils/_sympy/reference.py b/torch/utils/_sympy/reference.py index 881b9d616eb..b54a0d0503a 100644 --- a/torch/utils/_sympy/reference.py +++ b/torch/utils/_sympy/reference.py @@ -1,12 +1,25 @@ import math +import operator + import sympy import torch from torch.utils._sympy.functions import ( + _keep_float, + FloatPow, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, + Mod, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, OpaqueUnaryFn_sqrt, + PowByNatural, + RoundDecimal, + RoundToInt, + ToFloat, + TruncToInt, ) @@ -62,20 +75,41 @@ class ReferenceAnalysis: @staticmethod def reciprocal(x): - return 1 / x + return FloatTrueDiv(1.0, x) @staticmethod def square(x): - return x * x + return PowByNatural(x, 2) + + @staticmethod + def trunc_to_int(x, dtype): + return TruncToInt(x) + + @staticmethod + def ceil_to_int(x, dtype): + return sympy.ceiling(x) + + @staticmethod + def floor_to_int(x, dtype): + return sympy.floor(x) + + @staticmethod + def floor(x): + return _keep_float(sympy.floor)(x) + + @staticmethod + def ceil(x): + return _keep_float(sympy.ceiling)(x) + + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return ToFloat(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") @staticmethod def mod(x, y): - ret = abs(x) % abs(y) - # without check: - # tracing will fail trying to go through control-flow if x is Proxy() - if isinstance(x, (int, sympy.Number)) and x < 0: - ret *= -1 - return ret + return Mod(x, y) @staticmethod def abs(x): @@ -87,37 +121,31 @@ class ReferenceAnalysis: @staticmethod def truediv(a, b): - return a / b + return FloatTrueDiv(a, b) @staticmethod - def div(a, b): - return ReferenceAnalysis.truediv(a, b) + def int_truediv(a, b): + return IntTrueDiv(a, b) @staticmethod def floordiv(a, b): - if b == 0: - return sympy.nan if a == 0 else sympy.zoo - return a // b + return FloorDiv(a, b) @staticmethod def truncdiv(a, b): - result = a / b - if result.is_finite: - result = sympy.Integer(result) - - return result + raise NotImplementedError("TODO: truncdiv") @staticmethod def add(a, b): - return a + b + return _keep_float(operator.add)(a, b) @staticmethod def mul(a, b): - return a * b + return _keep_float(operator.mul)(a, b) @staticmethod def sub(a, b): - return a - b + return _keep_float(operator.sub)(a, b) @staticmethod def exp(x): @@ -133,39 +161,27 @@ class ReferenceAnalysis: @staticmethod def pow(a, b): - return a**b + return _keep_float(FloatPow)(a, b) + + @staticmethod + def pow_by_natural(a, b): + return PowByNatural(a, b) @staticmethod def minimum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Min(result_type(a), result_type(b)) + return sympy.Min(a, b) @staticmethod def maximum(a, b): - # Poorman's version of upcasting in Sympy - # This won't do for sympy.Expr as the casting does nothing for those - if a.is_Float or not a.is_finite or b.is_Float or not b.is_finite: - result_type = sympy.Float - else: - assert a.is_Integer - assert b.is_Integer - result_type = sympy.Integer - return sympy.Max(result_type(a), result_type(b)) + return sympy.Max(a, b) @staticmethod - def floor(x): - return sympy.floor(x) + def round_to_int(a, dtype): + return RoundToInt(a) @staticmethod - def ceil(x): - return sympy.ceiling(x) + def round_decimal(a, b): + return RoundDecimal(a, b) # Unlike ReferenceAnalysis, does NOT sympyify, instead, works with plain @@ -191,10 +207,20 @@ class PythonReferenceAnalysis(ReferenceAnalysis): def floordiv(a, b): return a // b + @staticmethod + def mod(x, y): + return x % y + @staticmethod def truncdiv(a, b): return a / b + @staticmethod + def to_dtype(x, dtype): + if dtype == torch.float64: + return float(x) + raise NotImplementedError(f"to_dtype {dtype} NYI") + @staticmethod def exp(x): raise AssertionError("exp is not valid shape sympy expr") @@ -216,9 +242,40 @@ class PythonReferenceAnalysis(ReferenceAnalysis): return torch.sym_max(a, b) @staticmethod - def floor(x): + def floor_to_int(x, dtype): return math.floor(x) @staticmethod - def ceil(x): + def ceil_to_int(x, dtype): return math.ceil(x) + + @staticmethod + def floor(x): + return float(math.floor(x)) + + @staticmethod + def ceil(x): + return float(math.ceil(x)) + + @staticmethod + def truediv(a, b): + return a / b + + @staticmethod + def pow(a, b): + return a**b + + @staticmethod + def pow_by_natural(a, b): + # Pray that safe_pow is not needed here lol. In particular, this + # never participates in VR low/high ranges, so overflow should be + # unlikely + return a**b + + @staticmethod + def round_to_int(a, dtype): + return round(a) + + @staticmethod + def round_decimal(a, b): + return round(a, ndigits=b) diff --git a/torch/utils/_sympy/solve.py b/torch/utils/_sympy/solve.py index 6276c696293..02ddf7c3421 100644 --- a/torch/utils/_sympy/solve.py +++ b/torch/utils/_sympy/solve.py @@ -88,6 +88,7 @@ def try_solve( # Return if we were able to isolate 'thing' on the left-hand side. if isinstance(e, sympy.Rel) and e.lhs == thing: + log.debug("solved: %s ---> %s", expr, e) return e, e.rhs return None diff --git a/torch/utils/_sympy/value_ranges.py b/torch/utils/_sympy/value_ranges.py index c7cc96beb98..4d364d4981b 100644 --- a/torch/utils/_sympy/value_ranges.py +++ b/torch/utils/_sympy/value_ranges.py @@ -5,6 +5,7 @@ import itertools import logging import math import operator +import sys from typing import ( Callable, Dict, @@ -25,17 +26,20 @@ import torch from torch._prims_common import dtype_to_type from .functions import ( - OpaqueUnaryFn_acos, - OpaqueUnaryFn_asinh, - OpaqueUnaryFn_atan, - OpaqueUnaryFn_cosh, + _keep_float, + FloatTrueDiv, + FloorDiv, + IntTrueDiv, OpaqueUnaryFn_exp, OpaqueUnaryFn_log, - OpaqueUnaryFn_sinh, OpaqueUnaryFn_sqrt, - OpaqueUnaryFn_tanh, - Round, + PowByNatural, RoundDecimal, + RoundToInt, + safe_pow, + ToFloat, + TruncToFloat, + TruncToInt, ) from .interp import sympy_interp @@ -120,6 +124,8 @@ class ValueRanges(Generic[_T]): lower: _T upper: _T is_bool: bool + is_int: bool + is_float: bool @overload def __init__(self: ValueRanges[sympy.Expr], lower: ExprIn, upper: ExprIn) -> None: @@ -142,8 +148,39 @@ class ValueRanges(Generic[_T]): # Because this is a frozen class object.__setattr__(self, "lower", lower) object.__setattr__(self, "upper", upper) + # Unlike bool/int in Python, we don't report bools are ints object.__setattr__(self, "is_bool", isinstance(lower, SympyBoolean)) - assert isinstance(upper, SympyBoolean) == self.is_bool + if self.is_bool: + assert isinstance(upper, SympyBoolean), (lower, upper) + + # Warning: is_int/is_float is best effort. We do pretty well in + # Dynamo, but in Inductor these attributes are often wrong because we + # are not very rigorous in dtype analysis. This is also why we need + # the flexible analysis for is_int: sometimes a sympy.oo pops in for + # an integer bound. I would /like/ for us not to do this, but it's + # too hard to push the invariant through right now. + + object.__setattr__( + self, + "is_int", + not self.is_bool + and (isinstance(lower, sympy.Integer) or isinstance(upper, sympy.Integer)), + ) + """ + # This assert is just impossible right now, too many sympy bugs + if self.is_int: + # NB: sympy will sometimes randomly lose the float-ness of zero, + # so we also need to account for that in the assertion here. + # See also https://github.com/sympy/sympy/issues/26620 + assert isinstance(lower, sympy.Integer) or lower in [-sympy.oo, 0], ( + lower, + upper, + ) + assert isinstance(upper, sympy.Integer) or upper in [sympy.oo, 0], (lower, upper) + """ + # NB: [-oo, oo] always advertises as float! + object.__setattr__(self, "is_float", not self.is_bool and not self.is_int) + assert self.is_bool or self.is_int or self.is_float, (lower, upper) def boolify(self) -> ValueRanges[SympyBoolean]: if vr_is_bool(self): @@ -184,6 +221,8 @@ class ValueRanges(Generic[_T]): if self == ValueRanges.unknown(): return other assert self.is_bool == other.is_bool, (self, other) + assert self.is_int == other.is_int, (self, other) + assert self.is_float == other.is_float, (self, other) if self.is_bool: return ValueRanges( sympy.Or(self.lower, other.lower), sympy.And(self.upper, other.upper) @@ -353,7 +392,12 @@ class SymPyValueRangeAnalysis: # using nan makes subsequent computation throw, and for the purposes of optimization # returning -math.inf - math.inf is equivalent to giving up if isinstance(value, SupportsFloat) and math.isnan(value): - return ValueRanges.unknown() + if dtype == torch.bool: + return ValueRanges.unknown_bool() + elif dtype.is_floating_point: + return ValueRanges.unknown() + else: + return ValueRanges(-sys.maxsize - 1, sys.maxsize) if is_python: type_ = dtype_to_type(dtype) @@ -369,7 +413,18 @@ class SymPyValueRangeAnalysis: # dtype is intXX assert value.is_integer - return ValueRanges.wrap(value) + r = ValueRanges.wrap(value) + return r + + @staticmethod + def to_dtype(a, dtype, src_dtype=None): + if dtype == torch.float64: + return ValueRanges.increasing_map(a, ToFloat) + return ValueRanges.unknown() + + @staticmethod + def trunc_to_int(a, dtype): + return ValueRanges.increasing_map(a, TruncToInt) @staticmethod def not_(a): @@ -428,7 +483,9 @@ class SymPyValueRangeAnalysis: @staticmethod def add(a, b): - return ValueRanges.coordinatewise_increasing_map(a, b, operator.add) + return ValueRanges.coordinatewise_increasing_map( + a, b, _keep_float(operator.add) + ) @classmethod def mul(cls, a, b): @@ -448,11 +505,20 @@ class SymPyValueRangeAnalysis: else: return a * b - return ValueRanges.coordinatewise_monotone_map(a, b, safe_mul) + return ValueRanges.coordinatewise_monotone_map(a, b, _keep_float(safe_mul)) - @classmethod - def div(cls, a, b): - return cls.truediv(a, b) + @staticmethod + def int_truediv(a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if 0 in b or ( + (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + ): + return ValueRanges.unknown() + else: + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(IntTrueDiv) + ) @staticmethod def truediv(a, b): @@ -463,18 +529,22 @@ class SymPyValueRangeAnalysis: ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.truediv) + return ValueRanges.coordinatewise_monotone_map( + a, b, _keep_float(FloatTrueDiv) + ) @staticmethod def floordiv(a, b): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) if 0 in b or ( - (-sympy.oo in a or sympy.oo in a) and (-sympy.oo in b or sympy.oo in b) + # TODO: make this more precise + (-sympy.oo in a or sympy.oo in a) + or (-sympy.oo in b or sympy.oo in b) ): return ValueRanges.unknown() else: - return ValueRanges.coordinatewise_monotone_map(a, b, operator.floordiv) + return ValueRanges.coordinatewise_monotone_map(a, b, FloorDiv) @classmethod def mod(cls, x, y): @@ -523,17 +593,51 @@ class SymPyValueRangeAnalysis: @classmethod def is_non_overlapping_and_dense_indicator(cls, *args): - return ValueRanges.unknown() + return ValueRanges.unknown() # TODO: type here is wrong + + @classmethod + def pow_by_natural(cls, a, b): + a = ValueRanges.wrap(a) + b = ValueRanges.wrap(b) + if a.is_singleton() and b.is_singleton(): + return ValueRanges.wrap(safe_pow(a.lower, b.lower)) + # NB: Exclude zero, because zero is special + elif a.lower >= 1: + # We should know that b >= 0 but we may have forgotten this fact due + # to replacements, so don't assert it, but DO clamp it to prevent + # degenerate problems + return ValueRanges.coordinatewise_increasing_map( + a, b & ValueRanges(0, sys.maxsize - 1), PowByNatural + ) + elif b.is_singleton(): + if b.lower % 2 == 0: + # x^n where n is even + return ValueRanges.convex_min_zero_map( + a, lambda x: safe_pow(x, b.lower) + ) + else: + # x^n where n is odd + return ValueRanges.increasing_map(a, lambda x: safe_pow(x, b.lower)) + else: + # a is potentially negative, and we don't know if the exponent is + # even or odd. So just conservatively set the upper and lower + # bound based on what the maximum absolute value could be, in both + # directions + max_base = max(a.upper, -a.lower) + return ValueRanges( + -(safe_pow(max_base, b.upper)), safe_pow(max_base, b.upper) + ) @classmethod def pow(cls, a, b): - def is_integer(val): - return isinstance(val, int) or ( - hasattr(val, "is_integer") and val.is_integer - ) + return ValueRanges.unknown() + # We could implement all this, but for floating point pow, is there + # really a point? + """ a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) + # Not implemented yet. It's a bit tricky # If you want to implement it, compute the partial derivatives of a ** b # and check the ranges where the function is increasing / decreasing @@ -553,8 +657,7 @@ class SymPyValueRangeAnalysis: if b == 0: if not a.lower.is_finite: return ValueRanges.unknown() - type_ = sympy.Float if a.lower.is_real else sympy.Integer - return ValueRanges.wrap(type_(1)) + return ValueRanges.wrap(1.0) if b < 0: a = cls.reciprocal(a) @@ -563,21 +666,12 @@ class SymPyValueRangeAnalysis: if a == ValueRanges.unknown(): return ValueRanges.unknown() - # Here b > 0 - if not is_integer(b): - # If the base is positive, then we're good, otherwise nothing's defined - if a.lower >= 0: - return ValueRanges.increasing_map(a, lambda x: x**b) - else: - return ValueRanges.unknown() + # If the base is positive, then we're good, otherwise nothing's defined + if a.lower >= 0: + return ValueRanges.increasing_map(a, lambda x: x**b) else: - # b > 0 integer - if b % 2 == 0: - # x^n where n is even - return ValueRanges.convex_min_zero_map(a, lambda x: x**b) - else: - # x^n where n is odd - return ValueRanges.increasing_map(a, lambda x: x**b) + return ValueRanges.unknown() + """ @staticmethod def reciprocal(x): @@ -586,7 +680,7 @@ class SymPyValueRangeAnalysis: if 0 in x: return ValueRanges.unknown() else: - return ValueRanges.decreasing_map(x, lambda y: 1 / y) + return ValueRanges.decreasing_map(x, lambda y: FloatTrueDiv(1.0, y)) @staticmethod def abs(x): @@ -615,45 +709,64 @@ class SymPyValueRangeAnalysis: def min_or_max(a, b, fn): a = ValueRanges.wrap(a) b = ValueRanges.wrap(b) - - # Performs upcasting first - def fn_(x: sympy.Expr, y: sympy.Expr) -> sympy.Expr: - # Poorman's version of upcasting in Sympy - # Inf is not a float... - if x.is_Integer and y.is_Integer: - result_type = sympy.Integer - elif x.is_rational and y.is_rational: - result_type = sympy.Rational - else: - assert x.is_real or not x.is_finite or y.is_real or not y.is_finite - result_type = sympy.Float - return fn(result_type(x), result_type(y)) - - return ValueRanges.coordinatewise_increasing_map(a, b, fn_) + return ValueRanges.coordinatewise_increasing_map(a, b, fn) @classmethod - def floor(cls, x): + def floor_to_int(cls, x, dtype): return ValueRanges.increasing_map(x, sympy.functions.elementary.integers.floor) @classmethod - def ceil(cls, x): + def ceil_to_int(cls, x, dtype): return ValueRanges.increasing_map( x, sympy.functions.elementary.integers.ceiling ) + # I think these implementations are sound. The hazard here is that sympy + # will carry out the floor/ceil at too high precision and then something + # bad will happen when we convert it to float. + # + # For truncation, the implementation is clearly sound, because the desired + # target float is always exactly representable, since you're just chopping + # off bits the mantissa. But what about ceil/floor? + # + # The important constraint here is that we're not defining floor on + # arbitrary real numbers, only representable float numbers. So we can + # take advantage of the fact that before we reach the first + # unrepresentable integer in floating point space, we have the range of + # numbers corresponding to exponent zero: all integers, with no fractional + # amounts. floor/ceil is an identity operation in this case. In the + # range below here, representable floating point numbers are spaced + # exactly 1/2 apart, and notably, both the floor/ceil are defined floating + # point numbers. There is no "gap" as you step up to the next exponent. + @classmethod - def round(cls, number, ndigits=None): - if ndigits is None: - fn = Round - else: - assert ndigits.is_singleton() - ndigits = ndigits.lower - # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind - # the second parameter. - fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 + def floor(cls, x): + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.floor) + ) + + @classmethod + def ceil(cls, x): + return ValueRanges.increasing_map( + x, _keep_float(sympy.functions.elementary.integers.ceiling) + ) + + @classmethod + def round_decimal(cls, number, ndigits): + if not ndigits.is_singleton(): + return ValueRanges.unknown() + + ndigits = ndigits.lower + # We can't use functools.partial here since sympy doesn't support keyword arguments, but we have to bind + # the second parameter. + fn = lambda number: RoundDecimal(number, ndigits) # type: ignore[misc, assignment] # noqa: E731 return ValueRanges.increasing_map(number, fn) + @classmethod + def round_to_int(cls, number, dtype): + return ValueRanges.increasing_map(number, RoundToInt) + # It's used in some models on symints @staticmethod def sqrt(x): @@ -708,12 +821,15 @@ class SymPyValueRangeAnalysis: @staticmethod def cosh(x): + return ValueRanges(0.0, sympy.oo) + """ x = ValueRanges.wrap(x) if x.lower > 0: return ValueRanges.increasing_map(x, OpaqueUnaryFn_cosh) elif x.upper < 0: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_cosh) return ValueRanges(0.0, sympy.oo) + """ @staticmethod def sin(x): @@ -723,7 +839,8 @@ class SymPyValueRangeAnalysis: @staticmethod def sinh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_sinh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def tan(x): @@ -731,32 +848,37 @@ class SymPyValueRangeAnalysis: @staticmethod def tanh(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_tanh) + return ValueRanges(-sympy.oo, sympy.oo) @staticmethod def asin(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.increasing_map(x, OpaqueUnaryFn_asinh) return ValueRanges.unknown() + """ @staticmethod def acos(x): + return ValueRanges(-sympy.oo, sympy.oo) + """ x = ValueRanges.wrap(x) if -1 <= x.lower and x.upper <= 1: return ValueRanges.decreasing_map(x, OpaqueUnaryFn_acos) return ValueRanges.unknown() + """ @staticmethod def atan(x): - return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) + return ValueRanges(-sympy.oo, sympy.oo) + # return ValueRanges.increasing_map(x, OpaqueUnaryFn_atan) @staticmethod def trunc(x): - def trunc(x): - return sympy.Integer(x) if x.is_finite else x - - return ValueRanges.increasing_map(x, trunc) + return ValueRanges.increasing_map(x, TruncToFloat) class ValueRangeAnalysis(SymPyValueRangeAnalysis): @@ -791,9 +913,10 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis): def reduction(self, name, dtype, src_dtype, reduction_type, index, value): return ValueRanges.unknown() - def index_expr(self, index, dtype): + @classmethod + def index_expr(cls, index, dtype): assert isinstance(index, ValueRanges) - return index + return cls.to_dtype(index, dtype) @staticmethod def to_dtype(x, dtype: torch.dtype, src_dtype: Optional[torch.dtype] = None): @@ -830,12 +953,15 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis): @staticmethod def square(x): - return ValueRanges.convex_min_zero_map(x, lambda y: y * y) + return ValueRanges.convex_min_zero_map(x, lambda y: PowByNatural(y, 2)) @staticmethod def neg(x): return ValueRanges.decreasing_map(x, operator.neg) + # TODO: this is slightly inaccurate because truncdiv operates at integer + # precision, but we're going through float truediv which means we can + # potentially lose precision on the bounds @classmethod def truncdiv(cls, a, b): x = cls.truediv(a, b) @@ -856,6 +982,7 @@ class ValueRangeAnalysis(SymPyValueRangeAnalysis): def bound_sympy( expr: sympy.Expr, ranges: Optional[Dict[sympy.Symbol, ValueRanges]] = None ) -> ValueRanges: + log.debug("bound_sympy(%s, %s)", expr, ranges) if isinstance(expr, sympy.Number): return ValueRanges.wrap(expr)