mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Handle Rational divisors in FloorDiv. (#106644)
Follow-up: #101173 This PR fixes the bug presented in #101173 by creating a special case for `sympy.Rational` divisors, inside `FloorDiv` evaluation. In summary: ```python FloorDiv(a, Rational(1, b)) a * b ``` Besides that, this PR also does 2 other things: - Replaces the use of the old `sympy.Mod` by the internal `Mod` (there were a few places that were still looking for the SymPy one) - Introduces debugging logs to the translation validator. These can be seen by setting the environment variable: `TORCH_LOGS=+torch.fx.experimental.validator` Pull Request resolved: https://github.com/pytorch/pytorch/pull/106644 Approved by: https://github.com/ezyang ghstack dependencies: #106643
This commit is contained in:
parent
33e70e34a3
commit
070eb88a96
4 changed files with 34 additions and 7 deletions
|
|
@ -795,6 +795,20 @@ class TestFloorDiv(TestCase):
|
|||
self.assertEqual(shape_env.simplify(expr), result)
|
||||
self.assertEqual(shape_env.evaluate_expr(expr), result)
|
||||
|
||||
def test_floordiv_simplify_rational(self):
|
||||
result = 21
|
||||
|
||||
a = sympy.Symbol("a", integer=True)
|
||||
b = sympy.Symbol("b")
|
||||
|
||||
cases = [
|
||||
(FloorDiv(a, sympy.Rational(1, 8)), 8 * a),
|
||||
(FloorDiv(b, sympy.Rational(1, 8)), sympy.floor(8 * b)),
|
||||
]
|
||||
|
||||
for expr, expected in cases:
|
||||
self.assertEqual(expr, expected)
|
||||
|
||||
def test_floordiv_assumptions(self):
|
||||
# We define two Symbols (with different names) for each type to make
|
||||
# sure the behavior is consistent regardless of whether both arguments
|
||||
|
|
|
|||
|
|
@ -2974,8 +2974,8 @@ class ShapeEnv:
|
|||
base, divisor = atom.args
|
||||
if isinstance(divisor, FloorDiv):
|
||||
base1, divisor1 = divisor.args
|
||||
if self.replace(base % divisor) in self.divisible and \
|
||||
base == base1 and self.replace(base1 % divisor1) in self.divisible:
|
||||
if self.replace(Mod(base, divisor)) in self.divisible and \
|
||||
base == base1 and self.replace(Mod(base1, divisor1)) in self.divisible:
|
||||
div_replacements[atom] = divisor1
|
||||
expr = expr.xreplace(div_replacements)
|
||||
expr = safe_expand(expr)
|
||||
|
|
@ -2985,7 +2985,7 @@ class ShapeEnv:
|
|||
rationals = expr.atoms(sympy.Rational).difference(expr.atoms(sympy.Integer))
|
||||
for fd in expr.atoms(FloorDiv):
|
||||
base, divisor = fd.args
|
||||
if self.replace(base % divisor) in self.divisible:
|
||||
if self.replace(Mod(base, divisor)) in self.divisible:
|
||||
div_replacements[fd] = base / divisor
|
||||
new_expr = expr.xreplace(div_replacements)
|
||||
new_expr = safe_expand(new_expr)
|
||||
|
|
|
|||
|
|
@ -329,11 +329,11 @@ try:
|
|||
|
||||
def constant(self, value: Any, dtype: torch.dtype) -> z3.ExprRef:
|
||||
if dtype is torch.int64:
|
||||
return z3.IntVal(value)
|
||||
return z3.IntVal(int(value))
|
||||
if dtype is torch.double:
|
||||
return z3.RealVal(value)
|
||||
return z3.RealVal(float(value))
|
||||
if dtype is torch.bool:
|
||||
return z3.BoolVal(value)
|
||||
return z3.BoolVal(bool(value))
|
||||
raise ValueError(f"unsupported dtype (SympyToZ3): {dtype}")
|
||||
|
||||
def truediv(self, numerator: z3.ArithRef, denominator: z3.ArithRef) -> z3.ArithRef:
|
||||
|
|
@ -385,6 +385,8 @@ try:
|
|||
# happens: target is TRUE, but source is FALSE.
|
||||
class TranslationValidator:
|
||||
def __init__(self) -> None:
|
||||
log.debug("new instance")
|
||||
|
||||
# Mapping of SymPy symbols to Z3 variables.
|
||||
self.symbols: Dict[sympy.Symbol, z3.ExprRef] = {}
|
||||
|
||||
|
|
@ -412,6 +414,8 @@ try:
|
|||
if symbol in self.symbols:
|
||||
return self.symbols[symbol]
|
||||
|
||||
log.debug("new variable: %s (%s)", symbol.name, type.__name__)
|
||||
|
||||
if type is int:
|
||||
var = z3.Int(symbol.name)
|
||||
|
||||
|
|
@ -444,11 +448,16 @@ try:
|
|||
return z3expr
|
||||
|
||||
def add_source_expr(self, e: z3.BoolRef) -> None:
|
||||
if e not in self._source_exprs:
|
||||
log.debug("add source guard: %s", z3str(e))
|
||||
self._source_exprs.add(e)
|
||||
|
||||
def add_target_expr(self, e: sympy.Expr) -> None:
|
||||
self._check_freesymbols(e)
|
||||
self._target_exprs.add(self.to_z3_boolean_expr(e))
|
||||
z3expr = self.to_z3_boolean_expr(e)
|
||||
if e not in self._target_exprs:
|
||||
log.debug("add target guard: %s", z3str(z3expr))
|
||||
self._target_exprs.add(z3expr)
|
||||
|
||||
def add_assertion(self, e: Union[z3.BoolRef, sympy.Basic]) -> None:
|
||||
if isinstance(e, sympy.Basic):
|
||||
|
|
@ -457,6 +466,8 @@ try:
|
|||
else:
|
||||
ref = e
|
||||
assert isinstance(ref, z3.BoolRef)
|
||||
if ref not in self._assertions:
|
||||
log.debug("add assertion: %s", z3str(ref))
|
||||
self._assertions.add(ref)
|
||||
|
||||
def validate(self) -> None:
|
||||
|
|
|
|||
|
|
@ -69,6 +69,8 @@ class FloorDiv(sympy.Function):
|
|||
return sympy.floor(base / divisor)
|
||||
if isinstance(base, FloorDiv):
|
||||
return FloorDiv(base.args[0], base.args[1] * divisor)
|
||||
if isinstance(divisor, sympy.Rational) and divisor.p == 1:
|
||||
return sympy.floor(base * divisor.q)
|
||||
|
||||
if isinstance(base, sympy.Add):
|
||||
for a in base.args:
|
||||
|
|
|
|||
Loading…
Reference in a new issue