[mps] Add a shader for spherical_bessel_j0. (#146771)

In preparation for adding the operation to inductor/eager.
Adapted from the CUDA version of the shader.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/146771
Approved by: https://github.com/malfet
This commit is contained in:
Davide Italiano 2025-02-09 05:11:17 +00:00 committed by PyTorch MergeBot
parent 0e83e7d56e
commit 91c4bf39d3

View file

@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) {
return float2(re, im) / a2;
}
template <typename T>
inline T spherical_bessel_j0(T x) {
if (::metal::isinf(x))
return T(0.0);
T x2 = x * x;
T k1 = static_cast<T>(-1.0);
T k2 = static_cast<T>(1.0);
if (::metal::abs(x) < T(0.5)) {
return T(1.0) +
x2 *
(k1 / T(6.0) +
x2 *
(k2 / T(120.0) +
x2 *
(k1 / T(5040.0) +
x2 *
(k2 / T(362880.0) +
x2 *
(k1 / T(39916800.0) +
x2 * (k2 / T(6227020800.0)))))));
}
return ::metal::sin(x) / x;
}
} // namespace metal
} // namespace c10