[BE]: Inline special functions for MPS (#146627)

These header functions should be inlined for consistency and to avoid translation unit / symbol issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/146627
Approved by: https://github.com/dcci
This commit is contained in:
Aaron Gokaslan 2025-02-07 05:15:11 +00:00 committed by PyTorch MergeBot
parent ecf44d1002
commit bc40ccf6aa

View file

@ -8,7 +8,7 @@ namespace metal {
// Translated to metal from https://www.johndcook.com/cpp_erf.html
template <typename T>
T erf(T x) {
inline T erf(T x) {
T a1 = 0.254829592;
T a2 = -0.284496736;
T a3 = 1.421413741;
@ -86,7 +86,7 @@ inline T chbevl(T x, const float array[], const int len) {
// https://github.com/pytorch/pytorch/blob/58b661cda2c002a8e1ac3bee494bfe1f7420437c/aten/src/ATen/native/cuda/Math.cuh#L502
template <typename T>
T i0(T _x) {
inline T i0(T _x) {
auto x = ::metal::fabs(_x);
if (x <= 8.0) {
@ -145,7 +145,7 @@ T i0(T _x) {
// https://github.com/pytorch/pytorch/blob/58b661cda2c002a8e1ac3bee494bfe1f7420437c/aten/src/ATen/native/cuda/Math.cuh#L576
template <typename T>
T i1(T _x) {
inline T i1(T _x) {
const auto x = ::metal::fabs(_x);
if (x <= 8.0) {
@ -199,7 +199,7 @@ template <typename T>
inline float log_gamma(const T);
template <typename T>
float gamma(const T x) {
inline float gamma(const T x) {
if (x < 0.001) {
constexpr float EULER_MASCHERONI = 0.577215664901532860606512090;
// For small x, 1/gamma(x) has power series x + gamma x^2 - ...
@ -452,7 +452,7 @@ inline float digamma(T0 x) {
}
template <typename T>
T sinc(T a) {
inline T sinc(T a) {
if (a == static_cast<T>(0)) {
return static_cast<T>(1);
}