mirror of
https://github.com/saymrwulf/pytorch.git
synced 2026-05-14 20:57:59 +00:00
[mps] Add a shader for spherical_bessel_j0. (#146771)
In preparation for adding the operation to inductor/eager. Adapted from the CUDA version of the shader. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146771 Approved by: https://github.com/malfet
This commit is contained in:
parent
0e83e7d56e
commit
91c4bf39d3
1 changed files with 26 additions and 0 deletions
|
|
@ -477,5 +477,31 @@ inline float2 sinc(float2 inp) {
|
||||||
return float2(re, im) / a2;
|
return float2(re, im) / a2;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
inline T spherical_bessel_j0(T x) {
|
||||||
|
if (::metal::isinf(x))
|
||||||
|
return T(0.0);
|
||||||
|
T x2 = x * x;
|
||||||
|
T k1 = static_cast<T>(-1.0);
|
||||||
|
T k2 = static_cast<T>(1.0);
|
||||||
|
|
||||||
|
if (::metal::abs(x) < T(0.5)) {
|
||||||
|
return T(1.0) +
|
||||||
|
x2 *
|
||||||
|
(k1 / T(6.0) +
|
||||||
|
x2 *
|
||||||
|
(k2 / T(120.0) +
|
||||||
|
x2 *
|
||||||
|
(k1 / T(5040.0) +
|
||||||
|
x2 *
|
||||||
|
(k2 / T(362880.0) +
|
||||||
|
x2 *
|
||||||
|
(k1 / T(39916800.0) +
|
||||||
|
x2 * (k2 / T(6227020800.0)))))));
|
||||||
|
}
|
||||||
|
|
||||||
|
return ::metal::sin(x) / x;
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace metal
|
} // namespace metal
|
||||||
} // namespace c10
|
} // namespace c10
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue