mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
Support 3D attention mask (#5887)
Support 3D attention mask with shape (batch_size, sequence_length, all_sequence_length)
This commit is contained in:
parent
cc6e8fb7cc
commit
910bbfe1ef
11 changed files with 886 additions and 596 deletions
|
|
@ -59,7 +59,10 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
// input : (batch_size, sequence_length, hidden_size)
|
||||
// weights : (hidden_size, 3 * hidden_size)
|
||||
// bias : (3 * hidden_size)
|
||||
// mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length)
|
||||
// mask_index : nullptr, (batch_size), (2 * batch_size),
|
||||
// or (batch_size, 1), (1, 1)
|
||||
// or (batch_size, past_sequence_length + sequence_length)
|
||||
// or (batch_size, sequence_length, past_sequence_length + sequence_length)
|
||||
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
|
||||
|
||||
const auto& dims = input_shape.GetDims();
|
||||
|
|
@ -136,8 +139,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)");
|
||||
}
|
||||
}
|
||||
} else if (mask_dims.size() == 3) {
|
||||
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' of 3d shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
|
||||
}
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1 or 2 dimensions, got ",
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ",
|
||||
mask_dims.size());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -26,8 +26,8 @@ class AttentionBase {
|
|||
int sequence_length,
|
||||
int& past_sequence_length) const;
|
||||
|
||||
int num_heads_; // number of attention heads
|
||||
bool is_unidirectional_; // whether every token can only attend to previous tokens.
|
||||
int num_heads_; // number of attention heads
|
||||
bool is_unidirectional_; // whether every token can only attend to previous tokens.
|
||||
};
|
||||
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ class AttentionCPUBase : public AttentionBase {
|
|||
const T* K, // k data. Its size is BxNxSxH
|
||||
const int32_t* mask_index, // mask index. nullptr if no mask or its size is B
|
||||
const std::vector<int64_t>* mask_index_dims, // mask index shape
|
||||
T* mask_data, // buffer for mask data. Its size is: SxS* if is_unidirectional_; BxSxS* if mask_index; null otherwise
|
||||
T* mask_data, // buffer for mask data. It is nullptr if mask_index is nullptr, otherwise its shape is BxSxS*
|
||||
int batch_size, // batch size of self-attention
|
||||
int sequence_length, // sequence length of self-attention
|
||||
int past_sequence_length, // sequence length of past state
|
||||
|
|
|
|||
|
|
@ -72,12 +72,31 @@ void PrepareMask(const int32_t* mask_index,
|
|||
// mask_data has been filled with 0, and its shape is BxSxS*
|
||||
T* p_mask = mask_data;
|
||||
|
||||
// For 3D mask, convert values 0 to -10000.0, and 1 to 0.0, then apply unidirectional mask if any.
|
||||
if (nullptr != mask_index_dims && mask_index_dims->size() == 3) {
|
||||
for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) {
|
||||
p_mask[i] = (mask_index[i] > 0) ? static_cast<T>(0.0f) : static_cast<T>(-10000.0f);
|
||||
}
|
||||
|
||||
if (is_unidirectional) {
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
|
||||
for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) {
|
||||
p_mask[s_i * all_sequence_length + m_i] += static_cast<T>(-10000.0f);
|
||||
}
|
||||
}
|
||||
p_mask += sequence_length * all_sequence_length;
|
||||
}
|
||||
}
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
bool is_raw_attention_mask = (nullptr != mask_index_dims && mask_index_dims->size() == 2);
|
||||
bool has_mask_start_position = (nullptr != mask_index_dims && mask_index_dims->size() == 1 && static_cast<int>(mask_index_dims->at(0)) == 2 * batch_size);
|
||||
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
// TODO: mask_index can be used in softmax to save some calculation.
|
||||
|
||||
if (nullptr != mask_index) {
|
||||
if (is_raw_attention_mask) {
|
||||
// Raw attention mask has value 0 or 1. Here we convert 0 to -10000.0, and 1 to 0.0.
|
||||
|
|
@ -120,7 +139,6 @@ void PrepareMask(const int32_t* mask_index,
|
|||
|
||||
p_mask += sequence_length * all_sequence_length;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Concatenate a past state chunk S'xH with input state chunk SxH into present state chunk S*xH
|
||||
|
|
|
|||
|
|
@ -21,13 +21,11 @@ limitations under the License.
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math_constants.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "attention_impl.h"
|
||||
#include "attention_softmax.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace cub;
|
||||
|
|
@ -40,7 +38,7 @@ static size_t AlignTo(size_t a, size_t b) {
|
|||
return CeilDiv(a, b) * b;
|
||||
}
|
||||
|
||||
size_t ScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) {
|
||||
size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) {
|
||||
const size_t len = batch_size * num_heads * sequence_length * all_sequence_length;
|
||||
const size_t bytes = len * element_size;
|
||||
|
||||
|
|
@ -57,580 +55,7 @@ size_t GetAttentionWorkspaceSize(
|
|||
int sequence_length,
|
||||
int past_sequence_length) {
|
||||
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
|
||||
return qkv_size + 2 * ScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void Softmax(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int valid_end,
|
||||
const int valid_start,
|
||||
const T* input,
|
||||
T* output) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
float thread_data_max(-CUDART_INF_F);
|
||||
|
||||
// e^x is represented as infinity if x is large enough, like 100.f.
|
||||
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
|
||||
// a math transform as below is leveraged to get a stable softmax:
|
||||
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
|
||||
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
|
||||
for (int i = threadIdx.x; i < valid_end; i += TPB) {
|
||||
if (i >= valid_start) {
|
||||
const int index = offset + i;
|
||||
if (thread_data_max < float(input[index])) {
|
||||
thread_data_max = float(input[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max());
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_sum(0.f);
|
||||
for (int i = threadIdx.x; i < valid_end; i += TPB) {
|
||||
if (i >= valid_start) {
|
||||
const int index = offset + i;
|
||||
const float val = input[index];
|
||||
thread_data_sum += expf(val - max_block);
|
||||
}
|
||||
}
|
||||
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum());
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = 1.f / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = threadIdx.x; i < all_sequence_length; i += TPB) {
|
||||
const int index = offset + i;
|
||||
const float val = (i >= valid_start && i < valid_end) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f;
|
||||
output[index] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void SoftmaxSmall(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int valid_end,
|
||||
const int valid_start,
|
||||
const T* input,
|
||||
T* output,
|
||||
bool is_unidirectional) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
|
||||
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
|
||||
const int index = offset + threadIdx.x;
|
||||
|
||||
bool is_valid = false; // whether it has attention mask == 1.
|
||||
|
||||
// Update end position for unidirectional.
|
||||
int end = valid_end;
|
||||
if (is_unidirectional) {
|
||||
int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1;
|
||||
if (end_unid <= valid_start) {
|
||||
// In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000.
|
||||
// So [0, end_unid) will also have value after softmax.
|
||||
is_valid = threadIdx.x < end_unid;
|
||||
} else {
|
||||
end = min(valid_end, end_unid);
|
||||
}
|
||||
}
|
||||
|
||||
is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end);
|
||||
|
||||
// e^x is represented as infinity if x is large enough, like 100.f.
|
||||
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
|
||||
// a math transform as below is leveraged to get a stable softmax:
|
||||
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
|
||||
float thread_data_max = is_valid ? float(input[index]) : float(-CUDART_INF_F);
|
||||
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_exp(0.f);
|
||||
if (is_valid) {
|
||||
thread_data_exp = expf(float(input[index]) - max_block);
|
||||
}
|
||||
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end);
|
||||
|
||||
// Store value of 1.0/sum.
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = (1.f) / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
output[index] = T(thread_data_exp * sum_reverse_block);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void SoftmaxWithMask2DSmall(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int* attention_mask, // 2D attention mask
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional,
|
||||
const float scalar) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
|
||||
int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
|
||||
|
||||
float thread_data = -CUDART_INF_F;
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
const int& mask = attention_mask[blockIdx.y * all_sequence_length + threadIdx.x];
|
||||
float mask_value = mask > 0 ? 0.0f : -10000.0f;
|
||||
|
||||
if (is_unidirectional) {
|
||||
int from_index = all_sequence_length - sequence_length + (blockIdx.x % sequence_length); // offset of from token in all sequence length.
|
||||
if (threadIdx.x > from_index) {
|
||||
mask_value += -10000.0f;
|
||||
}
|
||||
}
|
||||
|
||||
thread_data = float(input[index]) * scalar + mask_value;
|
||||
}
|
||||
|
||||
const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length);
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length);
|
||||
|
||||
// Store value of 1.0/sum
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = (1.f) / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
output[index] = T(thread_data_exp * sum_reverse_block);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) {
|
||||
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, input, output, is_unidirectional);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* input, T* output) {
|
||||
Softmax<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, input, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmax(
|
||||
cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const T* input, T* output, bool is_unidirectional) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (!is_unidirectional) {
|
||||
const int blockSize = 1024;
|
||||
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output, bool is_unidirectional) {
|
||||
__shared__ int start_position;
|
||||
__shared__ int end_position;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int batch = blockIdx.y;
|
||||
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
|
||||
end_position = min(all_sequence_length, mask_end[batch]);
|
||||
|
||||
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
|
||||
if (start_position >= end_position) {
|
||||
start_position = 0;
|
||||
end_position = all_sequence_length;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, input, output, is_unidirectional);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output) {
|
||||
__shared__ int start_position;
|
||||
__shared__ int end_position;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int batch = blockIdx.y;
|
||||
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
|
||||
end_position = min(all_sequence_length, mask_end[batch]);
|
||||
|
||||
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
|
||||
if (start_position >= end_position) {
|
||||
start_position = 0;
|
||||
end_position = all_sequence_length;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Softmax<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, input, output);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxWithMask2DSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar) {
|
||||
SoftmaxWithMask2DSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const int* mask_index, const int* mask_start, const T* input, T* output, const bool is_unidirectional) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (!is_unidirectional) {
|
||||
const int blockSize = 1024;
|
||||
MaskedSoftmaxKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmaxWithMask2D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
SoftmaxWithMask2DSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
SoftmaxWithMask2DSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
SoftmaxWithMask2DSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
SoftmaxWithMask2DSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
SoftmaxWithMask2DSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
SoftmaxWithMask2DSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void TransposeCtx(const int H, const T* input, T* output) {
|
||||
// Input: BxNxSxH
|
||||
// Output: BxSxNxH
|
||||
|
||||
int n = threadIdx.y;
|
||||
int s = blockIdx.x;
|
||||
int b = blockIdx.y;
|
||||
|
||||
int num_heads = blockDim.y;
|
||||
int sequence_length = gridDim.x;
|
||||
|
||||
const int NH = num_heads * H;
|
||||
const int NHS = NH * sequence_length;
|
||||
const int in_offset = s * H + n * sequence_length * H + b * NHS;
|
||||
const int out_offset = n * H + s * NH + b * NHS;
|
||||
|
||||
const int i = threadIdx.x;
|
||||
if (i < H) {
|
||||
output[out_offset + i] = input[in_offset + i];
|
||||
}
|
||||
}
|
||||
|
||||
bool LaunchTransCtx(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const float* input, float* output) {
|
||||
const dim3 grid(sequence_length, batch_size, 1);
|
||||
if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
const dim3 block(H, num_heads, 1);
|
||||
TransposeCtx<float2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else {
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
TransposeCtx<float><<<grid, block, 0, stream>>>(head_size, input, output);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
bool LaunchTransCtx(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const half* input, half* output) {
|
||||
const dim3 grid(sequence_length, batch_size, 1);
|
||||
if (0 == (head_size % 4)) {
|
||||
const int H = head_size / 4;
|
||||
const dim3 block(H, num_heads, 1);
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
TransposeCtx<float2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
const dim3 block(H, num_heads, 1);
|
||||
const half2* input2 = reinterpret_cast<const half2*>(input);
|
||||
half2* output2 = reinterpret_cast<half2*>(output);
|
||||
TransposeCtx<half2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
TransposeCtx<half><<<grid, block, 0, stream>>>(head_size, input, output);
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void TransposeQKV(const int H, const T* input, T* output) {
|
||||
// Input: BxSx3xNxH
|
||||
// Output: 3xBxNxSxH
|
||||
|
||||
int n = threadIdx.y;
|
||||
int s = blockIdx.x;
|
||||
int b = blockIdx.y;
|
||||
int m = blockIdx.z; // matrix id
|
||||
|
||||
const int num_heads = blockDim.y;
|
||||
|
||||
const int sequence_length = gridDim.x;
|
||||
const int batch_size = gridDim.y;
|
||||
const int NH = num_heads * H;
|
||||
const int NHS = NH * sequence_length;
|
||||
const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
|
||||
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
|
||||
|
||||
const int i = threadIdx.x;
|
||||
if (i < H) {
|
||||
output[out_offset + i] = input[in_offset + i];
|
||||
}
|
||||
}
|
||||
|
||||
bool LaunchTransQkv(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const float* input, float* output) {
|
||||
const dim3 grid(sequence_length, batch_size, 3);
|
||||
if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
const dim3 block(H, num_heads, 1);
|
||||
TransposeQKV<float2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else {
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
TransposeQKV<float><<<grid, block, 0, stream>>>(head_size, input, output);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
bool LaunchTransQkv(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const half* input, half* output) {
|
||||
const dim3 grid(sequence_length, batch_size, 3);
|
||||
if (0 == (head_size % 4)) {
|
||||
const int H = head_size / 4;
|
||||
const dim3 block(H, num_heads, 1);
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
TransposeQKV<float2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
const dim3 block(H, num_heads, 1);
|
||||
const half2* input2 = reinterpret_cast<const half2*>(input);
|
||||
half2* output2 = reinterpret_cast<half2*>(output);
|
||||
TransposeQKV<half2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel..
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
TransposeQKV<half><<<grid, block, 0, stream>>>(head_size, input, output);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void ConcatPastToPresent(const int sequence_length,
|
||||
const T* past,
|
||||
const T* k_v,
|
||||
T* present) {
|
||||
const int h = threadIdx.x;
|
||||
const int n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
const int is_v = blockIdx.z; // 0 for k, 1 for v
|
||||
|
||||
const int all_sequence_length = gridDim.x;
|
||||
const int batch_size = gridDim.y;
|
||||
const int num_heads = blockDim.y;
|
||||
const int H = blockDim.x;
|
||||
|
||||
// past: 2 x BxNxS'xH (past_k and past_v)
|
||||
// k_v: 2 x BxNxSxH (k and v)
|
||||
// present: 2 x BxNxS*xH (present_k and present_v)
|
||||
const int past_sequence_length = all_sequence_length - sequence_length;
|
||||
|
||||
const int present_SH = all_sequence_length * H;
|
||||
const int present_NSH = num_heads * present_SH;
|
||||
int out_offset = b * present_NSH + n * present_SH + s * H + h + is_v * (present_NSH * batch_size);
|
||||
if (s < past_sequence_length) {
|
||||
const int past_SH = past_sequence_length * H;
|
||||
const int past_NSH = num_heads * past_SH;
|
||||
const int in_offset = b * past_NSH + n * past_SH + s * H + h + is_v * (past_NSH * batch_size);
|
||||
present[out_offset] = past[in_offset];
|
||||
} else if (s < all_sequence_length) {
|
||||
const int SH = sequence_length * H;
|
||||
const int NSH = num_heads * SH;
|
||||
const int in_offset = b * NSH + n * SH + (s - past_sequence_length) * H + h + is_v * (NSH * batch_size);
|
||||
present[out_offset] = k_v[in_offset];
|
||||
}
|
||||
}
|
||||
|
||||
bool LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const float* past,
|
||||
const float* k_v,
|
||||
float* present) {
|
||||
const dim3 grid(all_sequence_length, batch_size, 2);
|
||||
if (0 == (head_size & 1)) {
|
||||
const dim3 block(head_size / 2, num_heads, 1);
|
||||
ConcatPastToPresent<float2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const float2*>(past), reinterpret_cast<const float2*>(k_v), reinterpret_cast<float2*>(present));
|
||||
} else {
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
ConcatPastToPresent<float><<<grid, block, 0, stream>>>(sequence_length, past, k_v, present);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
bool LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const half* past,
|
||||
const half* k_v,
|
||||
half* present) {
|
||||
const dim3 grid(all_sequence_length, batch_size, 2);
|
||||
if (0 == (head_size % 4)) {
|
||||
const dim3 block(head_size / 4, num_heads, 1);
|
||||
ConcatPastToPresent<float2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const float2*>(past), reinterpret_cast<const float2*>(k_v), reinterpret_cast<float2*>(present));
|
||||
} else if (0 == (head_size & 1)) {
|
||||
const dim3 block(head_size / 2, num_heads, 1);
|
||||
ConcatPastToPresent<half2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const half2*>(past), reinterpret_cast<const half2*>(k_v), reinterpret_cast<half2*>(present));
|
||||
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
ConcatPastToPresent<half><<<grid, block, 0, stream>>>(sequence_length, past, k_v, present);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
cublasStatus_t inline CublasGemmStridedBatched(
|
||||
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
int m, int n, int k, const float alpha,
|
||||
const float* A, int lda, long long int strideA, const float* B, int ldb, long long int strideB,
|
||||
const float beta, float* C, int ldc, long long int strideC, int batchCount) {
|
||||
return cublasSgemmStridedBatched(
|
||||
handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount);
|
||||
}
|
||||
|
||||
cublasStatus_t inline CublasGemmStridedBatched(
|
||||
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
int m, int n, int k, const half alpha,
|
||||
const half* A, int lda, long long int strideA, const half* B, int ldb, long long int strideB,
|
||||
const half beta, half* C, int ldc, long long int strideC, int batchCount) {
|
||||
return cublasHgemmStridedBatched(
|
||||
handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount);
|
||||
return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -641,7 +66,7 @@ bool QkvToContext(
|
|||
const int* mask_index, const std::vector<int64_t>* mask_index_dims,
|
||||
bool is_unidirectional, int past_sequence_length, const T* past, T* present) {
|
||||
const int all_sequence_length = past_sequence_length + sequence_length;
|
||||
const size_t bytes = ScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
|
||||
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
|
||||
T* scratch1 = workspace;
|
||||
T* scratch2 = scratch1 + (bytes / element_size);
|
||||
T* scratch3 = scratch2 + (bytes / element_size);
|
||||
|
|
@ -677,13 +102,15 @@ bool QkvToContext(
|
|||
v = present + batches * present_size_per_batch;
|
||||
}
|
||||
|
||||
bool use_2d_attention_mask = (nullptr != mask_index && nullptr != mask_index_dims && mask_index_dims->size() == 2);
|
||||
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*)
|
||||
bool use_raw_attention_mask = (nullptr != mask_index && nullptr != mask_index_dims && mask_index_dims->size() >= 2);
|
||||
|
||||
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
|
||||
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
|
||||
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
|
||||
const int temp_matrix_size = sequence_length * all_sequence_length;
|
||||
T alpha = (T)(use_2d_attention_mask ? 1.0f : rsqrt_head_size);
|
||||
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
|
||||
T alpha = (T)(use_raw_attention_mask ? 1.0f : rsqrt_head_size);
|
||||
if (!CUBLAS_CALL(CublasGemmStridedBatched(
|
||||
cublas, CUBLAS_OP_T, CUBLAS_OP_N, all_sequence_length, sequence_length, head_size, alpha, k, head_size, present_size_per_batch,
|
||||
q, head_size, size_per_batch, 0.f, scratch1, all_sequence_length, temp_matrix_size, batches))) {
|
||||
|
|
@ -691,8 +118,8 @@ bool QkvToContext(
|
|||
}
|
||||
|
||||
// apply softmax and store result P to scratch2: BxNxSxS*
|
||||
if (use_2d_attention_mask) { // 2d attention mask
|
||||
if (!ComputeSoftmaxWithMask2D<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, rsqrt_head_size)) {
|
||||
if (use_raw_attention_mask) { // 2d or 3d attention mask
|
||||
if (!ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, rsqrt_head_size, static_cast<int>(mask_index_dims->size()))) {
|
||||
return false;
|
||||
}
|
||||
} else if (nullptr != mask_index) { // 1d mask index
|
||||
|
|
|
|||
|
|
@ -3,10 +3,13 @@
|
|||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
#include <cublas_v2.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length);
|
||||
|
||||
size_t GetAttentionWorkspaceSize(
|
||||
size_t element_size,
|
||||
int batchsize,
|
||||
|
|
@ -34,6 +37,60 @@ bool LaunchAttentionKernel(
|
|||
void* present // Present state output
|
||||
);
|
||||
|
||||
cublasStatus_t inline CublasGemmStridedBatched(
|
||||
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
int m, int n, int k, const float alpha,
|
||||
const float* A, int lda, long long int strideA, const float* B, int ldb, long long int strideB,
|
||||
const float beta, float* C, int ldc, long long int strideC, int batchCount) {
|
||||
return cublasSgemmStridedBatched(
|
||||
handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount);
|
||||
}
|
||||
|
||||
cublasStatus_t inline CublasGemmStridedBatched(
|
||||
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
|
||||
int m, int n, int k, const half alpha,
|
||||
const half* A, int lda, long long int strideA, const half* B, int ldb, long long int strideB,
|
||||
const half beta, half* C, int ldc, long long int strideC, int batchCount) {
|
||||
return cublasHgemmStridedBatched(
|
||||
handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount);
|
||||
}
|
||||
|
||||
bool LaunchTransCtx(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const float* input, float* output);
|
||||
|
||||
bool LaunchTransCtx(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const half* input, half* output);
|
||||
|
||||
bool LaunchTransQkv(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const float* input, float* output);
|
||||
|
||||
bool LaunchTransQkv(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const half* input, half* output);
|
||||
|
||||
bool LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const float* past,
|
||||
const float* k_v,
|
||||
float* present);
|
||||
|
||||
bool LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const half* past,
|
||||
const half* k_v,
|
||||
half* present);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
95
onnxruntime/contrib_ops/cuda/bert/attention_past.cu
Normal file
95
onnxruntime/contrib_ops/cuda/bert/attention_past.cu
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "attention_impl.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void ConcatPastToPresent(const int sequence_length,
|
||||
const T* past,
|
||||
const T* k_v,
|
||||
T* present) {
|
||||
const int h = threadIdx.x;
|
||||
const int n = threadIdx.y;
|
||||
const int s = blockIdx.x;
|
||||
const int b = blockIdx.y;
|
||||
const int is_v = blockIdx.z; // 0 for k, 1 for v
|
||||
|
||||
const int all_sequence_length = gridDim.x;
|
||||
const int batch_size = gridDim.y;
|
||||
const int num_heads = blockDim.y;
|
||||
const int H = blockDim.x;
|
||||
|
||||
// past: 2 x BxNxS'xH (past_k and past_v)
|
||||
// k_v: 2 x BxNxSxH (k and v)
|
||||
// present: 2 x BxNxS*xH (present_k and present_v)
|
||||
const int past_sequence_length = all_sequence_length - sequence_length;
|
||||
|
||||
const int present_SH = all_sequence_length * H;
|
||||
const int present_NSH = num_heads * present_SH;
|
||||
int out_offset = b * present_NSH + n * present_SH + s * H + h + is_v * (present_NSH * batch_size);
|
||||
if (s < past_sequence_length) {
|
||||
const int past_SH = past_sequence_length * H;
|
||||
const int past_NSH = num_heads * past_SH;
|
||||
const int in_offset = b * past_NSH + n * past_SH + s * H + h + is_v * (past_NSH * batch_size);
|
||||
present[out_offset] = past[in_offset];
|
||||
} else if (s < all_sequence_length) {
|
||||
const int SH = sequence_length * H;
|
||||
const int NSH = num_heads * SH;
|
||||
const int in_offset = b * NSH + n * SH + (s - past_sequence_length) * H + h + is_v * (NSH * batch_size);
|
||||
present[out_offset] = k_v[in_offset];
|
||||
}
|
||||
}
|
||||
|
||||
bool LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const float* past,
|
||||
const float* k_v,
|
||||
float* present) {
|
||||
const dim3 grid(all_sequence_length, batch_size, 2);
|
||||
if (0 == (head_size & 1)) {
|
||||
const dim3 block(head_size / 2, num_heads, 1);
|
||||
ConcatPastToPresent<float2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const float2*>(past), reinterpret_cast<const float2*>(k_v), reinterpret_cast<float2*>(present));
|
||||
} else {
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
ConcatPastToPresent<float><<<grid, block, 0, stream>>>(sequence_length, past, k_v, present);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
bool LaunchConcatPastToPresent(cudaStream_t stream,
|
||||
const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int batch_size,
|
||||
const int head_size,
|
||||
const int num_heads,
|
||||
const half* past,
|
||||
const half* k_v,
|
||||
half* present) {
|
||||
const dim3 grid(all_sequence_length, batch_size, 2);
|
||||
if (0 == (head_size % 4)) {
|
||||
const dim3 block(head_size / 4, num_heads, 1);
|
||||
ConcatPastToPresent<float2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const float2*>(past), reinterpret_cast<const float2*>(k_v), reinterpret_cast<float2*>(present));
|
||||
} else if (0 == (head_size & 1)) {
|
||||
const dim3 block(head_size / 2, num_heads, 1);
|
||||
ConcatPastToPresent<half2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const half2*>(past), reinterpret_cast<const half2*>(k_v), reinterpret_cast<half2*>(present));
|
||||
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
ConcatPastToPresent<half><<<grid, block, 0, stream>>>(sequence_length, past, k_v, present);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
389
onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
Normal file
389
onnxruntime/contrib_ops/cuda/bert/attention_softmax.h
Normal file
|
|
@ -0,0 +1,389 @@
|
|||
/*
|
||||
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
|
||||
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
|
||||
|
||||
Copyright 2019 NVIDIA Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math_constants.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace cub;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void Softmax(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int valid_end,
|
||||
const int valid_start,
|
||||
const T* input,
|
||||
T* output) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
float thread_data_max(-CUDART_INF_F);
|
||||
|
||||
// e^x is represented as infinity if x is large enough, like 100.f.
|
||||
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
|
||||
// a math transform as below is leveraged to get a stable softmax:
|
||||
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
|
||||
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
|
||||
for (int i = threadIdx.x; i < valid_end; i += TPB) {
|
||||
if (i >= valid_start) {
|
||||
const int index = offset + i;
|
||||
if (thread_data_max < float(input[index])) {
|
||||
thread_data_max = float(input[index]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max());
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_sum(0.f);
|
||||
for (int i = threadIdx.x; i < valid_end; i += TPB) {
|
||||
if (i >= valid_start) {
|
||||
const int index = offset + i;
|
||||
const float val = input[index];
|
||||
thread_data_sum += expf(val - max_block);
|
||||
}
|
||||
}
|
||||
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum());
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = 1.f / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
for (int i = threadIdx.x; i < all_sequence_length; i += TPB) {
|
||||
const int index = offset + i;
|
||||
const float val = (i >= valid_start && i < valid_end) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f;
|
||||
output[index] = T(val);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void SoftmaxSmall(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int valid_end,
|
||||
const int valid_start,
|
||||
const T* input,
|
||||
T* output,
|
||||
bool is_unidirectional) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
|
||||
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
|
||||
const int index = offset + threadIdx.x;
|
||||
|
||||
bool is_valid = false; // whether it has attention mask == 1.
|
||||
|
||||
// Update end position for unidirectional.
|
||||
int end = valid_end;
|
||||
if (is_unidirectional) {
|
||||
int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1;
|
||||
if (end_unid <= valid_start) {
|
||||
// In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000.
|
||||
// So [0, end_unid) will also have value after softmax.
|
||||
is_valid = threadIdx.x < end_unid;
|
||||
} else {
|
||||
end = min(valid_end, end_unid);
|
||||
}
|
||||
}
|
||||
|
||||
is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end);
|
||||
|
||||
// e^x is represented as infinity if x is large enough, like 100.f.
|
||||
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
|
||||
// a math transform as below is leveraged to get a stable softmax:
|
||||
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
|
||||
float thread_data_max = is_valid ? float(input[index]) : float(-CUDART_INF_F);
|
||||
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_exp(0.f);
|
||||
if (is_valid) {
|
||||
thread_data_exp = expf(float(input[index]) - max_block);
|
||||
}
|
||||
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end);
|
||||
|
||||
// Store value of 1.0/sum.
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = (1.f) / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
// threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
output[index] = T(thread_data_exp * sum_reverse_block);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
||||
const int sequence_length,
|
||||
const int* attention_mask, // 2D or 3D attention mask
|
||||
const T* input,
|
||||
T* output,
|
||||
const bool is_unidirectional,
|
||||
const float scalar,
|
||||
const int mask_dimension) {
|
||||
using BlockReduce = cub::BlockReduce<float, TPB>;
|
||||
__shared__ typename BlockReduce::TempStorage tmp_storage;
|
||||
|
||||
__shared__ float sum_reverse_block;
|
||||
__shared__ float max_block;
|
||||
|
||||
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
|
||||
int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
|
||||
|
||||
float thread_data = -CUDART_INF_F;
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
const int batch_index = blockIdx.y;
|
||||
const int sequence_index = blockIdx.x % sequence_length;
|
||||
const int mask_offset = (mask_dimension == 2) ? batch_index * all_sequence_length + threadIdx.x : batch_index * sequence_length * all_sequence_length + sequence_index * all_sequence_length + threadIdx.x;
|
||||
|
||||
const int& mask = attention_mask[mask_offset];
|
||||
float mask_value = mask > 0 ? 0.0f : -10000.0f;
|
||||
|
||||
if (is_unidirectional) {
|
||||
int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length.
|
||||
if (threadIdx.x > from_index) {
|
||||
mask_value += -10000.0f;
|
||||
}
|
||||
}
|
||||
|
||||
thread_data = float(input[index]) * scalar + mask_value;
|
||||
}
|
||||
|
||||
const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length);
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
max_block = max;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length);
|
||||
|
||||
// Store value of 1.0/sum
|
||||
if (threadIdx.x == 0) {
|
||||
sum_reverse_block = (1.f) / sum;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
output[index] = T(thread_data_exp * sum_reverse_block);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) {
|
||||
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, input, output, is_unidirectional);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* input, T* output) {
|
||||
Softmax<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, input, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmax(
|
||||
cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const T* input, T* output, bool is_unidirectional) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
|
||||
} else if (!is_unidirectional) {
|
||||
const int blockSize = 1024;
|
||||
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output, bool is_unidirectional) {
|
||||
__shared__ int start_position;
|
||||
__shared__ int end_position;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int batch = blockIdx.y;
|
||||
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
|
||||
end_position = min(all_sequence_length, mask_end[batch]);
|
||||
|
||||
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
|
||||
if (start_position >= end_position) {
|
||||
start_position = 0;
|
||||
end_position = all_sequence_length;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, input, output, is_unidirectional);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output) {
|
||||
__shared__ int start_position;
|
||||
__shared__ int end_position;
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
const int batch = blockIdx.y;
|
||||
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
|
||||
end_position = min(all_sequence_length, mask_end[batch]);
|
||||
|
||||
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
|
||||
if (start_position >= end_position) {
|
||||
start_position = 0;
|
||||
end_position = all_sequence_length;
|
||||
}
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
Softmax<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, input, output);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar, const int mask_dimension) {
|
||||
SoftmaxWithRawMaskSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const int* mask_index, const int* mask_start, const T* input, T* output, const bool is_unidirectional) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
MaskedSoftmaxKernelSmall<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
|
||||
} else if (!is_unidirectional) {
|
||||
const int blockSize = 1024;
|
||||
MaskedSoftmaxKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmaxWithRawMask(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar,
|
||||
const int mask_dimension) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
|
||||
if (all_sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
|
||||
} else if (all_sequence_length <= 64) {
|
||||
const int blockSize = 64;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
|
||||
} else if (all_sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
|
||||
} else if (all_sequence_length <= 256) {
|
||||
const int blockSize = 256;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
|
||||
} else if (all_sequence_length <= 512) {
|
||||
const int blockSize = 512;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
|
||||
} else if (all_sequence_length <= 1024) {
|
||||
const int blockSize = 1024;
|
||||
SoftmaxWithRawMaskSmallKernel<T, blockSize>
|
||||
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
|
||||
} else {
|
||||
ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024.");
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
160
onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu
Normal file
160
onnxruntime/contrib_ops/cuda/bert/attention_transpose.cu
Normal file
|
|
@ -0,0 +1,160 @@
|
|||
/*
|
||||
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
|
||||
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
|
||||
|
||||
Copyright 2019 NVIDIA Corporation
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
*/
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "attention_impl.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T>
|
||||
__global__ void TransposeCtx(const int H, const T* input, T* output) {
|
||||
// Input: BxNxSxH
|
||||
// Output: BxSxNxH
|
||||
|
||||
int n = threadIdx.y;
|
||||
int s = blockIdx.x;
|
||||
int b = blockIdx.y;
|
||||
|
||||
int num_heads = blockDim.y;
|
||||
int sequence_length = gridDim.x;
|
||||
|
||||
const int NH = num_heads * H;
|
||||
const int NHS = NH * sequence_length;
|
||||
const int in_offset = s * H + n * sequence_length * H + b * NHS;
|
||||
const int out_offset = n * H + s * NH + b * NHS;
|
||||
|
||||
const int i = threadIdx.x;
|
||||
if (i < H) {
|
||||
output[out_offset + i] = input[in_offset + i];
|
||||
}
|
||||
}
|
||||
|
||||
bool LaunchTransCtx(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const float* input, float* output) {
|
||||
const dim3 grid(sequence_length, batch_size, 1);
|
||||
if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
const dim3 block(H, num_heads, 1);
|
||||
TransposeCtx<float2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else {
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
TransposeCtx<float><<<grid, block, 0, stream>>>(head_size, input, output);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
bool LaunchTransCtx(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const half* input, half* output) {
|
||||
const dim3 grid(sequence_length, batch_size, 1);
|
||||
if (0 == (head_size % 4)) {
|
||||
const int H = head_size / 4;
|
||||
const dim3 block(H, num_heads, 1);
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
TransposeCtx<float2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
const dim3 block(H, num_heads, 1);
|
||||
const half2* input2 = reinterpret_cast<const half2*>(input);
|
||||
half2* output2 = reinterpret_cast<half2*>(output);
|
||||
TransposeCtx<half2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
TransposeCtx<half><<<grid, block, 0, stream>>>(head_size, input, output);
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void TransposeQKV(const int H, const T* input, T* output) {
|
||||
// Input: BxSx3xNxH
|
||||
// Output: 3xBxNxSxH
|
||||
|
||||
int n = threadIdx.y;
|
||||
int s = blockIdx.x;
|
||||
int b = blockIdx.y;
|
||||
int m = blockIdx.z; // matrix id
|
||||
|
||||
const int num_heads = blockDim.y;
|
||||
|
||||
const int sequence_length = gridDim.x;
|
||||
const int batch_size = gridDim.y;
|
||||
const int NH = num_heads * H;
|
||||
const int NHS = NH * sequence_length;
|
||||
const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
|
||||
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
|
||||
|
||||
const int i = threadIdx.x;
|
||||
if (i < H) {
|
||||
output[out_offset + i] = input[in_offset + i];
|
||||
}
|
||||
}
|
||||
|
||||
bool LaunchTransQkv(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const float* input, float* output) {
|
||||
const dim3 grid(sequence_length, batch_size, 3);
|
||||
if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
const dim3 block(H, num_heads, 1);
|
||||
TransposeQKV<float2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else {
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
TransposeQKV<float><<<grid, block, 0, stream>>>(head_size, input, output);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
bool LaunchTransQkv(cudaStream_t stream,
|
||||
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
|
||||
const half* input, half* output) {
|
||||
const dim3 grid(sequence_length, batch_size, 3);
|
||||
if (0 == (head_size % 4)) {
|
||||
const int H = head_size / 4;
|
||||
const dim3 block(H, num_heads, 1);
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
TransposeQKV<float2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
const dim3 block(H, num_heads, 1);
|
||||
const half2* input2 = reinterpret_cast<const half2*>(input);
|
||||
half2* output2 = reinterpret_cast<half2*>(output);
|
||||
TransposeQKV<half2><<<grid, block, 0, stream>>>(H, input2, output2);
|
||||
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel..
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
TransposeQKV<half><<<grid, block, 0, stream>>>(head_size, input, output);
|
||||
}
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -277,7 +277,8 @@ void FusedMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
|
|||
void RegisterBertSchemas() {
|
||||
static const char* Attention_ver1_doc = R"DOC(
|
||||
Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT).
|
||||
The mask_index input is optional. Besides raw attention mask with shape (batch_size, past_sequence_length + sequence_length),
|
||||
The mask_index input is optional. Besides raw attention mask with shape (batch_size, past_sequence_length + sequence_length)
|
||||
or (batch_size, sequence_length, past_sequence_length + sequence_length) with value 0 for masked and 1 otherwise,
|
||||
we also support other two formats: When input has right-side padding, mask_index is one dimension with shape (batch_size),
|
||||
where value of each element is the end position, or valid length of actual sequence excluding padding. When input has
|
||||
left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by
|
||||
|
|
@ -297,7 +298,7 @@ and present state are optional. Present state could appear in output even when p
|
|||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T")
|
||||
.Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
|
||||
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
|
||||
.Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
|
||||
.Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
|
||||
.Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional)
|
||||
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T")
|
||||
.Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional)
|
||||
|
|
|
|||
|
|
@ -13,12 +13,13 @@ enum MaskIndexType {
|
|||
kMaskIndexEnd = 0,
|
||||
kMaskIndexEndAndStart,
|
||||
kMaskRaw,
|
||||
kMask3D,
|
||||
kMaskDummy // Dummy mask with shape [1, 1] or [batch_size, 1]
|
||||
};
|
||||
|
||||
static void RunAttentionTest(
|
||||
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size]
|
||||
const std::vector<float>& weights_data, // weights: [hidden_size, 3 * hidden_size]
|
||||
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size]
|
||||
const std::vector<float>& weights_data, // weights: [hidden_size, 3 * hidden_size]
|
||||
bool is_weights_constant,
|
||||
const std::vector<float>& bias_data, // bias: [3 * hidden_size]
|
||||
const std::vector<int32_t>& mask_index_data, // mask_index: [batch_size] or [batch_size, past_sequence_length + sequence_length] or empty
|
||||
|
|
@ -52,6 +53,7 @@ static void RunAttentionTest(
|
|||
std::vector<int64_t> mask_index_dims_2 = {2 * batch_size};
|
||||
std::vector<int64_t> mask_index_dims_3 = {batch_size, past_sequence_length + sequence_length};
|
||||
std::vector<int64_t> mask_index_dims_4 = {batch_size, 1};
|
||||
std::vector<int64_t> mask_index_dims_5 = {batch_size, sequence_length, past_sequence_length + sequence_length};
|
||||
std::vector<int64_t> mask_index_dims;
|
||||
switch (mask_index_type) {
|
||||
case kMaskIndexEnd:
|
||||
|
|
@ -66,8 +68,11 @@ static void RunAttentionTest(
|
|||
case kMaskDummy:
|
||||
mask_index_dims = mask_index_dims_4;
|
||||
break;
|
||||
case kMask3D:
|
||||
mask_index_dims = mask_index_dims_5;
|
||||
break;
|
||||
default:
|
||||
assert(0); // shall not reach here.
|
||||
assert(0); // shall not reach here.
|
||||
break;
|
||||
}
|
||||
|
||||
|
|
@ -973,6 +978,51 @@ TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) {
|
|||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, Attention3DMask) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.5f, 0.2f, 0.3f, -0.6f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// Test 3D mask BxSxS*
|
||||
std::vector<int32_t> mask_index_data = {
|
||||
0, 1,
|
||||
0, 1,
|
||||
1, 1,
|
||||
1, 1};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
3.14959716796875f, 0.10843672603368759f, 4.25f, 5.65f,
|
||||
3.9696791172027588f, 0.073143675923347473f, 4.25f, 5.65f};
|
||||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch2AttentionMask) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
|
|
@ -1014,6 +1064,51 @@ TEST(AttentionTest, AttentionBatch2AttentionMask) {
|
|||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionUnidirectional3DMask) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.5f, 0.2f, 0.3f, -0.6f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// Test 3D mask BxSxS*
|
||||
std::vector<int32_t> mask_index_data = {
|
||||
0, 1,
|
||||
0, 1,
|
||||
1, 1,
|
||||
1, 1};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.967245340f, 0.07324841f, 4.25f, 5.65f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
8.69f, -0.13f, 4.25f, 5.65f,
|
||||
3.96967912f, 0.07314367f, 4.25f, 5.65f};
|
||||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = true;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
|
|
@ -1181,6 +1276,47 @@ TEST(AttentionTest, AttentionMask2DNoWord) {
|
|||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionMask3DNoWord) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.5f, 0.2f, 0.3f, -0.6f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// Test that all attention masks are zero.
|
||||
std::vector<int32_t> mask_index_data = {0, 0, 0, 0, 0, 0, 0, 0};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.96724534f, 0.07324841f, 4.25f, 5.65f,
|
||||
3.14984703f, 0.10842596f, 4.25f, 5.65f,
|
||||
3.14984703f, 0.10842596f, 4.25f, 5.65f,
|
||||
3.96724534f, 0.07324841f, 4.25f, 5.65f};
|
||||
|
||||
bool use_float16 = false;
|
||||
bool is_unidirectional = false;
|
||||
bool use_past_state = false;
|
||||
int past_sequence_length = 0;
|
||||
const std::vector<float>* past_data = nullptr;
|
||||
const std::vector<float>* present_data = nullptr;
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionDummyMask2D) {
|
||||
int batch_size = 2;
|
||||
|
|
@ -1294,4 +1430,4 @@ TEST(AttentionTest, AttentionPastState_dynamic) {
|
|||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue