mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
Fix PythonMod printing for C++ (#143385)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/143385 Approved by: https://github.com/leslie-fang-intel, https://github.com/anijain2305
This commit is contained in:
parent
079a3e0f75
commit
4b77ff9784
3 changed files with 37 additions and 6 deletions
|
|
@ -69,4 +69,34 @@ inline C10_HOST_DEVICE scalar_t div_floor_integer(scalar_t a, scalar_t b) {
|
|||
return a / b;
|
||||
}
|
||||
|
||||
template <
|
||||
typename scalar_t,
|
||||
std::enable_if_t<std::is_floating_point_v<scalar_t>, int> = 0>
|
||||
inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b)
|
||||
__ubsan_ignore_float_divide_by_zero__ {
|
||||
if (C10_UNLIKELY(b == 0)) {
|
||||
// Divide by zero: return standard IEEE result
|
||||
return std::fmod(a, b);
|
||||
}
|
||||
|
||||
auto mod = std::fmod(a, b);
|
||||
if (mod == 0) {
|
||||
mod = C10_COMPAT_COPYSIGN(scalar_t(0), b);
|
||||
} else if ((b < 0) != (mod < 0)) {
|
||||
mod += b;
|
||||
}
|
||||
return mod;
|
||||
}
|
||||
|
||||
template <
|
||||
typename scalar_t,
|
||||
std::enable_if_t<std::is_integral_v<scalar_t>, int> = 0>
|
||||
inline C10_HOST_DEVICE scalar_t div_mod(scalar_t a, scalar_t b) {
|
||||
auto mod = a % b;
|
||||
if ((b < 0) != (mod < 0)) {
|
||||
mod += b;
|
||||
}
|
||||
return mod;
|
||||
}
|
||||
|
||||
} // namespace c10
|
||||
|
|
|
|||
|
|
@ -349,7 +349,9 @@ class ExprPrinterTests(InductorTestCase):
|
|||
x = sympy.Symbol("x", integer=True)
|
||||
expr = PythonMod(x - 10, x)
|
||||
self.assertExpectedInline(pexpr(expr), """((-10) + x) % x""")
|
||||
self.assertExpectedInline(cexpr(expr), f"""((-10{LONG_SUFFIX}) + x) % x""")
|
||||
self.assertExpectedInline(
|
||||
cexpr(expr), f"""c10::div_mod((-10{LONG_SUFFIX}) + x, x)"""
|
||||
)
|
||||
self.assertExpectedInline(
|
||||
texpr(expr), """triton_helpers.remainder_integer((-10) + x, x)"""
|
||||
)
|
||||
|
|
|
|||
|
|
@ -319,12 +319,11 @@ class CppPrinter(ExprPrinter):
|
|||
assert len(expr.args) == 1
|
||||
return f"static_cast<double>({self._print(expr.args[0])})"
|
||||
|
||||
# TODO: This is wrong if one of the inputs is negative. This is hard to
|
||||
# tickle though, as the inputs are typically positive (and if we can prove
|
||||
# they are positive, we will have used Mod instead, for which this codegen
|
||||
# is right).
|
||||
def _print_PythonMod(self, expr: sympy.Expr) -> str:
|
||||
return self.stringify(expr.args, " % ", PRECEDENCE["Atom"] - 0.5)
|
||||
x, div = expr.args
|
||||
x = self.doprint(x)
|
||||
div = self.doprint(div)
|
||||
return f"c10::div_mod({x}, {div})"
|
||||
|
||||
def _print_IntTrueDiv(self, expr: sympy.Expr) -> str:
|
||||
lhs, rhs = expr.args
|
||||
|
|
|
|||
Loading…
Reference in a new issue