Implement a more stable softmax (#2715)

* Implement a more stable SoftMax
 e^x is represented as infinity if x is large enough, like 100.f. Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
A math transform as below is leveraged to get a stable softmax:
e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))

And for convenience, force max to 0.f if all xi are negative
This commit is contained in:
Yufeng Li 2020-01-06 14:28:12 -08:00 committed by GitHub
parent 6f66260372
commit 72bdfc8cd4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
2 changed files with 79 additions and 28 deletions

View file

@ -238,8 +238,18 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
float* x = reinterpret_cast<T*>(scratch_data) + j * D;
float* y = x;
for (int i = 0; i < D; i++)
y[i] = expf(x[i]);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
// And for convenience, force max to 0.f if all xi are negative
float max = 0.f;
for (int i = 0; i < D; i++) {
if (max < x[i]) max = x[i];
}
for (int i = 0; i < D; i++) {
y[i] = expf(x[i] - max);
}
double sum = 0.0;

View file

@ -54,84 +54,125 @@ size_t GetAttentionWorkspaceSize(size_t element_size, int batch_size, int num_he
}
template <typename T, unsigned TPB>
__device__ inline void Softmax(const int ld, const int last_valid, const T* input, T* output) {
__device__ inline void Softmax(const int ld, const int num_valid, const T* input, T* output) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float reverse_z;
__shared__ float sum_reverse_block;
__shared__ float max_block;
float thread_data(0);
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld;
for (int i = threadIdx.x; i < last_valid; i += TPB) {
for (int i = threadIdx.x; i < num_valid; i += TPB) {
const int index = offset + i;
const float val = input[index];
thread_data += expf(val);
if (thread_data < float(input[index])) {
thread_data = float(input[index]);
}
}
cub::Sum sum;
const auto z = BlockReduce(tmp_storage).Reduce(thread_data, sum);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
// And for convenience, force max to 0.f if all xi are negative
const auto max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max());
// Store max value
if (threadIdx.x == 0) {
reverse_z = 1.f / z;
max_block = max;
}
__syncthreads();
for (int i = threadIdx.x; i < num_valid; i += TPB) {
const int index = offset + i;
const float val = input[index];
thread_data += expf(val - max_block);
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data, cub::Sum());
if (threadIdx.x == 0) {
sum_reverse_block = 1.f / sum;
}
__syncthreads();
for (int i = threadIdx.x; i < ld; i += TPB) {
const int index = offset + i;
const float val = (i < last_valid) ? expf(float(input[index])) * reverse_z : 0.f;
const float val = (i < num_valid) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f;
output[index] = T(val);
}
}
template <typename T, unsigned TPB>
__device__ inline void SoftmaxSmall(const int ld, const int last_valid, const T* input, T* output) {
__device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T* input, T* output) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float reverse_z;
__shared__ float sum_reverse_block;
__shared__ float max_block;
float thread_data(0);
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld;
const int index = offset + threadIdx.x;
if (threadIdx.x < last_valid) {
const float val = input[index];
thread_data = expf(val);
float thread_data(0);
if (threadIdx.x < num_valid) {
thread_data = input[index];
}
cub::Sum sum;
const auto z = BlockReduce(tmp_storage).Reduce(thread_data, sum);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
// And for convenience, force max to 0.f if all xi are negative
const auto max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), num_valid);
// Store max value
if (threadIdx.x == 0) {
reverse_z = (1.f) / z;
max_block = max;
}
__syncthreads();
if (threadIdx.x < num_valid) {
const float val = input[index];
thread_data = expf(val - max_block);
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data, cub::Sum(), num_valid);
// Store max value
if (threadIdx.x == 0) {
sum_reverse_block = (1.f) / sum;
}
__syncthreads();
if (threadIdx.x < ld) {
// this will be 0 for threadIdx.x >= last_valid
output[index] = T(thread_data * reverse_z);
// this will be 0 for threadIdx.x >= num_valid
output[index] = T(thread_data * sum_reverse_block);
}
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernelSmall(const int sequence_length, const int* mask_index, const T* input, T* output) {
__shared__ int last_valid;
__shared__ int num_valid;
if (threadIdx.x == 0) {
last_valid = min(sequence_length, mask_index[blockIdx.y]);
num_valid = min(sequence_length, mask_index[blockIdx.y]);
}
__syncthreads();
SoftmaxSmall<T, TPB>(sequence_length, last_valid, input, output);
SoftmaxSmall<T, TPB>(sequence_length, num_valid, input, output);
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernel(const int sequence_length, const int* mask_index, const T* input, T* output) {
__shared__ int last_valid;
__shared__ int num_valid;
if (threadIdx.x == 0) {
last_valid = min(sequence_length, mask_index[blockIdx.y]);
num_valid = min(sequence_length, mask_index[blockIdx.y]);
}
__syncthreads();
Softmax<T, TPB>(sequence_length, last_valid, input, output);
Softmax<T, TPB>(sequence_length, num_valid, input, output);
}
template <typename T>