From bc40ccf6aab3093efb1f765f9c0599384d2a0fb6 Mon Sep 17 00:00:00 2001 From: Aaron Gokaslan Date: Fri, 7 Feb 2025 05:15:11 +0000 Subject: [PATCH] [BE]: Inline special functions for MPS (#146627) These header functions should be inlined for consistency and to avoid translation unit / symbol issues. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146627 Approved by: https://github.com/dcci --- c10/metal/special_math.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index c28640afa26..8bcb1f7a53e 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -8,7 +8,7 @@ namespace metal { // Translated to metal from https://www.johndcook.com/cpp_erf.html template -T erf(T x) { +inline T erf(T x) { T a1 = 0.254829592; T a2 = -0.284496736; T a3 = 1.421413741; @@ -86,7 +86,7 @@ inline T chbevl(T x, const float array[], const int len) { // https://github.com/pytorch/pytorch/blob/58b661cda2c002a8e1ac3bee494bfe1f7420437c/aten/src/ATen/native/cuda/Math.cuh#L502 template -T i0(T _x) { +inline T i0(T _x) { auto x = ::metal::fabs(_x); if (x <= 8.0) { @@ -145,7 +145,7 @@ T i0(T _x) { // https://github.com/pytorch/pytorch/blob/58b661cda2c002a8e1ac3bee494bfe1f7420437c/aten/src/ATen/native/cuda/Math.cuh#L576 template -T i1(T _x) { +inline T i1(T _x) { const auto x = ::metal::fabs(_x); if (x <= 8.0) { @@ -199,7 +199,7 @@ template inline float log_gamma(const T); template -float gamma(const T x) { +inline float gamma(const T x) { if (x < 0.001) { constexpr float EULER_MASCHERONI = 0.577215664901532860606512090; // For small x, 1/gamma(x) has power series x + gamma x^2 - ... @@ -452,7 +452,7 @@ inline float digamma(T0 x) { } template -T sinc(T a) { +inline T sinc(T a) { if (a == static_cast(0)) { return static_cast(1); }