mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-26 22:35:43 +00:00
Tune fast Gelu to use exp(x) instead of tanh(x) on Rocm platform (#6174)
* tune fast gelu to use exp(x) instead of tanh(x) on rocm * update to use expression 2/(1+exp(-2x))-1 for stability
This commit is contained in:
parent
53307a5f2e
commit
67ac6ae4e0
1 changed files with 21 additions and 2 deletions
|
|
@ -39,14 +39,25 @@ constexpr float B = 0.7978845608028654; // sqrt(2.0/M_PI)
|
|||
|
||||
constexpr float C = 0.035677408136300125; // 0.044715 * sqrt(2.0/M_PI)
|
||||
|
||||
constexpr float one = 1.0;
|
||||
constexpr float two = 2.0;
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void FastGeluKernel(const T a, const T b, const T c, int input_length, int bias_length, const T* input, const T* bias, T* output) {
|
||||
const int idx = blockIdx.x * TPB + threadIdx.x;
|
||||
|
||||
const T twoT = T(two);
|
||||
const T oneT = T(one);
|
||||
|
||||
if (idx < input_length) {
|
||||
const T x = input[idx];
|
||||
const T in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
|
||||
const T cdf = a + a * _Tanh(in * (c * in * in + b));
|
||||
|
||||
// const T cdf = a + a * _Tanh(in * (c * in * in + b));
|
||||
const T u = twoT * in * (c * in * in + b);
|
||||
const T emu = __expf(-u);
|
||||
const T cdf = a + a * (twoT/(oneT + emu) - oneT);
|
||||
|
||||
output[idx] = in * cdf;
|
||||
}
|
||||
}
|
||||
|
|
@ -55,10 +66,18 @@ template <unsigned TPB>
|
|||
__global__ void FastGeluKernel2(const half2 a, const half2 b, const half2 c, int input_length, int bias_length, const half2* input, const half2* bias, half2* output) {
|
||||
const int idx = blockIdx.x * TPB + threadIdx.x;
|
||||
|
||||
const half2 two2 = __floats2half2_rn(two, two);
|
||||
const half2 one2 = __floats2half2_rn(one, one);
|
||||
|
||||
if (idx < input_length) {
|
||||
const half2 x = input[idx];
|
||||
const half2 in = (bias == nullptr) ? x : (x + bias[idx % bias_length]);
|
||||
const half2 cdf = a + a * _Tanh(in * (c * in * in + b));
|
||||
|
||||
// const half2 cdf = a + a * _Tanh(in * (c * in * in + b));
|
||||
const half2 u = two2 * in * (c * in * in + b);
|
||||
const half2 emu = h2exp(-u);
|
||||
const half2 cdf = a + a * (two2/(one2 + emu) - one2);
|
||||
|
||||
output[idx] = in * cdf;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue