diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 2027dd1da0..9ec7e849c0 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -238,8 +238,18 @@ Status Attention::Compute(OpKernelContext* context) const { float* x = reinterpret_cast(scratch_data) + j * D; float* y = x; - for (int i = 0; i < D; i++) - y[i] = expf(x[i]); + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + // And for convenience, force max to 0.f if all xi are negative + float max = 0.f; + for (int i = 0; i < D; i++) { + if (max < x[i]) max = x[i]; + } + for (int i = 0; i < D; i++) { + y[i] = expf(x[i] - max); + } double sum = 0.0; diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 53c79598ab..5c72a1d53e 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -54,84 +54,125 @@ size_t GetAttentionWorkspaceSize(size_t element_size, int batch_size, int num_he } template -__device__ inline void Softmax(const int ld, const int last_valid, const T* input, T* output) { +__device__ inline void Softmax(const int ld, const int num_valid, const T* input, T* output) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; - __shared__ float reverse_z; + __shared__ float sum_reverse_block; + __shared__ float max_block; float thread_data(0); + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld; - for (int i = threadIdx.x; i < last_valid; i += TPB) { + for (int i = threadIdx.x; i < num_valid; i += TPB) { const int index = offset + i; - const float val = input[index]; - thread_data += expf(val); + if (thread_data < float(input[index])) { + thread_data = float(input[index]); + } } - cub::Sum sum; - const auto z = BlockReduce(tmp_storage).Reduce(thread_data, sum); + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + // And for convenience, force max to 0.f if all xi are negative + const auto max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max()); + + // Store max value if (threadIdx.x == 0) { - reverse_z = 1.f / z; + max_block = max; + } + __syncthreads(); + + for (int i = threadIdx.x; i < num_valid; i += TPB) { + const int index = offset + i; + const float val = input[index]; + thread_data += expf(val - max_block); + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data, cub::Sum()); + if (threadIdx.x == 0) { + sum_reverse_block = 1.f / sum; } __syncthreads(); for (int i = threadIdx.x; i < ld; i += TPB) { const int index = offset + i; - const float val = (i < last_valid) ? expf(float(input[index])) * reverse_z : 0.f; + const float val = (i < num_valid) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f; output[index] = T(val); } } template -__device__ inline void SoftmaxSmall(const int ld, const int last_valid, const T* input, T* output) { +__device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T* input, T* output) { using BlockReduce = cub::BlockReduce; __shared__ typename BlockReduce::TempStorage tmp_storage; - __shared__ float reverse_z; + __shared__ float sum_reverse_block; + __shared__ float max_block; - float thread_data(0); const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * ld; const int index = offset + threadIdx.x; - if (threadIdx.x < last_valid) { - const float val = input[index]; - thread_data = expf(val); + + float thread_data(0); + if (threadIdx.x < num_valid) { + thread_data = input[index]; } - cub::Sum sum; - const auto z = BlockReduce(tmp_storage).Reduce(thread_data, sum); + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + // And for convenience, force max to 0.f if all xi are negative + const auto max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), num_valid); + + // Store max value if (threadIdx.x == 0) { - reverse_z = (1.f) / z; + max_block = max; + } + __syncthreads(); + + if (threadIdx.x < num_valid) { + const float val = input[index]; + thread_data = expf(val - max_block); + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data, cub::Sum(), num_valid); + + // Store max value + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; } __syncthreads(); if (threadIdx.x < ld) { - // this will be 0 for threadIdx.x >= last_valid - output[index] = T(thread_data * reverse_z); + // this will be 0 for threadIdx.x >= num_valid + output[index] = T(thread_data * sum_reverse_block); } } template __global__ void MaskedSoftmaxKernelSmall(const int sequence_length, const int* mask_index, const T* input, T* output) { - __shared__ int last_valid; + __shared__ int num_valid; if (threadIdx.x == 0) { - last_valid = min(sequence_length, mask_index[blockIdx.y]); + num_valid = min(sequence_length, mask_index[blockIdx.y]); } __syncthreads(); - SoftmaxSmall(sequence_length, last_valid, input, output); + SoftmaxSmall(sequence_length, num_valid, input, output); } template __global__ void MaskedSoftmaxKernel(const int sequence_length, const int* mask_index, const T* input, T* output) { - __shared__ int last_valid; + __shared__ int num_valid; if (threadIdx.x == 0) { - last_valid = min(sequence_length, mask_index[blockIdx.y]); + num_valid = min(sequence_length, mask_index[blockIdx.y]); } __syncthreads(); - Softmax(sequence_length, last_valid, input, output); + Softmax(sequence_length, num_valid, input, output); } template