From 91c4bf39d39f0607833b2a226e48a9ab7262c906 Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Sun, 9 Feb 2025 05:11:17 +0000 Subject: [PATCH] [mps] Add a shader for spherical_bessel_j0. (#146771) In preparation for adding the operation to inductor/eager. Adapted from the CUDA version of the shader. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146771 Approved by: https://github.com/malfet --- c10/metal/special_math.h | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 8bcb1f7a53e..04fd7eee18f 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) { return float2(re, im) / a2; } +template +inline T spherical_bessel_j0(T x) { + if (::metal::isinf(x)) + return T(0.0); + T x2 = x * x; + T k1 = static_cast(-1.0); + T k2 = static_cast(1.0); + + if (::metal::abs(x) < T(0.5)) { + return T(1.0) + + x2 * + (k1 / T(6.0) + + x2 * + (k2 / T(120.0) + + x2 * + (k1 / T(5040.0) + + x2 * + (k2 / T(362880.0) + + x2 * + (k1 / T(39916800.0) + + x2 * (k2 / T(6227020800.0))))))); + } + + return ::metal::sin(x) / x; +} + } // namespace metal } // namespace c10