diff --git a/test/test_dynamic_shapes.py b/test/test_dynamic_shapes.py index d5dd3e8ee8a..bb54978a70f 100644 --- a/test/test_dynamic_shapes.py +++ b/test/test_dynamic_shapes.py @@ -795,6 +795,20 @@ class TestFloorDiv(TestCase): self.assertEqual(shape_env.simplify(expr), result) self.assertEqual(shape_env.evaluate_expr(expr), result) + def test_floordiv_simplify_rational(self): + result = 21 + + a = sympy.Symbol("a", integer=True) + b = sympy.Symbol("b") + + cases = [ + (FloorDiv(a, sympy.Rational(1, 8)), 8 * a), + (FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)), + ] + + for expr, expected in cases: + self.assertEqual(expr, expected) + def test_floordiv_assumptions(self): # We define two Symbols (with different names) for each type to make # sure the behavior is consistent regardless of whether both arguments diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 55f1e64a0b7..e0af2400727 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -2974,8 +2974,8 @@ class ShapeEnv: base, divisor = atom.args if isinstance(divisor, FloorDiv): base1, divisor1 = divisor.args - if self.replace(base % divisor) in self.divisible and \ - base == base1 and self.replace(base1 % divisor1) in self.divisible: + if self.replace(Mod(base, divisor)) in self.divisible and \ + base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible: div_replacements[atom] = divisor1 expr = expr.xreplace(div_replacements) expr = safe_expand(expr) @@ -2985,7 +2985,7 @@ class ShapeEnv: rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer)) for fd in expr.atoms(FloorDiv): base, divisor = fd.args - if self.replace(base % divisor) in self.divisible: + if self.replace(Mod(base, divisor)) in self.divisible: div_replacements[fd] = base / divisor new_expr = expr.xreplace(div_replacements) new_expr = safe_expand(new_expr) diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index df993fbba2a..4d3c37755b9 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -329,11 +329,11 @@ try: def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef: if dtype is torch.int64: - return z3.IntVal(value) + return z3.IntVal(int(value)) if dtype is torch.double: - return z3.RealVal(value) + return z3.RealVal(float(value)) if dtype is torch.bool: - return z3.BoolVal(value) + return z3.BoolVal(bool(value)) raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}") def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef: @@ -385,6 +385,8 @@ try: # happens: target is TRUE, but source is FALSE. class TranslationValidator: def __init__(self) -> None: + log.debug("new instance") + # Mapping of SymPy symbols to Z3 variables. self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {} @@ -412,6 +414,8 @@ try: if symbol in self.symbols: return self.symbols[symbol] + log.debug("new variable: %s (%s)", symbol.name, type.__name__) + if type is int: var = z3.Int(symbol.name) @@ -444,11 +448,16 @@ try: return z3expr def add_source_expr(self, e: z3.BoolRef) -> None: + if e not in self._source_exprs: + log.debug("add source guard: %s", z3str(e)) self._source_exprs.add(e) def add_target_expr(self, e: sympy.Expr) -> None: self._check_freesymbols(e) - self._target_exprs.add(self.to_z3_boolean_expr(e)) + z3expr = self.to_z3_boolean_expr(e) + if e not in self._target_exprs: + log.debug("add target guard: %s", z3str(z3expr)) + self._target_exprs.add(z3expr) def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None: if isinstance(e, sympy.Basic): @@ -457,6 +466,8 @@ try: else: ref = e assert isinstance(ref, z3.BoolRef) + if ref not in self._assertions: + log.debug("add assertion: %s", z3str(ref)) self._assertions.add(ref) def validate(self) -> None: diff --git a/torch/utils/_sympy/functions.py b/torch/utils/_sympy/functions.py index 57452553463..30b8b869cfc 100644 --- a/torch/utils/_sympy/functions.py +++ b/torch/utils/_sympy/functions.py @@ -69,6 +69,8 @@ class FloorDiv(sympy.Function): return sympy.floor(base / divisor) if isinstance(base, FloorDiv): return FloorDiv(base.args[0], base.args[1] * divisor) + if isinstance(divisor, sympy.Rational) and divisor.p == 1: + return sympy.floor(base * divisor.q) if isinstance(base, sympy.Add): for a in base.args: