From 67ac6ae4e0a6e6d123f5d526e64d52bd56ef9ff0 Mon Sep 17 00:00:00 2001 From: Suffian Khan Date: Mon, 21 Dec 2020 19:25:21 -0500 Subject: [PATCH] Tune fast Gelu to use exp(x) instead of tanh(x) on Rocm platform (#6174) * tune fast gelu to use exp(x) instead of tanh(x) on rocm * update to use expression 2/(1+exp(-2x))-1 for stability --- .../contrib_ops/rocm/bert/fast_gelu_impl.cu | 23 +++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu index 01fb1d193e..b1ed18b418 100644 --- a/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu +++ b/onnxruntime/contrib_ops/rocm/bert/fast_gelu_impl.cu @@ -39,14 +39,25 @@ constexpr float B = 0.7978845608028654; // sqrt(2.0/M_PI) constexpr float C = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI) +constexpr float one = 1.0; +constexpr float two = 2.0; + template __global__ void FastGeluKernel(const T a, const T b, const T c, int input_length, int bias_length, const T* input, const T* bias, T* output) { const int idx = blockIdx.x * TPB + threadIdx.x; + const T twoT = T(two); + const T oneT = T(one); + if (idx < input_length) { const T x = input[idx]; const T in = (bias == nullptr) ? x : (x + bias[idx % bias_length]); - const T cdf = a + a * _Tanh(in * (c * in * in + b)); + + // const T cdf = a + a * _Tanh(in * (c * in * in + b)); + const T u = twoT * in * (c * in * in + b); + const T emu = __expf(-u); + const T cdf = a + a * (twoT/(oneT + emu) - oneT); + output[idx] = in * cdf; } } @@ -55,10 +66,18 @@ template __global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int input_length, int bias_length, const half2* input, const half2* bias, half2* output) { const int idx = blockIdx.x * TPB + threadIdx.x; + const half2 two2 = __floats2half2_rn(two, two); + const half2 one2 = __floats2half2_rn(one, one); + if (idx < input_length) { const half2 x = input[idx]; const half2 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]); - const half2 cdf = a + a * _Tanh(in * (c * in * in + b)); + + // const half2 cdf = a + a * _Tanh(in * (c * in * in + b)); + const half2 u = two2 * in * (c * in * in + b); + const half2 emu = h2exp(-u); + const half2 cdf = a + a * (two2/(one2 + emu) - one2); + output[idx] = in * cdf; } }