From dca5cc025585fadf1d7365a9f6819844ee558fce Mon Sep 17 00:00:00 2001 From: Davide Italiano Date: Sat, 1 Feb 2025 21:45:23 +0000 Subject: [PATCH] [mps] Move polygamma to special_math.h. (#146253) In preparation to implement it in inductor. Pull Request resolved: https://github.com/pytorch/pytorch/pull/146253 Approved by: https://github.com/Skylion007, https://github.com/malfet --- aten/src/ATen/native/mps/kernels/Gamma.metal | 6 +----- c10/metal/special_math.h | 8 ++++++++ 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/mps/kernels/Gamma.metal b/aten/src/ATen/native/mps/kernels/Gamma.metal index a740cd44ee2..0e1b08238bc 100644 --- a/aten/src/ATen/native/mps/kernels/Gamma.metal +++ b/aten/src/ATen/native/mps/kernels/Gamma.metal @@ -136,11 +136,7 @@ kernel void polygamma( constant int64_t& order [[buffer(2)]], uint id [[thread_position_in_grid]]) { // already blocked if n <= 1 - float x = input[id]; - float n = order; - float sgn = ((order % 2) ? 1 : -1); - output[id] = static_cast( - sgn * c10::metal::gamma(n + 1) * c10::metal::zeta(n + 1, x)); + output[id] = static_cast(c10::metal::polygamma(input[id], order)); } #define INSTANTIATE_GAMMA_KERNELS(DTYPE0, DTYPE1) \ diff --git a/c10/metal/special_math.h b/c10/metal/special_math.h index 309f57907e6..fecb8cd3d87 100644 --- a/c10/metal/special_math.h +++ b/c10/metal/special_math.h @@ -382,5 +382,13 @@ float zeta(float x, float q) { return s; } +template +float polygamma(const T0 input, const T1 order) { + float x = input; + float n = order; + float sgn = ((order % 2) ? 1 : -1); + return sgn * gamma(n + 1) * zeta(n + 1, x); +} + } // namespace metal } // namespace c10