diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index 01d17144ff..cf98fd51c4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -59,7 +59,10 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, // input : (batch_size, sequence_length, hidden_size) // weights : (hidden_size, 3 * hidden_size) // bias : (3 * hidden_size) - // mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length) + // mask_index : nullptr, (batch_size), (2 * batch_size), + // or (batch_size, 1), (1, 1) + // or (batch_size, past_sequence_length + sequence_length) + // or (batch_size, sequence_length, past_sequence_length + sequence_length) // past : (2, batch_size, num_heads, past_sequence_length, head_size) const auto& dims = input_shape.GetDims(); @@ -136,8 +139,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape, return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)"); } } + } else if (mask_dims.size() == 3) { + if (static_cast(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast(mask_dims[2]) != past_sequence_length + sequence_length) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' of 3d shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)"); + } } else { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1 or 2 dimensions, got ", + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ", mask_dims.size()); } } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_base.h index b61fe2597b..d7fe44f077 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_base.h @@ -26,8 +26,8 @@ class AttentionBase { int sequence_length, int& past_sequence_length) const; - int num_heads_; // number of attention heads - bool is_unidirectional_; // whether every token can only attend to previous tokens. + int num_heads_; // number of attention heads + bool is_unidirectional_; // whether every token can only attend to previous tokens. }; } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h index 4252098dbf..c2d0d03e59 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_cpu_base.h @@ -89,7 +89,7 @@ class AttentionCPUBase : public AttentionBase { const T* K, // k data. Its size is BxNxSxH const int32_t* mask_index, // mask index. nullptr if no mask or its size is B const std::vector* mask_index_dims, // mask index shape - T* mask_data, // buffer for mask data. Its size is: SxS* if is_unidirectional_; BxSxS* if mask_index; null otherwise + T* mask_data, // buffer for mask data. It is nullptr if mask_index is nullptr, otherwise its shape is BxSxS* int batch_size, // batch size of self-attention int sequence_length, // sequence length of self-attention int past_sequence_length, // sequence length of past state diff --git a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h index 5acb8c2a30..5afe24166a 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention_helper.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention_helper.h @@ -72,12 +72,31 @@ void PrepareMask(const int32_t* mask_index, // mask_data has been filled with 0, and its shape is BxSxS* T* p_mask = mask_data; + // For 3D mask, convert values 0 to -10000.0, and 1 to 0.0, then apply unidirectional mask if any. + if (nullptr != mask_index_dims && mask_index_dims->size() == 3) { + for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) { + p_mask[i] = (mask_index[i] > 0) ? static_cast(0.0f) : static_cast(-10000.0f); + } + + if (is_unidirectional) { + for (int b_i = 0; b_i < batch_size; b_i++) { + for (int s_i = 0; s_i < sequence_length - 1; s_i++) { + for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) { + p_mask[s_i * all_sequence_length + m_i] += static_cast(-10000.0f); + } + } + p_mask += sequence_length * all_sequence_length; + } + } + + return; + } + bool is_raw_attention_mask = (nullptr != mask_index_dims && mask_index_dims->size() == 2); bool has_mask_start_position = (nullptr != mask_index_dims && mask_index_dims->size() == 1 && static_cast(mask_index_dims->at(0)) == 2 * batch_size); for (int b_i = 0; b_i < batch_size; b_i++) { // TODO: mask_index can be used in softmax to save some calculation. - if (nullptr != mask_index) { if (is_raw_attention_mask) { // Raw attention mask has value 0 or 1. Here we convert 0 to -10000.0, and 1 to 0.0. @@ -120,7 +139,6 @@ void PrepareMask(const int32_t* mask_index, p_mask += sequence_length * all_sequence_length; } - } // Concatenate a past state chunk S'xH with input state chunk SxH into present state chunk S*xH diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 72634fa6d4..00f92b4f1c 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -21,13 +21,11 @@ limitations under the License. // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include -#include #include -#include #include "core/providers/cuda/cu_inc/common.cuh" #include "core/providers/cuda/cuda_common.h" #include "attention_impl.h" +#include "attention_softmax.h" using namespace onnxruntime::cuda; using namespace cub; @@ -40,7 +38,7 @@ static size_t AlignTo(size_t a, size_t b) { return CeilDiv(a, b) * b; } -size_t ScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) { +size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) { const size_t len = batch_size * num_heads * sequence_length * all_sequence_length; const size_t bytes = len * element_size; @@ -57,580 +55,7 @@ size_t GetAttentionWorkspaceSize( int sequence_length, int past_sequence_length) { size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size; - return qkv_size + 2 * ScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length); -} - -template -__device__ inline void Softmax(const int all_sequence_length, - const int sequence_length, - const int valid_end, - const int valid_start, - const T* input, - T* output) { - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - float thread_data_max(-CUDART_INF_F); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - if (thread_data_max < float(input[index])) { - thread_data_max = float(input[index]); - } - } - } - - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max()); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_sum(0.f); - for (int i = threadIdx.x; i < valid_end; i += TPB) { - if (i >= valid_start) { - const int index = offset + i; - const float val = input[index]; - thread_data_sum += expf(val - max_block); - } - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum()); - if (threadIdx.x == 0) { - sum_reverse_block = 1.f / sum; - } - __syncthreads(); - - for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { - const int index = offset + i; - const float val = (i >= valid_start && i < valid_end) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f; - output[index] = T(val); - } -} - -template -__device__ inline void SoftmaxSmall(const int all_sequence_length, - const int sequence_length, - const int valid_end, - const int valid_start, - const T* input, - T* output, - bool is_unidirectional) { - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; - const int index = offset + threadIdx.x; - - bool is_valid = false; // whether it has attention mask == 1. - - // Update end position for unidirectional. - int end = valid_end; - if (is_unidirectional) { - int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; - if (end_unid <= valid_start) { - // In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000. - // So [0, end_unid) will also have value after softmax. - is_valid = threadIdx.x < end_unid; - } else { - end = min(valid_end, end_unid); - } - } - - is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end); - - // e^x is represented as infinity if x is large enough, like 100.f. - // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. - // a math transform as below is leveraged to get a stable softmax: - // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) - float thread_data_max = is_valid ? float(input[index]) : float(-CUDART_INF_F); - const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_exp(0.f); - if (is_valid) { - thread_data_exp = expf(float(input[index]) - max_block); - } - - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end); - - // Store value of 1.0/sum. - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data_exp * sum_reverse_block); - } -} - -template -__device__ inline void SoftmaxWithMask2DSmall(const int all_sequence_length, - const int sequence_length, - const int* attention_mask, // 2D attention mask - const T* input, - T* output, - const bool is_unidirectional, - const float scalar) { - using BlockReduce = cub::BlockReduce; - __shared__ typename BlockReduce::TempStorage tmp_storage; - - __shared__ float sum_reverse_block; - __shared__ float max_block; - - // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; - int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; - - float thread_data = -CUDART_INF_F; - if (threadIdx.x < all_sequence_length) { - const int& mask = attention_mask[blockIdx.y * all_sequence_length + threadIdx.x]; - float mask_value = mask > 0 ? 0.0f : -10000.0f; - - if (is_unidirectional) { - int from_index = all_sequence_length - sequence_length + (blockIdx.x % sequence_length); // offset of from token in all sequence length. - if (threadIdx.x > from_index) { - mask_value += -10000.0f; - } - } - - thread_data = float(input[index]) * scalar + mask_value; - } - - const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length); - - // Store max value - if (threadIdx.x == 0) { - max_block = max; - } - __syncthreads(); - - float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length); - - // Store value of 1.0/sum - if (threadIdx.x == 0) { - sum_reverse_block = (1.f) / sum; - } - __syncthreads(); - - if (threadIdx.x < all_sequence_length) { - output[index] = T(thread_data_exp * sum_reverse_block); - } -} - -template -__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) { - SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, input, output, is_unidirectional); -} - -template -__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* input, T* output) { - Softmax(all_sequence_length, sequence_length, all_sequence_length, 0, input, output); -} - -template -bool ComputeSoftmax( - cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const T* input, T* output, bool is_unidirectional) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); - } else if (!is_unidirectional) { - const int blockSize = 1024; - SoftmaxKernel<<>>(all_sequence_length, sequence_length, input, output); - } else { - ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024."); - } - - return CUDA_CALL(cudaPeekAtLastError()); -} - -template -__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output, bool is_unidirectional) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, input, output, is_unidirectional); -} - -template -__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output) { - __shared__ int start_position; - __shared__ int end_position; - - if (threadIdx.x == 0) { - const int batch = blockIdx.y; - start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; - end_position = min(all_sequence_length, mask_end[batch]); - - // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. - if (start_position >= end_position) { - start_position = 0; - end_position = all_sequence_length; - } - } - __syncthreads(); - - Softmax(all_sequence_length, sequence_length, end_position, start_position, input, output); -} - -template -__global__ void SoftmaxWithMask2DSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar) { - SoftmaxWithMask2DSmall(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); -} - -template -bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* mask_index, const int* mask_start, const T* input, T* output, const bool is_unidirectional) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - - if (all_sequence_length <= 32) { - const int blockSize = 32; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - MaskedSoftmaxKernelSmall - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); - } else if (!is_unidirectional) { - const int blockSize = 1024; - MaskedSoftmaxKernel - <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output); - } else { - ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024."); - } - - return CUDA_CALL(cudaPeekAtLastError()); -} - -template -bool ComputeSoftmaxWithMask2D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, - const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar) { - const dim3 grid(sequence_length * num_heads, batch_size, 1); - - if (all_sequence_length <= 32) { - const int blockSize = 32; - SoftmaxWithMask2DSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); - } else if (all_sequence_length <= 64) { - const int blockSize = 64; - SoftmaxWithMask2DSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); - } else if (all_sequence_length <= 128) { - const int blockSize = 128; - SoftmaxWithMask2DSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); - } else if (all_sequence_length <= 256) { - const int blockSize = 256; - SoftmaxWithMask2DSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); - } else if (all_sequence_length <= 512) { - const int blockSize = 512; - SoftmaxWithMask2DSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); - } else if (all_sequence_length <= 1024) { - const int blockSize = 1024; - SoftmaxWithMask2DSmallKernel - <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar); - } else { - ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024."); - } - - return CUDA_CALL(cudaPeekAtLastError()); -} - -template -__global__ void TransposeCtx(const int H, const T* input, T* output) { - // Input: BxNxSxH - // Output: BxSxNxH - - int n = threadIdx.y; - int s = blockIdx.x; - int b = blockIdx.y; - - int num_heads = blockDim.y; - int sequence_length = gridDim.x; - - const int NH = num_heads * H; - const int NHS = NH * sequence_length; - const int in_offset = s * H + n * sequence_length * H + b * NHS; - const int out_offset = n * H + s * NH + b * NHS; - - const int i = threadIdx.x; - if (i < H) { - output[out_offset + i] = input[in_offset + i]; - } -} - -bool LaunchTransCtx(cudaStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const float* input, float* output) { - const dim3 grid(sequence_length, batch_size, 1); - if (0 == (head_size & 1)) { - const int H = head_size / 2; - const float2* input2 = reinterpret_cast(input); - float2* output2 = reinterpret_cast(output); - const dim3 block(H, num_heads, 1); - TransposeCtx<<>>(H, input2, output2); - } else { - const dim3 block(head_size, num_heads, 1); - TransposeCtx<<>>(head_size, input, output); - } - return CUDA_CALL(cudaPeekAtLastError()); -} - -bool LaunchTransCtx(cudaStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const half* input, half* output) { - const dim3 grid(sequence_length, batch_size, 1); - if (0 == (head_size % 4)) { - const int H = head_size / 4; - const dim3 block(H, num_heads, 1); - const float2* input2 = reinterpret_cast(input); - float2* output2 = reinterpret_cast(output); - TransposeCtx<<>>(H, input2, output2); - } else if (0 == (head_size & 1)) { - const int H = head_size / 2; - const dim3 block(H, num_heads, 1); - const half2* input2 = reinterpret_cast(input); - half2* output2 = reinterpret_cast(output); - TransposeCtx<<>>(H, input2, output2); - } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. - const dim3 block(head_size, num_heads, 1); - TransposeCtx<<>>(head_size, input, output); - } - - return CUDA_CALL(cudaPeekAtLastError()); -} - -template -__global__ void TransposeQKV(const int H, const T* input, T* output) { - // Input: BxSx3xNxH - // Output: 3xBxNxSxH - - int n = threadIdx.y; - int s = blockIdx.x; - int b = blockIdx.y; - int m = blockIdx.z; // matrix id - - const int num_heads = blockDim.y; - - const int sequence_length = gridDim.x; - const int batch_size = gridDim.y; - const int NH = num_heads * H; - const int NHS = NH * sequence_length; - const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3; - const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size; - - const int i = threadIdx.x; - if (i < H) { - output[out_offset + i] = input[in_offset + i]; - } -} - -bool LaunchTransQkv(cudaStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const float* input, float* output) { - const dim3 grid(sequence_length, batch_size, 3); - if (0 == (head_size & 1)) { - const int H = head_size / 2; - const float2* input2 = reinterpret_cast(input); - float2* output2 = reinterpret_cast(output); - const dim3 block(H, num_heads, 1); - TransposeQKV<<>>(H, input2, output2); - } else { - const dim3 block(head_size, num_heads, 1); - TransposeQKV<<>>(head_size, input, output); - } - return CUDA_CALL(cudaPeekAtLastError()); -} - -bool LaunchTransQkv(cudaStream_t stream, - const int sequence_length, const int batch_size, const int head_size, const int num_heads, - const half* input, half* output) { - const dim3 grid(sequence_length, batch_size, 3); - if (0 == (head_size % 4)) { - const int H = head_size / 4; - const dim3 block(H, num_heads, 1); - const float2* input2 = reinterpret_cast(input); - float2* output2 = reinterpret_cast(output); - TransposeQKV<<>>(H, input2, output2); - } else if (0 == (head_size & 1)) { - const int H = head_size / 2; - const dim3 block(H, num_heads, 1); - const half2* input2 = reinterpret_cast(input); - half2* output2 = reinterpret_cast(output); - TransposeQKV<<>>(H, input2, output2); - } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.. - const dim3 block(head_size, num_heads, 1); - TransposeQKV<<>>(head_size, input, output); - } - return CUDA_CALL(cudaPeekAtLastError()); -} - -template -__global__ void ConcatPastToPresent(const int sequence_length, - const T* past, - const T* k_v, - T* present) { - const int h = threadIdx.x; - const int n = threadIdx.y; - const int s = blockIdx.x; - const int b = blockIdx.y; - const int is_v = blockIdx.z; // 0 for k, 1 for v - - const int all_sequence_length = gridDim.x; - const int batch_size = gridDim.y; - const int num_heads = blockDim.y; - const int H = blockDim.x; - - // past: 2 x BxNxS'xH (past_k and past_v) - // k_v: 2 x BxNxSxH (k and v) - // present: 2 x BxNxS*xH (present_k and present_v) - const int past_sequence_length = all_sequence_length - sequence_length; - - const int present_SH = all_sequence_length * H; - const int present_NSH = num_heads * present_SH; - int out_offset = b * present_NSH + n * present_SH + s * H + h + is_v * (present_NSH * batch_size); - if (s < past_sequence_length) { - const int past_SH = past_sequence_length * H; - const int past_NSH = num_heads * past_SH; - const int in_offset = b * past_NSH + n * past_SH + s * H + h + is_v * (past_NSH * batch_size); - present[out_offset] = past[in_offset]; - } else if (s < all_sequence_length) { - const int SH = sequence_length * H; - const int NSH = num_heads * SH; - const int in_offset = b * NSH + n * SH + (s - past_sequence_length) * H + h + is_v * (NSH * batch_size); - present[out_offset] = k_v[in_offset]; - } -} - -bool LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const float* past, - const float* k_v, - float* present) { - const dim3 grid(all_sequence_length, batch_size, 2); - if (0 == (head_size & 1)) { - const dim3 block(head_size / 2, num_heads, 1); - ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); - } else { - const dim3 block(head_size, num_heads, 1); - ConcatPastToPresent<<>>(sequence_length, past, k_v, present); - } - return CUDA_CALL(cudaPeekAtLastError()); -} - -bool LaunchConcatPastToPresent(cudaStream_t stream, - const int all_sequence_length, - const int sequence_length, - const int batch_size, - const int head_size, - const int num_heads, - const half* past, - const half* k_v, - half* present) { - const dim3 grid(all_sequence_length, batch_size, 2); - if (0 == (head_size % 4)) { - const dim3 block(head_size / 4, num_heads, 1); - ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); - } else if (0 == (head_size & 1)) { - const dim3 block(head_size / 2, num_heads, 1); - ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); - } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. - const dim3 block(head_size, num_heads, 1); - ConcatPastToPresent<<>>(sequence_length, past, k_v, present); - } - return CUDA_CALL(cudaPeekAtLastError()); -} - -cublasStatus_t inline CublasGemmStridedBatched( - cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, const float alpha, - const float* A, int lda, long long int strideA, const float* B, int ldb, long long int strideB, - const float beta, float* C, int ldc, long long int strideC, int batchCount) { - return cublasSgemmStridedBatched( - handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount); -} - -cublasStatus_t inline CublasGemmStridedBatched( - cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, - int m, int n, int k, const half alpha, - const half* A, int lda, long long int strideA, const half* B, int ldb, long long int strideB, - const half beta, half* C, int ldc, long long int strideC, int batchCount) { - return cublasHgemmStridedBatched( - handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount); + return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length); } template @@ -641,7 +66,7 @@ bool QkvToContext( const int* mask_index, const std::vector* mask_index_dims, bool is_unidirectional, int past_sequence_length, const T* past, T* present) { const int all_sequence_length = past_sequence_length + sequence_length; - const size_t bytes = ScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length); + const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length); T* scratch1 = workspace; T* scratch2 = scratch1 + (bytes / element_size); T* scratch3 = scratch2 + (bytes / element_size); @@ -677,13 +102,15 @@ bool QkvToContext( v = present + batches * present_size_per_batch; } - bool use_2d_attention_mask = (nullptr != mask_index && nullptr != mask_index_dims && mask_index_dims->size() == 2); + // Raw attention mask could be 2D (BxS) or 3D (BxSxS*) + bool use_raw_attention_mask = (nullptr != mask_index && nullptr != mask_index_dims && mask_index_dims->size() >= 2); // compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS* // Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS* const float rsqrt_head_size = 1.f / sqrt(static_cast(head_size)); const int temp_matrix_size = sequence_length * all_sequence_length; - T alpha = (T)(use_2d_attention_mask ? 1.0f : rsqrt_head_size); + // For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation. + T alpha = (T)(use_raw_attention_mask ? 1.0f : rsqrt_head_size); if (!CUBLAS_CALL(CublasGemmStridedBatched( cublas, CUBLAS_OP_T, CUBLAS_OP_N, all_sequence_length, sequence_length, head_size, alpha, k, head_size, present_size_per_batch, q, head_size, size_per_batch, 0.f, scratch1, all_sequence_length, temp_matrix_size, batches))) { @@ -691,8 +118,8 @@ bool QkvToContext( } // apply softmax and store result P to scratch2: BxNxSxS* - if (use_2d_attention_mask) { // 2d attention mask - if (!ComputeSoftmaxWithMask2D(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, rsqrt_head_size)) { + if (use_raw_attention_mask) { // 2d or 3d attention mask + if (!ComputeSoftmaxWithRawMask(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, rsqrt_head_size, static_cast(mask_index_dims->size()))) { return false; } } else if (nullptr != mask_index) { // 1d mask index diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h index 0ba287ee3c..c51c007290 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.h +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.h @@ -3,10 +3,13 @@ #pragma once #include "core/providers/cuda/shared_inc/cuda_utils.h" +#include namespace onnxruntime { namespace contrib { namespace cuda { +size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length); + size_t GetAttentionWorkspaceSize( size_t element_size, int batchsize, @@ -34,6 +37,60 @@ bool LaunchAttentionKernel( void* present // Present state output ); +cublasStatus_t inline CublasGemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const float alpha, + const float* A, int lda, long long int strideA, const float* B, int ldb, long long int strideB, + const float beta, float* C, int ldc, long long int strideC, int batchCount) { + return cublasSgemmStridedBatched( + handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount); +} + +cublasStatus_t inline CublasGemmStridedBatched( + cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb, + int m, int n, int k, const half alpha, + const half* A, int lda, long long int strideA, const half* B, int ldb, long long int strideB, + const half beta, half* C, int ldc, long long int strideC, int batchCount) { + return cublasHgemmStridedBatched( + handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount); +} + +bool LaunchTransCtx(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const float* input, float* output); + +bool LaunchTransCtx(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const half* input, half* output); + +bool LaunchTransQkv(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const float* input, float* output); + +bool LaunchTransQkv(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const half* input, half* output); + +bool LaunchConcatPastToPresent(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const float* past, + const float* k_v, + float* present); + +bool LaunchConcatPastToPresent(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const half* past, + const half* k_v, + half* present); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_past.cu b/onnxruntime/contrib_ops/cuda/bert/attention_past.cu new file mode 100644 index 0000000000..d3811f1dfe --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_past.cu @@ -0,0 +1,95 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "attention_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void ConcatPastToPresent(const int sequence_length, + const T* past, + const T* k_v, + T* present) { + const int h = threadIdx.x; + const int n = threadIdx.y; + const int s = blockIdx.x; + const int b = blockIdx.y; + const int is_v = blockIdx.z; // 0 for k, 1 for v + + const int all_sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int num_heads = blockDim.y; + const int H = blockDim.x; + + // past: 2 x BxNxS'xH (past_k and past_v) + // k_v: 2 x BxNxSxH (k and v) + // present: 2 x BxNxS*xH (present_k and present_v) + const int past_sequence_length = all_sequence_length - sequence_length; + + const int present_SH = all_sequence_length * H; + const int present_NSH = num_heads * present_SH; + int out_offset = b * present_NSH + n * present_SH + s * H + h + is_v * (present_NSH * batch_size); + if (s < past_sequence_length) { + const int past_SH = past_sequence_length * H; + const int past_NSH = num_heads * past_SH; + const int in_offset = b * past_NSH + n * past_SH + s * H + h + is_v * (past_NSH * batch_size); + present[out_offset] = past[in_offset]; + } else if (s < all_sequence_length) { + const int SH = sequence_length * H; + const int NSH = num_heads * SH; + const int in_offset = b * NSH + n * SH + (s - past_sequence_length) * H + h + is_v * (NSH * batch_size); + present[out_offset] = k_v[in_offset]; + } +} + +bool LaunchConcatPastToPresent(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const float* past, + const float* k_v, + float* present) { + const dim3 grid(all_sequence_length, batch_size, 2); + if (0 == (head_size & 1)) { + const dim3 block(head_size / 2, num_heads, 1); + ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); + } else { + const dim3 block(head_size, num_heads, 1); + ConcatPastToPresent<<>>(sequence_length, past, k_v, present); + } + return CUDA_CALL(cudaPeekAtLastError()); +} + +bool LaunchConcatPastToPresent(cudaStream_t stream, + const int all_sequence_length, + const int sequence_length, + const int batch_size, + const int head_size, + const int num_heads, + const half* past, + const half* k_v, + half* present) { + const dim3 grid(all_sequence_length, batch_size, 2); + if (0 == (head_size % 4)) { + const dim3 block(head_size / 4, num_heads, 1); + ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); + } else if (0 == (head_size & 1)) { + const dim3 block(head_size / 2, num_heads, 1); + ConcatPastToPresent<<>>(sequence_length, reinterpret_cast(past), reinterpret_cast(k_v), reinterpret_cast(present)); + } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. + const dim3 block(head_size, num_heads, 1); + ConcatPastToPresent<<>>(sequence_length, past, k_v, present); + } + return CUDA_CALL(cudaPeekAtLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h new file mode 100644 index 0000000000..e24d284559 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_softmax.h @@ -0,0 +1,389 @@ +/* + The implementation of this file is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#pragma once + +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__device__ inline void Softmax(const int all_sequence_length, + const int sequence_length, + const int valid_end, + const int valid_start, + const T* input, + T* output) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + float thread_data_max(-CUDART_INF_F); + + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + if (thread_data_max < float(input[index])) { + thread_data_max = float(input[index]); + } + } + } + + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max()); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_sum(0.f); + for (int i = threadIdx.x; i < valid_end; i += TPB) { + if (i >= valid_start) { + const int index = offset + i; + const float val = input[index]; + thread_data_sum += expf(val - max_block); + } + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum()); + if (threadIdx.x == 0) { + sum_reverse_block = 1.f / sum; + } + __syncthreads(); + + for (int i = threadIdx.x; i < all_sequence_length; i += TPB) { + const int index = offset + i; + const float val = (i >= valid_start && i < valid_end) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f; + output[index] = T(val); + } +} + +template +__device__ inline void SoftmaxSmall(const int all_sequence_length, + const int sequence_length, + const int valid_end, + const int valid_start, + const T* input, + T* output, + bool is_unidirectional) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length; + const int index = offset + threadIdx.x; + + bool is_valid = false; // whether it has attention mask == 1. + + // Update end position for unidirectional. + int end = valid_end; + if (is_unidirectional) { + int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1; + if (end_unid <= valid_start) { + // In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000. + // So [0, end_unid) will also have value after softmax. + is_valid = threadIdx.x < end_unid; + } else { + end = min(valid_end, end_unid); + } + } + + is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end); + + // e^x is represented as infinity if x is large enough, like 100.f. + // Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough. + // a math transform as below is leveraged to get a stable softmax: + // e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max)) + float thread_data_max = is_valid ? float(input[index]) : float(-CUDART_INF_F); + const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_exp(0.f); + if (is_valid) { + thread_data_exp = expf(float(input[index]) - max_block); + } + + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end); + + // Store value of 1.0/sum. + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + // threadIdx.x might be larger than all_sequence_length due to alignment to 32x. + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data_exp * sum_reverse_block); + } +} + +template +__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, + const int sequence_length, + const int* attention_mask, // 2D or 3D attention mask + const T* input, + T* output, + const bool is_unidirectional, + const float scalar, + const int mask_dimension) { + using BlockReduce = cub::BlockReduce; + __shared__ typename BlockReduce::TempStorage tmp_storage; + + __shared__ float sum_reverse_block; + __shared__ float max_block; + + // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; + int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; + + float thread_data = -CUDART_INF_F; + if (threadIdx.x < all_sequence_length) { + const int batch_index = blockIdx.y; + const int sequence_index = blockIdx.x % sequence_length; + const int mask_offset = (mask_dimension == 2) ? batch_index * all_sequence_length + threadIdx.x : batch_index * sequence_length * all_sequence_length + sequence_index * all_sequence_length + threadIdx.x; + + const int& mask = attention_mask[mask_offset]; + float mask_value = mask > 0 ? 0.0f : -10000.0f; + + if (is_unidirectional) { + int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length. + if (threadIdx.x > from_index) { + mask_value += -10000.0f; + } + } + + thread_data = float(input[index]) * scalar + mask_value; + } + + const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length); + + // Store max value + if (threadIdx.x == 0) { + max_block = max; + } + __syncthreads(); + + float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length); + + // Store value of 1.0/sum + if (threadIdx.x == 0) { + sum_reverse_block = (1.f) / sum; + } + __syncthreads(); + + if (threadIdx.x < all_sequence_length) { + output[index] = T(thread_data_exp * sum_reverse_block); + } +} + +template +__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) { + SoftmaxSmall(all_sequence_length, sequence_length, all_sequence_length, 0, input, output, is_unidirectional); +} + +template +__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* input, T* output) { + Softmax(all_sequence_length, sequence_length, all_sequence_length, 0, input, output); +} + +template +bool ComputeSoftmax( + cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const T* input, T* output, bool is_unidirectional) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + if (all_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxKernelSmall<<>>(all_sequence_length, sequence_length, input, output, is_unidirectional); + } else if (!is_unidirectional) { + const int blockSize = 1024; + SoftmaxKernel<<>>(all_sequence_length, sequence_length, input, output); + } else { + ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024."); + } + + return CUDA_CALL(cudaPeekAtLastError()); +} + +template +__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output, bool is_unidirectional) { + __shared__ int start_position; + __shared__ int end_position; + + if (threadIdx.x == 0) { + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } + } + __syncthreads(); + + SoftmaxSmall(all_sequence_length, sequence_length, end_position, start_position, input, output, is_unidirectional); +} + +template +__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output) { + __shared__ int start_position; + __shared__ int end_position; + + if (threadIdx.x == 0) { + const int batch = blockIdx.y; + start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0; + end_position = min(all_sequence_length, mask_end[batch]); + + // Attend to no word has same effect as attend to all words. This is added to get parity with CPU result. + if (start_position >= end_position) { + start_position = 0; + end_position = all_sequence_length; + } + } + __syncthreads(); + + Softmax(all_sequence_length, sequence_length, end_position, start_position, input, output); +} + +template +__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar, const int mask_dimension) { + SoftmaxWithRawMaskSmall(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension); +} + +template +bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* mask_index, const int* mask_start, const T* input, T* output, const bool is_unidirectional) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + + if (all_sequence_length <= 32) { + const int blockSize = 32; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + MaskedSoftmaxKernelSmall + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional); + } else if (!is_unidirectional) { + const int blockSize = 1024; + MaskedSoftmaxKernel + <<>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output); + } else { + ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024."); + } + + return CUDA_CALL(cudaPeekAtLastError()); +} + +template +bool ComputeSoftmaxWithRawMask(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads, + const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar, + const int mask_dimension) { + const dim3 grid(sequence_length * num_heads, batch_size, 1); + + if (all_sequence_length <= 32) { + const int blockSize = 32; + SoftmaxWithRawMaskSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension); + } else if (all_sequence_length <= 64) { + const int blockSize = 64; + SoftmaxWithRawMaskSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension); + } else if (all_sequence_length <= 128) { + const int blockSize = 128; + SoftmaxWithRawMaskSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension); + } else if (all_sequence_length <= 256) { + const int blockSize = 256; + SoftmaxWithRawMaskSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension); + } else if (all_sequence_length <= 512) { + const int blockSize = 512; + SoftmaxWithRawMaskSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension); + } else if (all_sequence_length <= 1024) { + const int blockSize = 1024; + SoftmaxWithRawMaskSmallKernel + <<>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension); + } else { + ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024."); + } + + return CUDA_CALL(cudaPeekAtLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu new file mode 100644 index 0000000000..d6a760362c --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu @@ -0,0 +1,160 @@ +/* + The implementation of this file is based on qkvToContext plugin in TensorRT demo: + https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/ + +Copyright 2019 NVIDIA Corporation + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +#include "core/providers/cuda/cuda_common.h" +#include "attention_impl.h" + +using namespace onnxruntime::cuda; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void TransposeCtx(const int H, const T* input, T* output) { + // Input: BxNxSxH + // Output: BxSxNxH + + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + + int num_heads = blockDim.y; + int sequence_length = gridDim.x; + + const int NH = num_heads * H; + const int NHS = NH * sequence_length; + const int in_offset = s * H + n * sequence_length * H + b * NHS; + const int out_offset = n * H + s * NH + b * NHS; + + const int i = threadIdx.x; + if (i < H) { + output[out_offset + i] = input[in_offset + i]; + } +} + +bool LaunchTransCtx(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const float* input, float* output) { + const dim3 grid(sequence_length, batch_size, 1); + if (0 == (head_size & 1)) { + const int H = head_size / 2; + const float2* input2 = reinterpret_cast(input); + float2* output2 = reinterpret_cast(output); + const dim3 block(H, num_heads, 1); + TransposeCtx<<>>(H, input2, output2); + } else { + const dim3 block(head_size, num_heads, 1); + TransposeCtx<<>>(head_size, input, output); + } + return CUDA_CALL(cudaPeekAtLastError()); +} + +bool LaunchTransCtx(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const half* input, half* output) { + const dim3 grid(sequence_length, batch_size, 1); + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const dim3 block(H, num_heads, 1); + const float2* input2 = reinterpret_cast(input); + float2* output2 = reinterpret_cast(output); + TransposeCtx<<>>(H, input2, output2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const dim3 block(H, num_heads, 1); + const half2* input2 = reinterpret_cast(input); + half2* output2 = reinterpret_cast(output); + TransposeCtx<<>>(H, input2, output2); + } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel. + const dim3 block(head_size, num_heads, 1); + TransposeCtx<<>>(head_size, input, output); + } + + return CUDA_CALL(cudaPeekAtLastError()); +} + +template +__global__ void TransposeQKV(const int H, const T* input, T* output) { + // Input: BxSx3xNxH + // Output: 3xBxNxSxH + + int n = threadIdx.y; + int s = blockIdx.x; + int b = blockIdx.y; + int m = blockIdx.z; // matrix id + + const int num_heads = blockDim.y; + + const int sequence_length = gridDim.x; + const int batch_size = gridDim.y; + const int NH = num_heads * H; + const int NHS = NH * sequence_length; + const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3; + const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size; + + const int i = threadIdx.x; + if (i < H) { + output[out_offset + i] = input[in_offset + i]; + } +} + +bool LaunchTransQkv(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const float* input, float* output) { + const dim3 grid(sequence_length, batch_size, 3); + if (0 == (head_size & 1)) { + const int H = head_size / 2; + const float2* input2 = reinterpret_cast(input); + float2* output2 = reinterpret_cast(output); + const dim3 block(H, num_heads, 1); + TransposeQKV<<>>(H, input2, output2); + } else { + const dim3 block(head_size, num_heads, 1); + TransposeQKV<<>>(head_size, input, output); + } + return CUDA_CALL(cudaPeekAtLastError()); +} + +bool LaunchTransQkv(cudaStream_t stream, + const int sequence_length, const int batch_size, const int head_size, const int num_heads, + const half* input, half* output) { + const dim3 grid(sequence_length, batch_size, 3); + if (0 == (head_size % 4)) { + const int H = head_size / 4; + const dim3 block(H, num_heads, 1); + const float2* input2 = reinterpret_cast(input); + float2* output2 = reinterpret_cast(output); + TransposeQKV<<>>(H, input2, output2); + } else if (0 == (head_size & 1)) { + const int H = head_size / 2; + const dim3 block(H, num_heads, 1); + const half2* input2 = reinterpret_cast(input); + half2* output2 = reinterpret_cast(output); + TransposeQKV<<>>(H, input2, output2); + } else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.. + const dim3 block(head_size, num_heads, 1); + TransposeQKV<<>>(head_size, input, output); + } + return CUDA_CALL(cudaPeekAtLastError()); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 084e145f5f..5d18fd6892 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -277,7 +277,8 @@ void FusedMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { void RegisterBertSchemas() { static const char* Attention_ver1_doc = R"DOC( Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT). -The mask_index input is optional. Besides raw attention mask with shape (batch_size, past_sequence_length + sequence_length), +The mask_index input is optional. Besides raw attention mask with shape (batch_size, past_sequence_length + sequence_length) +or (batch_size, sequence_length, past_sequence_length + sequence_length) with value 0 for masked and 1 otherwise, we also support other two formats: When input has right-side padding, mask_index is one dimension with shape (batch_size), where value of each element is the end position, or valid length of actual sequence excluding padding. When input has left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by @@ -297,7 +298,7 @@ and present state are optional. Present state could appear in output even when p .Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T") .Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T") .Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T") - .Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional) + .Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional) .Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional) .Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T") .Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional) diff --git a/onnxruntime/test/contrib_ops/attention_op_test.cc b/onnxruntime/test/contrib_ops/attention_op_test.cc index 2728a8ad5f..6abe23599b 100644 --- a/onnxruntime/test/contrib_ops/attention_op_test.cc +++ b/onnxruntime/test/contrib_ops/attention_op_test.cc @@ -13,12 +13,13 @@ enum MaskIndexType { kMaskIndexEnd = 0, kMaskIndexEndAndStart, kMaskRaw, + kMask3D, kMaskDummy // Dummy mask with shape [1, 1] or [batch_size, 1] }; static void RunAttentionTest( - const std::vector& input_data, // input: [batch_size, sequence_length, hidden_size] - const std::vector& weights_data, // weights: [hidden_size, 3 * hidden_size] + const std::vector& input_data, // input: [batch_size, sequence_length, hidden_size] + const std::vector& weights_data, // weights: [hidden_size, 3 * hidden_size] bool is_weights_constant, const std::vector& bias_data, // bias: [3 * hidden_size] const std::vector& mask_index_data, // mask_index: [batch_size] or [batch_size, past_sequence_length + sequence_length] or empty @@ -52,6 +53,7 @@ static void RunAttentionTest( std::vector mask_index_dims_2 = {2 * batch_size}; std::vector mask_index_dims_3 = {batch_size, past_sequence_length + sequence_length}; std::vector mask_index_dims_4 = {batch_size, 1}; + std::vector mask_index_dims_5 = {batch_size, sequence_length, past_sequence_length + sequence_length}; std::vector mask_index_dims; switch (mask_index_type) { case kMaskIndexEnd: @@ -66,8 +68,11 @@ static void RunAttentionTest( case kMaskDummy: mask_index_dims = mask_index_dims_4; break; + case kMask3D: + mask_index_dims = mask_index_dims_5; + break; default: - assert(0); // shall not reach here. + assert(0); // shall not reach here. break; } @@ -973,6 +978,51 @@ TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart); } +TEST(AttentionTest, Attention3DMask) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test 3D mask BxSxS* + std::vector mask_index_data = { + 0, 1, + 0, 1, + 1, 1, + 1, 1}; + + std::vector output_data = { + 8.69f, -0.13f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f, + 3.14959716796875f, 0.10843672603368759f, 4.25f, 5.65f, + 3.9696791172027588f, 0.073143675923347473f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D); +} + TEST(AttentionTest, AttentionBatch2AttentionMask) { int batch_size = 2; int sequence_length = 2; @@ -1014,6 +1064,51 @@ TEST(AttentionTest, AttentionBatch2AttentionMask) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw); } +TEST(AttentionTest, AttentionUnidirectional3DMask) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test 3D mask BxSxS* + std::vector mask_index_data = { + 0, 1, + 0, 1, + 1, 1, + 1, 1}; + + std::vector output_data = { + 3.967245340f, 0.07324841f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f, + 8.69f, -0.13f, 4.25f, 5.65f, + 3.96967912f, 0.07314367f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = true; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D); +} + TEST(AttentionTest, AttentionUnidirectionalAttentionMask) { int batch_size = 2; int sequence_length = 2; @@ -1181,6 +1276,47 @@ TEST(AttentionTest, AttentionMask2DNoWord) { use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw); } +TEST(AttentionTest, AttentionMask3DNoWord) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test that all attention masks are zero. + std::vector mask_index_data = {0, 0, 0, 0, 0, 0, 0, 0}; + + std::vector output_data = { + 3.96724534f, 0.07324841f, 4.25f, 5.65f, + 3.14984703f, 0.10842596f, 4.25f, 5.65f, + 3.14984703f, 0.10842596f, 4.25f, 5.65f, + 3.96724534f, 0.07324841f, 4.25f, 5.65f}; + + bool use_float16 = false; + bool is_unidirectional = false; + bool use_past_state = false; + int past_sequence_length = 0; + const std::vector* past_data = nullptr; + const std::vector* present_data = nullptr; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, + use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D); +} TEST(AttentionTest, AttentionDummyMask2D) { int batch_size = 2; @@ -1294,4 +1430,4 @@ TEST(AttentionTest, AttentionPastState_dynamic) { } } // namespace test -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime