Handle Rational divisors in FloorDiv. (#106644)

Follow-up: #101173

This PR fixes the bug presented in #101173 by creating a special case for `sympy.Rational`
divisors, inside `FloorDiv` evaluation. In summary:

```python
FloorDiv(a, Rational(1, b))
a * b
```

Besides that, this PR also does 2 other things:

- Replaces the use of the old `sympy.Mod` by the internal `Mod` (there were a few places
that were still looking for the SymPy one)

- Introduces debugging logs to the translation validator. These can be seen by setting the
environment variable: `TORCH_LOGS=+torch.fx.experimental.validator`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106644
Approved by: https://github.com/ezyang
ghstack dependencies: #106643
This commit is contained in:
Yukio Siraichi 2023-08-05 16:51:56 -03:00 committed by PyTorch MergeBot
parent 33e70e34a3
commit 070eb88a96
4 changed files with 34 additions and 7 deletions

View file

@ -795,6 +795,20 @@ class TestFloorDiv(TestCase):
self.assertEqual(shape_env.simplify(expr), result)
self.assertEqual(shape_env.evaluate_expr(expr), result)
def test_floordiv_simplify_rational(self):
result = 21
a = sympy.Symbol("a", integer=True)
b = sympy.Symbol("b")
cases = [
(FloorDiv(a, sympy.Rational(1, 8)), 8 * a),
(FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)),
]
for expr, expected in cases:
self.assertEqual(expr, expected)
def test_floordiv_assumptions(self):
# We define two Symbols (with different names) for each type to make
# sure the behavior is consistent regardless of whether both arguments

View file

@ -2974,8 +2974,8 @@ class ShapeEnv:
base, divisor = atom.args
if isinstance(divisor, FloorDiv):
base1, divisor1 = divisor.args
if self.replace(base % divisor) in self.divisible and \
base == base1 and self.replace(base1 % divisor1) in self.divisible:
if self.replace(Mod(base, divisor)) in self.divisible and \
base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible:
div_replacements[atom] = divisor1
expr = expr.xreplace(div_replacements)
expr = safe_expand(expr)
@ -2985,7 +2985,7 @@ class ShapeEnv:
rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer))
for fd in expr.atoms(FloorDiv):
base, divisor = fd.args
if self.replace(base % divisor) in self.divisible:
if self.replace(Mod(base, divisor)) in self.divisible:
div_replacements[fd] = base / divisor
new_expr = expr.xreplace(div_replacements)
new_expr = safe_expand(new_expr)

View file

@ -329,11 +329,11 @@ try:
def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
if dtype is torch.int64:
return z3.IntVal(value)
return z3.IntVal(int(value))
if dtype is torch.double:
return z3.RealVal(value)
return z3.RealVal(float(value))
if dtype is torch.bool:
return z3.BoolVal(value)
return z3.BoolVal(bool(value))
raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
@ -385,6 +385,8 @@ try:
# happens: target is TRUE, but source is FALSE.
class TranslationValidator:
def __init__(self) -> None:
log.debug("new instance")
# Mapping of SymPy symbols to Z3 variables.
self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {}
@ -412,6 +414,8 @@ try:
if symbol in self.symbols:
return self.symbols[symbol]
log.debug("new variable: %s (%s)", symbol.name, type.__name__)
if type is int:
var = z3.Int(symbol.name)
@ -444,11 +448,16 @@ try:
return z3expr
def add_source_expr(self, e: z3.BoolRef) -> None:
if e not in self._source_exprs:
log.debug("add source guard: %s", z3str(e))
self._source_exprs.add(e)
def add_target_expr(self, e: sympy.Expr) -> None:
self._check_freesymbols(e)
self._target_exprs.add(self.to_z3_boolean_expr(e))
z3expr = self.to_z3_boolean_expr(e)
if e not in self._target_exprs:
log.debug("add target guard: %s", z3str(z3expr))
self._target_exprs.add(z3expr)
def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
if isinstance(e, sympy.Basic):
@ -457,6 +466,8 @@ try:
else:
ref = e
assert isinstance(ref, z3.BoolRef)
if ref not in self._assertions:
log.debug("add assertion: %s", z3str(ref))
self._assertions.add(ref)
def validate(self) -> None:

View file

@ -69,6 +69,8 @@ class FloorDiv(sympy.Function):
return sympy.floor(base / divisor)
if isinstance(base, FloorDiv):
return FloorDiv(base.args[0], base.args[1] * divisor)
if isinstance(divisor, sympy.Rational) and divisor.p == 1:
return sympy.floor(base * divisor.q)
if isinstance(base, sympy.Add):
for a in base.args: