Fix PythonMod printing for C++ (#143385)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/143385
Approved by: https://github.com/leslie-fang-intel, https://github.com/anijain2305
This commit is contained in:
Isuru Fernando 2025-01-22 10:34:51 +00:00 committed by PyTorch MergeBot
parent 079a3e0f75
commit 4b77ff9784
3 changed files with 37 additions and 6 deletions

View file

@ -69,4 +69,34 @@ inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) {
return a / b;
}
template <
typename scalar_t,
std::enable_if_t<std::is_floating_point_v<scalar_t>, int> = 0>
inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b)
__ubsan_ignore_float_divide_by_zero__ {
if (C10_UNLIKELY(b == 0)) {
// Divide by zero: return standard IEEE result
return std::fmod(a, b);
}
auto mod = std::fmod(a, b);
if (mod == 0) {
mod = C10_COMPAT_COPYSIGN(scalar_t(0), b);
} else if ((b < 0) != (mod < 0)) {
mod += b;
}
return mod;
}
template <
typename scalar_t,
std::enable_if_t<std::is_integral_v<scalar_t>, int> = 0>
inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) {
auto mod = a % b;
if ((b < 0) != (mod < 0)) {
mod += b;
}
return mod;
}
} // namespace c10

View file

@ -349,7 +349,9 @@ class ExprPrinterTests(InductorTestCase):
x = sympy.Symbol("x", integer=True)
expr = PythonMod(x - 10, x)
self.assertExpectedInline(pexpr(expr), """((-10) + x) % x""")
self.assertExpectedInline(cexpr(expr), f"""((-10{LONG_SUFFIX}) + x) % x""")
self.assertExpectedInline(
cexpr(expr), f"""c10::div_mod((-10{LONG_SUFFIX}) + x, x)"""
)
self.assertExpectedInline(
texpr(expr), """triton_helpers.remainder_integer((-10) + x, x)"""
)

View file

@ -319,12 +319,11 @@ class CppPrinter(ExprPrinter):
assert len(expr.args) == 1
return f"static_cast<double>({self._print(expr.args[0])})"
# TODO: This is wrong if one of the inputs is negative. This is hard to
# tickle though, as the inputs are typically positive (and if we can prove
# they are positive, we will have used Mod instead, for which this codegen
# is right).
def _print_PythonMod(self, expr: sympy.Expr) -> str:
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
x, div = expr.args
x = self.doprint(x)
div = self.doprint(div)
return f"c10::div_mod({x}, {div})"
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
lhs, rhs = expr.args