Tune fast Gelu to use exp(x) instead of tanh(x) on Rocm platform (#6174)

* tune fast gelu to use exp(x) instead of tanh(x) on rocm

* update to use expression 2/(1+exp(-2x))-1 for stability
This commit is contained in:
Suffian Khan 2020-12-21 19:25:21 -05:00 committed by GitHub
parent 53307a5f2e
commit 67ac6ae4e0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -39,14 +39,25 @@ constexpr float B = 0.7978845608028654; // sqrt(2.0/M_PI)
constexpr float C = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI)
constexpr float one = 1.0;
constexpr float two = 2.0;
template <typename T, unsigned TPB>
__global__ void FastGeluKernel(const T a, const T b, const T c, int input_length, int bias_length, const T* input, const T* bias, T* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
const T twoT = T(two);
const T oneT = T(one);
if (idx < input_length) {
const T x = input[idx];
const T in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
const T cdf = a + a * _Tanh(in * (c * in * in + b));
// const T cdf = a + a * _Tanh(in * (c * in * in + b));
const T u = twoT * in * (c * in * in + b);
const T emu = __expf(-u);
const T cdf = a + a * (twoT/(oneT + emu) - oneT);
output[idx] = in * cdf;
}
}
@ -55,10 +66,18 @@ template <unsigned TPB>
__global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int input_length, int bias_length, const half2* input, const half2* bias, half2* output) {
const int idx = blockIdx.x * TPB + threadIdx.x;
const half2 two2 = __floats2half2_rn(two, two);
const half2 one2 = __floats2half2_rn(one, one);
if (idx < input_length) {
const half2 x = input[idx];
const half2 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
const half2 cdf = a + a * _Tanh(in * (c * in * in + b));
// const half2 cdf = a + a * _Tanh(in * (c * in * in + b));
const half2 u = two2 * in * (c * in * in + b);
const half2 emu = h2exp(-u);
const half2 cdf = a + a * (two2/(one2 + emu) - one2);
output[idx] = in * cdf;
}
}