Support 3D attention mask (#5887)

Support 3D attention mask with shape (batch_size, sequence_length, all_sequence_length)
This commit is contained in:
Tianlei Wu 2020-11-20 22:48:01 -08:00 committed by GitHub
parent cc6e8fb7cc
commit 910bbfe1ef
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 886 additions and 596 deletions

View file

@ -59,7 +59,10 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
// input : (batch_size, sequence_length, hidden_size)
// weights : (hidden_size, 3 * hidden_size)
// bias : (3 * hidden_size)
// mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1) or (batch_size, past_sequence_length + sequence_length)
// mask_index : nullptr, (batch_size), (2 * batch_size),
// or (batch_size, 1), (1, 1)
// or (batch_size, past_sequence_length + sequence_length)
// or (batch_size, sequence_length, past_sequence_length + sequence_length)
// past : (2, batch_size, num_heads, past_sequence_length, head_size)
const auto& dims = input_shape.GetDims();
@ -136,8 +139,12 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' with raw attention mask shall have shape batch_size x (past_sequence_length + sequence_length)");
}
}
} else if (mask_dims.size() == 3) {
if (static_cast<int>(mask_dims[0]) != batch_size || mask_dims[1] != sequence_length || static_cast<int>(mask_dims[2]) != past_sequence_length + sequence_length) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 'mask_index' of 3d shall have shape batch_size x sequence_length x (past_sequence_length + sequence_length)");
}
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1 or 2 dimensions, got ",
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'mask_index' is expected to have 1, 2 or 3 dimensions, got ",
mask_dims.size());
}
}

View file

@ -26,8 +26,8 @@ class AttentionBase {
int sequence_length,
int& past_sequence_length) const;
int num_heads_; // number of attention heads
bool is_unidirectional_; // whether every token can only attend to previous tokens.
int num_heads_; // number of attention heads
bool is_unidirectional_; // whether every token can only attend to previous tokens.
};
} // namespace contrib

View file

@ -89,7 +89,7 @@ class AttentionCPUBase : public AttentionBase {
const T* K, // k data. Its size is BxNxSxH
const int32_t* mask_index, // mask index. nullptr if no mask or its size is B
const std::vector<int64_t>* mask_index_dims, // mask index shape
T* mask_data, // buffer for mask data. Its size is: SxS* if is_unidirectional_; BxSxS* if mask_index; null otherwise
T* mask_data, // buffer for mask data. It is nullptr if mask_index is nullptr, otherwise its shape is BxSxS*
int batch_size, // batch size of self-attention
int sequence_length, // sequence length of self-attention
int past_sequence_length, // sequence length of past state

View file

@ -72,12 +72,31 @@ void PrepareMask(const int32_t* mask_index,
// mask_data has been filled with 0, and its shape is BxSxS*
T* p_mask = mask_data;
// For 3D mask, convert values 0 to -10000.0, and 1 to 0.0, then apply unidirectional mask if any.
if (nullptr != mask_index_dims && mask_index_dims->size() == 3) {
for (int i = 0; i < batch_size * sequence_length * all_sequence_length; i++) {
p_mask[i] = (mask_index[i] > 0) ? static_cast<T>(0.0f) : static_cast<T>(-10000.0f);
}
if (is_unidirectional) {
for (int b_i = 0; b_i < batch_size; b_i++) {
for (int s_i = 0; s_i < sequence_length - 1; s_i++) {
for (int m_i = past_sequence_length + s_i + 1; m_i < all_sequence_length; m_i++) {
p_mask[s_i * all_sequence_length + m_i] += static_cast<T>(-10000.0f);
}
}
p_mask += sequence_length * all_sequence_length;
}
}
return;
}
bool is_raw_attention_mask = (nullptr != mask_index_dims && mask_index_dims->size() == 2);
bool has_mask_start_position = (nullptr != mask_index_dims && mask_index_dims->size() == 1 && static_cast<int>(mask_index_dims->at(0)) == 2 * batch_size);
for (int b_i = 0; b_i < batch_size; b_i++) {
// TODO: mask_index can be used in softmax to save some calculation.
if (nullptr != mask_index) {
if (is_raw_attention_mask) {
// Raw attention mask has value 0 or 1. Here we convert 0 to -10000.0, and 1 to 0.0.
@ -120,7 +139,6 @@ void PrepareMask(const int32_t* mask_index,
p_mask += sequence_length * all_sequence_length;
}
}
// Concatenate a past state chunk S'xH with input state chunk SxH into present state chunk S*xH

View file

@ -21,13 +21,11 @@ limitations under the License.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cub/cub.cuh>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <math_constants.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "attention_impl.h"
#include "attention_softmax.h"
using namespace onnxruntime::cuda;
using namespace cub;
@ -40,7 +38,7 @@ static size_t AlignTo(size_t a, size_t b) {
return CeilDiv(a, b) * b;
}
size_t ScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) {
size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length) {
const size_t len = batch_size * num_heads * sequence_length * all_sequence_length;
const size_t bytes = len * element_size;
@ -57,580 +55,7 @@ size_t GetAttentionWorkspaceSize(
int sequence_length,
int past_sequence_length) {
size_t qkv_size = 3 * batch_size * sequence_length * num_heads * head_size * element_size;
return qkv_size + 2 * ScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length);
}
template <typename T, unsigned TPB>
__device__ inline void Softmax(const int all_sequence_length,
const int sequence_length,
const int valid_end,
const int valid_start,
const T* input,
T* output) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
float thread_data_max(-CUDART_INF_F);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
const int index = offset + i;
if (thread_data_max < float(input[index])) {
thread_data_max = float(input[index]);
}
}
}
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max());
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_sum(0.f);
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
const int index = offset + i;
const float val = input[index];
thread_data_sum += expf(val - max_block);
}
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum());
if (threadIdx.x == 0) {
sum_reverse_block = 1.f / sum;
}
__syncthreads();
for (int i = threadIdx.x; i < all_sequence_length; i += TPB) {
const int index = offset + i;
const float val = (i >= valid_start && i < valid_end) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f;
output[index] = T(val);
}
}
template <typename T, unsigned TPB>
__device__ inline void SoftmaxSmall(const int all_sequence_length,
const int sequence_length,
const int valid_end,
const int valid_start,
const T* input,
T* output,
bool is_unidirectional) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
const int index = offset + threadIdx.x;
bool is_valid = false; // whether it has attention mask == 1.
// Update end position for unidirectional.
int end = valid_end;
if (is_unidirectional) {
int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1;
if (end_unid <= valid_start) {
// In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000.
// So [0, end_unid) will also have value after softmax.
is_valid = threadIdx.x < end_unid;
} else {
end = min(valid_end, end_unid);
}
}
is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
float thread_data_max = is_valid ? float(input[index]) : float(-CUDART_INF_F);
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_exp(0.f);
if (is_valid) {
thread_data_exp = expf(float(input[index]) - max_block);
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end);
// Store value of 1.0/sum.
if (threadIdx.x == 0) {
sum_reverse_block = (1.f) / sum;
}
__syncthreads();
// threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
if (threadIdx.x < all_sequence_length) {
output[index] = T(thread_data_exp * sum_reverse_block);
}
}
template <typename T, unsigned TPB>
__device__ inline void SoftmaxWithMask2DSmall(const int all_sequence_length,
const int sequence_length,
const int* attention_mask, // 2D attention mask
const T* input,
T* output,
const bool is_unidirectional,
const float scalar) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
float thread_data = -CUDART_INF_F;
if (threadIdx.x < all_sequence_length) {
const int& mask = attention_mask[blockIdx.y * all_sequence_length + threadIdx.x];
float mask_value = mask > 0 ? 0.0f : -10000.0f;
if (is_unidirectional) {
int from_index = all_sequence_length - sequence_length + (blockIdx.x % sequence_length); // offset of from token in all sequence length.
if (threadIdx.x > from_index) {
mask_value += -10000.0f;
}
}
thread_data = float(input[index]) * scalar + mask_value;
}
const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length);
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length);
// Store value of 1.0/sum
if (threadIdx.x == 0) {
sum_reverse_block = (1.f) / sum;
}
__syncthreads();
if (threadIdx.x < all_sequence_length) {
output[index] = T(thread_data_exp * sum_reverse_block);
}
}
template <typename T, unsigned TPB>
__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) {
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, input, output, is_unidirectional);
}
template <typename T, unsigned TPB>
__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* input, T* output) {
Softmax<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, input, output);
}
template <typename T>
bool ComputeSoftmax(
cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const T* input, T* output, bool is_unidirectional) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (!is_unidirectional) {
const int blockSize = 1024;
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output);
} else {
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output, bool is_unidirectional) {
__shared__ int start_position;
__shared__ int end_position;
if (threadIdx.x == 0) {
const int batch = blockIdx.y;
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
end_position = min(all_sequence_length, mask_end[batch]);
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
if (start_position >= end_position) {
start_position = 0;
end_position = all_sequence_length;
}
}
__syncthreads();
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, input, output, is_unidirectional);
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output) {
__shared__ int start_position;
__shared__ int end_position;
if (threadIdx.x == 0) {
const int batch = blockIdx.y;
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
end_position = min(all_sequence_length, mask_end[batch]);
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
if (start_position >= end_position) {
start_position = 0;
end_position = all_sequence_length;
}
}
__syncthreads();
Softmax<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, input, output);
}
template <typename T, unsigned TPB>
__global__ void SoftmaxWithMask2DSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar) {
SoftmaxWithMask2DSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
}
template <typename T>
bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* mask_index, const int* mask_start, const T* input, T* output, const bool is_unidirectional) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (!is_unidirectional) {
const int blockSize = 1024;
MaskedSoftmaxKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output);
} else {
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T>
bool ComputeSoftmaxWithMask2D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
SoftmaxWithMask2DSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
SoftmaxWithMask2DSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
SoftmaxWithMask2DSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
SoftmaxWithMask2DSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
SoftmaxWithMask2DSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
SoftmaxWithMask2DSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar);
} else {
ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024.");
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T>
__global__ void TransposeCtx(const int H, const T* input, T* output) {
// Input: BxNxSxH
// Output: BxSxNxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int num_heads = blockDim.y;
int sequence_length = gridDim.x;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int in_offset = s * H + n * sequence_length * H + b * NHS;
const int out_offset = n * H + s * NH + b * NHS;
const int i = threadIdx.x;
if (i < H) {
output[out_offset + i] = input[in_offset + i];
}
}
bool LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const float* input, float* output) {
const dim3 grid(sequence_length, batch_size, 1);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
const dim3 block(H, num_heads, 1);
TransposeCtx<float2><<<grid, block, 0, stream>>>(H, input2, output2);
} else {
const dim3 block(head_size, num_heads, 1);
TransposeCtx<float><<<grid, block, 0, stream>>>(head_size, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
bool LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const half* input, half* output) {
const dim3 grid(sequence_length, batch_size, 1);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const dim3 block(H, num_heads, 1);
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
TransposeCtx<float2><<<grid, block, 0, stream>>>(H, input2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const dim3 block(H, num_heads, 1);
const half2* input2 = reinterpret_cast<const half2*>(input);
half2* output2 = reinterpret_cast<half2*>(output);
TransposeCtx<half2><<<grid, block, 0, stream>>>(H, input2, output2);
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
const dim3 block(head_size, num_heads, 1);
TransposeCtx<half><<<grid, block, 0, stream>>>(head_size, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T>
__global__ void TransposeQKV(const int H, const T* input, T* output) {
// Input: BxSx3xNxH
// Output: 3xBxNxSxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int i = threadIdx.x;
if (i < H) {
output[out_offset + i] = input[in_offset + i];
}
}
bool LaunchTransQkv(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const float* input, float* output) {
const dim3 grid(sequence_length, batch_size, 3);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
const dim3 block(H, num_heads, 1);
TransposeQKV<float2><<<grid, block, 0, stream>>>(H, input2, output2);
} else {
const dim3 block(head_size, num_heads, 1);
TransposeQKV<float><<<grid, block, 0, stream>>>(head_size, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
bool LaunchTransQkv(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const half* input, half* output) {
const dim3 grid(sequence_length, batch_size, 3);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const dim3 block(H, num_heads, 1);
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
TransposeQKV<float2><<<grid, block, 0, stream>>>(H, input2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const dim3 block(H, num_heads, 1);
const half2* input2 = reinterpret_cast<const half2*>(input);
half2* output2 = reinterpret_cast<half2*>(output);
TransposeQKV<half2><<<grid, block, 0, stream>>>(H, input2, output2);
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel..
const dim3 block(head_size, num_heads, 1);
TransposeQKV<half><<<grid, block, 0, stream>>>(head_size, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T>
__global__ void ConcatPastToPresent(const int sequence_length,
const T* past,
const T* k_v,
T* present) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int is_v = blockIdx.z; // 0 for k, 1 for v
const int all_sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int num_heads = blockDim.y;
const int H = blockDim.x;
// past: 2 x BxNxS'xH (past_k and past_v)
// k_v: 2 x BxNxSxH (k and v)
// present: 2 x BxNxS*xH (present_k and present_v)
const int past_sequence_length = all_sequence_length - sequence_length;
const int present_SH = all_sequence_length * H;
const int present_NSH = num_heads * present_SH;
int out_offset = b * present_NSH + n * present_SH + s * H + h + is_v * (present_NSH * batch_size);
if (s < past_sequence_length) {
const int past_SH = past_sequence_length * H;
const int past_NSH = num_heads * past_SH;
const int in_offset = b * past_NSH + n * past_SH + s * H + h + is_v * (past_NSH * batch_size);
present[out_offset] = past[in_offset];
} else if (s < all_sequence_length) {
const int SH = sequence_length * H;
const int NSH = num_heads * SH;
const int in_offset = b * NSH + n * SH + (s - past_sequence_length) * H + h + is_v * (NSH * batch_size);
present[out_offset] = k_v[in_offset];
}
}
bool LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const float* past,
const float* k_v,
float* present) {
const dim3 grid(all_sequence_length, batch_size, 2);
if (0 == (head_size & 1)) {
const dim3 block(head_size / 2, num_heads, 1);
ConcatPastToPresent<float2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const float2*>(past), reinterpret_cast<const float2*>(k_v), reinterpret_cast<float2*>(present));
} else {
const dim3 block(head_size, num_heads, 1);
ConcatPastToPresent<float><<<grid, block, 0, stream>>>(sequence_length, past, k_v, present);
}
return CUDA_CALL(cudaPeekAtLastError());
}
bool LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const half* past,
const half* k_v,
half* present) {
const dim3 grid(all_sequence_length, batch_size, 2);
if (0 == (head_size % 4)) {
const dim3 block(head_size / 4, num_heads, 1);
ConcatPastToPresent<float2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const float2*>(past), reinterpret_cast<const float2*>(k_v), reinterpret_cast<float2*>(present));
} else if (0 == (head_size & 1)) {
const dim3 block(head_size / 2, num_heads, 1);
ConcatPastToPresent<half2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const half2*>(past), reinterpret_cast<const half2*>(k_v), reinterpret_cast<half2*>(present));
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
const dim3 block(head_size, num_heads, 1);
ConcatPastToPresent<half><<<grid, block, 0, stream>>>(sequence_length, past, k_v, present);
}
return CUDA_CALL(cudaPeekAtLastError());
}
cublasStatus_t inline CublasGemmStridedBatched(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const float alpha,
const float* A, int lda, long long int strideA, const float* B, int ldb, long long int strideB,
const float beta, float* C, int ldc, long long int strideC, int batchCount) {
return cublasSgemmStridedBatched(
handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount);
}
cublasStatus_t inline CublasGemmStridedBatched(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const half alpha,
const half* A, int lda, long long int strideA, const half* B, int ldb, long long int strideB,
const half beta, half* C, int ldc, long long int strideC, int batchCount) {
return cublasHgemmStridedBatched(
handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount);
return qkv_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, past_sequence_length + sequence_length);
}
template <typename T>
@ -641,7 +66,7 @@ bool QkvToContext(
const int* mask_index, const std::vector<int64_t>* mask_index_dims,
bool is_unidirectional, int past_sequence_length, const T* past, T* present) {
const int all_sequence_length = past_sequence_length + sequence_length;
const size_t bytes = ScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
const size_t bytes = GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length, all_sequence_length);
T* scratch1 = workspace;
T* scratch2 = scratch1 + (bytes / element_size);
T* scratch3 = scratch2 + (bytes / element_size);
@ -677,13 +102,15 @@ bool QkvToContext(
v = present + batches * present_size_per_batch;
}
bool use_2d_attention_mask = (nullptr != mask_index && nullptr != mask_index_dims && mask_index_dims->size() == 2);
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*)
bool use_raw_attention_mask = (nullptr != mask_index && nullptr != mask_index_dims && mask_index_dims->size() >= 2);
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
const int temp_matrix_size = sequence_length * all_sequence_length;
T alpha = (T)(use_2d_attention_mask ? 1.0f : rsqrt_head_size);
// For raw attention mask, the scalar if 1/sqrt(H) is moved to softmax computation.
T alpha = (T)(use_raw_attention_mask ? 1.0f : rsqrt_head_size);
if (!CUBLAS_CALL(CublasGemmStridedBatched(
cublas, CUBLAS_OP_T, CUBLAS_OP_N, all_sequence_length, sequence_length, head_size, alpha, k, head_size, present_size_per_batch,
q, head_size, size_per_batch, 0.f, scratch1, all_sequence_length, temp_matrix_size, batches))) {
@ -691,8 +118,8 @@ bool QkvToContext(
}
// apply softmax and store result P to scratch2: BxNxSxS*
if (use_2d_attention_mask) { // 2d attention mask
if (!ComputeSoftmaxWithMask2D<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, rsqrt_head_size)) {
if (use_raw_attention_mask) { // 2d or 3d attention mask
if (!ComputeSoftmaxWithRawMask<T>(stream, all_sequence_length, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2, is_unidirectional, rsqrt_head_size, static_cast<int>(mask_index_dims->size()))) {
return false;
}
} else if (nullptr != mask_index) { // 1d mask index

View file

@ -3,10 +3,13 @@
#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
#include <cublas_v2.h>
namespace onnxruntime {
namespace contrib {
namespace cuda {
size_t GetAttentionScratchSize(size_t element_size, int batch_size, int num_heads, int sequence_length, int all_sequence_length);
size_t GetAttentionWorkspaceSize(
size_t element_size,
int batchsize,
@ -34,6 +37,60 @@ bool LaunchAttentionKernel(
void* present // Present state output
);
cublasStatus_t inline CublasGemmStridedBatched(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const float alpha,
const float* A, int lda, long long int strideA, const float* B, int ldb, long long int strideB,
const float beta, float* C, int ldc, long long int strideC, int batchCount) {
return cublasSgemmStridedBatched(
handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount);
}
cublasStatus_t inline CublasGemmStridedBatched(
cublasHandle_t handle, cublasOperation_t transa, cublasOperation_t transb,
int m, int n, int k, const half alpha,
const half* A, int lda, long long int strideA, const half* B, int ldb, long long int strideB,
const half beta, half* C, int ldc, long long int strideC, int batchCount) {
return cublasHgemmStridedBatched(
handle, transa, transb, m, n, k, &alpha, A, lda, strideA, B, ldb, strideB, &beta, C, ldc, strideC, batchCount);
}
bool LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const float* input, float* output);
bool LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const half* input, half* output);
bool LaunchTransQkv(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const float* input, float* output);
bool LaunchTransQkv(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const half* input, half* output);
bool LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const float* past,
const float* k_v,
float* present);
bool LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const half* past,
const half* k_v,
half* present);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,95 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "attention_impl.h"
using namespace onnxruntime::cuda;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T>
__global__ void ConcatPastToPresent(const int sequence_length,
const T* past,
const T* k_v,
T* present) {
const int h = threadIdx.x;
const int n = threadIdx.y;
const int s = blockIdx.x;
const int b = blockIdx.y;
const int is_v = blockIdx.z; // 0 for k, 1 for v
const int all_sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int num_heads = blockDim.y;
const int H = blockDim.x;
// past: 2 x BxNxS'xH (past_k and past_v)
// k_v: 2 x BxNxSxH (k and v)
// present: 2 x BxNxS*xH (present_k and present_v)
const int past_sequence_length = all_sequence_length - sequence_length;
const int present_SH = all_sequence_length * H;
const int present_NSH = num_heads * present_SH;
int out_offset = b * present_NSH + n * present_SH + s * H + h + is_v * (present_NSH * batch_size);
if (s < past_sequence_length) {
const int past_SH = past_sequence_length * H;
const int past_NSH = num_heads * past_SH;
const int in_offset = b * past_NSH + n * past_SH + s * H + h + is_v * (past_NSH * batch_size);
present[out_offset] = past[in_offset];
} else if (s < all_sequence_length) {
const int SH = sequence_length * H;
const int NSH = num_heads * SH;
const int in_offset = b * NSH + n * SH + (s - past_sequence_length) * H + h + is_v * (NSH * batch_size);
present[out_offset] = k_v[in_offset];
}
}
bool LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const float* past,
const float* k_v,
float* present) {
const dim3 grid(all_sequence_length, batch_size, 2);
if (0 == (head_size & 1)) {
const dim3 block(head_size / 2, num_heads, 1);
ConcatPastToPresent<float2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const float2*>(past), reinterpret_cast<const float2*>(k_v), reinterpret_cast<float2*>(present));
} else {
const dim3 block(head_size, num_heads, 1);
ConcatPastToPresent<float><<<grid, block, 0, stream>>>(sequence_length, past, k_v, present);
}
return CUDA_CALL(cudaPeekAtLastError());
}
bool LaunchConcatPastToPresent(cudaStream_t stream,
const int all_sequence_length,
const int sequence_length,
const int batch_size,
const int head_size,
const int num_heads,
const half* past,
const half* k_v,
half* present) {
const dim3 grid(all_sequence_length, batch_size, 2);
if (0 == (head_size % 4)) {
const dim3 block(head_size / 4, num_heads, 1);
ConcatPastToPresent<float2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const float2*>(past), reinterpret_cast<const float2*>(k_v), reinterpret_cast<float2*>(present));
} else if (0 == (head_size & 1)) {
const dim3 block(head_size / 2, num_heads, 1);
ConcatPastToPresent<half2><<<grid, block, 0, stream>>>(sequence_length, reinterpret_cast<const half2*>(past), reinterpret_cast<const half2*>(k_v), reinterpret_cast<half2*>(present));
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
const dim3 block(head_size, num_heads, 1);
ConcatPastToPresent<half><<<grid, block, 0, stream>>>(sequence_length, past, k_v, present);
}
return CUDA_CALL(cudaPeekAtLastError());
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,389 @@
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include <cub/cub.cuh>
#include <cuda_fp16.h>
#include <math_constants.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
using namespace onnxruntime::cuda;
using namespace cub;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T, unsigned TPB>
__device__ inline void Softmax(const int all_sequence_length,
const int sequence_length,
const int valid_end,
const int valid_start,
const T* input,
T* output) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
float thread_data_max(-CUDART_INF_F);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
const int index = offset + i;
if (thread_data_max < float(input[index])) {
thread_data_max = float(input[index]);
}
}
}
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max());
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_sum(0.f);
for (int i = threadIdx.x; i < valid_end; i += TPB) {
if (i >= valid_start) {
const int index = offset + i;
const float val = input[index];
thread_data_sum += expf(val - max_block);
}
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_sum, cub::Sum());
if (threadIdx.x == 0) {
sum_reverse_block = 1.f / sum;
}
__syncthreads();
for (int i = threadIdx.x; i < all_sequence_length; i += TPB) {
const int index = offset + i;
const float val = (i >= valid_start && i < valid_end) ? expf(float(input[index]) - max_block) * sum_reverse_block : 0.f;
output[index] = T(val);
}
}
template <typename T, unsigned TPB>
__device__ inline void SoftmaxSmall(const int all_sequence_length,
const int sequence_length,
const int valid_end,
const int valid_start,
const T* input,
T* output,
bool is_unidirectional) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
const int offset = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length;
const int index = offset + threadIdx.x;
bool is_valid = false; // whether it has attention mask == 1.
// Update end position for unidirectional.
int end = valid_end;
if (is_unidirectional) {
int end_unid = all_sequence_length - sequence_length + (blockIdx.x % sequence_length) + 1;
if (end_unid <= valid_start) {
// In this situation, mask of [0, end_unid) and [valid_start, valid_end) has -10000, and [end_unid, valid_start) and [valid_end, all_seq_len) has -20000.
// So [0, end_unid) will also have value after softmax.
is_valid = threadIdx.x < end_unid;
} else {
end = min(valid_end, end_unid);
}
}
is_valid = is_valid || (threadIdx.x >= valid_start && threadIdx.x < end);
// e^x is represented as infinity if x is large enough, like 100.f.
// Infinity divided by Infinity is a NAN. Thus, softmax gets a NAN if one or more item are large enough.
// a math transform as below is leveraged to get a stable softmax:
// e^xi/(e^x1 + ...e^xn) = e^(xi - max) / (e^(x1 - max) + ... + e^(xn - max))
float thread_data_max = is_valid ? float(input[index]) : float(-CUDART_INF_F);
const auto max = BlockReduce(tmp_storage).Reduce(thread_data_max, cub::Max(), end);
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_exp(0.f);
if (is_valid) {
thread_data_exp = expf(float(input[index]) - max_block);
}
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), end);
// Store value of 1.0/sum.
if (threadIdx.x == 0) {
sum_reverse_block = (1.f) / sum;
}
__syncthreads();
// threadIdx.x might be larger than all_sequence_length due to alignment to 32x.
if (threadIdx.x < all_sequence_length) {
output[index] = T(thread_data_exp * sum_reverse_block);
}
}
template <typename T, unsigned TPB>
__device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
const int sequence_length,
const int* attention_mask, // 2D or 3D attention mask
const T* input,
T* output,
const bool is_unidirectional,
const float scalar,
const int mask_dimension) {
using BlockReduce = cub::BlockReduce<float, TPB>;
__shared__ typename BlockReduce::TempStorage tmp_storage;
__shared__ float sum_reverse_block;
__shared__ float max_block;
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
float thread_data = -CUDART_INF_F;
if (threadIdx.x < all_sequence_length) {
const int batch_index = blockIdx.y;
const int sequence_index = blockIdx.x % sequence_length;
const int mask_offset = (mask_dimension == 2) ? batch_index * all_sequence_length + threadIdx.x : batch_index * sequence_length * all_sequence_length + sequence_index * all_sequence_length + threadIdx.x;
const int& mask = attention_mask[mask_offset];
float mask_value = mask > 0 ? 0.0f : -10000.0f;
if (is_unidirectional) {
int from_index = all_sequence_length - sequence_length + sequence_index; // offset of from token in all sequence length.
if (threadIdx.x > from_index) {
mask_value += -10000.0f;
}
}
thread_data = float(input[index]) * scalar + mask_value;
}
const float max = BlockReduce(tmp_storage).Reduce(thread_data, cub::Max(), all_sequence_length);
// Store max value
if (threadIdx.x == 0) {
max_block = max;
}
__syncthreads();
float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, cub::Sum(), all_sequence_length);
// Store value of 1.0/sum
if (threadIdx.x == 0) {
sum_reverse_block = (1.f) / sum;
}
__syncthreads();
if (threadIdx.x < all_sequence_length) {
output[index] = T(thread_data_exp * sum_reverse_block);
}
}
template <typename T, unsigned TPB>
__global__ void SoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const T* input, T* output, bool is_unidirectional) {
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, input, output, is_unidirectional);
}
template <typename T, unsigned TPB>
__global__ void SoftmaxKernel(const int all_sequence_length, const int sequence_length, const T* input, T* output) {
Softmax<T, TPB>(all_sequence_length, sequence_length, all_sequence_length, 0, input, output);
}
template <typename T>
bool ComputeSoftmax(
cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const T* input, T* output, bool is_unidirectional) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output, is_unidirectional);
} else if (!is_unidirectional) {
const int blockSize = 1024;
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, input, output);
} else {
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernelSmall(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output, bool is_unidirectional) {
__shared__ int start_position;
__shared__ int end_position;
if (threadIdx.x == 0) {
const int batch = blockIdx.y;
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
end_position = min(all_sequence_length, mask_end[batch]);
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
if (start_position >= end_position) {
start_position = 0;
end_position = all_sequence_length;
}
}
__syncthreads();
SoftmaxSmall<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, input, output, is_unidirectional);
}
template <typename T, unsigned TPB>
__global__ void MaskedSoftmaxKernel(const int all_sequence_length, const int sequence_length, const int* mask_end, const int* mask_start, const T* input, T* output) {
__shared__ int start_position;
__shared__ int end_position;
if (threadIdx.x == 0) {
const int batch = blockIdx.y;
start_position = mask_start != nullptr ? max(0, mask_start[batch]) : 0;
end_position = min(all_sequence_length, mask_end[batch]);
// Attend to no word has same effect as attend to all words. This is added to get parity with CPU result.
if (start_position >= end_position) {
start_position = 0;
end_position = all_sequence_length;
}
}
__syncthreads();
Softmax<T, TPB>(all_sequence_length, sequence_length, end_position, start_position, input, output);
}
template <typename T, unsigned TPB>
__global__ void SoftmaxWithRawMaskSmallKernel(const int all_sequence_length, const int sequence_length, const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar, const int mask_dimension) {
SoftmaxWithRawMaskSmall<T, TPB>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
}
template <typename T>
bool ComputeSoftmaxWithMask1D(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* mask_index, const int* mask_start, const T* input, T* output, const bool is_unidirectional) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
MaskedSoftmaxKernelSmall<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output, is_unidirectional);
} else if (!is_unidirectional) {
const int blockSize = 1024;
MaskedSoftmaxKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, mask_index, mask_start, input, output);
} else {
ORT_THROW("Attention CUDA operator does not support unidirectional with total sequence length > 1024.");
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T>
bool ComputeSoftmaxWithRawMask(cudaStream_t stream, const int all_sequence_length, const int sequence_length, const int batch_size, const int num_heads,
const int* attention_mask, const T* input, T* output, const bool is_unidirectional, const float scalar,
const int mask_dimension) {
const dim3 grid(sequence_length * num_heads, batch_size, 1);
if (all_sequence_length <= 32) {
const int blockSize = 32;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
} else if (all_sequence_length <= 64) {
const int blockSize = 64;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
} else if (all_sequence_length <= 128) {
const int blockSize = 128;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
} else if (all_sequence_length <= 256) {
const int blockSize = 256;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
} else if (all_sequence_length <= 512) {
const int blockSize = 512;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
} else if (all_sequence_length <= 1024) {
const int blockSize = 1024;
SoftmaxWithRawMaskSmallKernel<T, blockSize>
<<<grid, blockSize, 0, stream>>>(all_sequence_length, sequence_length, attention_mask, input, output, is_unidirectional, scalar, mask_dimension);
} else {
ORT_THROW("Attention CUDA operator does not supported 2D attention mask with total sequence length > 1024.");
}
return CUDA_CALL(cudaPeekAtLastError());
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,160 @@
/*
The implementation of this file is based on qkvToContext plugin in TensorRT demo:
https://github.com/NVIDIA/TensorRT/tree/release/5.1/demo/BERT/
Copyright 2019 NVIDIA Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "core/providers/cuda/cuda_common.h"
#include "attention_impl.h"
using namespace onnxruntime::cuda;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T>
__global__ void TransposeCtx(const int H, const T* input, T* output) {
// Input: BxNxSxH
// Output: BxSxNxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int num_heads = blockDim.y;
int sequence_length = gridDim.x;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int in_offset = s * H + n * sequence_length * H + b * NHS;
const int out_offset = n * H + s * NH + b * NHS;
const int i = threadIdx.x;
if (i < H) {
output[out_offset + i] = input[in_offset + i];
}
}
bool LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const float* input, float* output) {
const dim3 grid(sequence_length, batch_size, 1);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
const dim3 block(H, num_heads, 1);
TransposeCtx<float2><<<grid, block, 0, stream>>>(H, input2, output2);
} else {
const dim3 block(head_size, num_heads, 1);
TransposeCtx<float><<<grid, block, 0, stream>>>(head_size, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
bool LaunchTransCtx(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const half* input, half* output) {
const dim3 grid(sequence_length, batch_size, 1);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const dim3 block(H, num_heads, 1);
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
TransposeCtx<float2><<<grid, block, 0, stream>>>(H, input2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const dim3 block(H, num_heads, 1);
const half2* input2 = reinterpret_cast<const half2*>(input);
half2* output2 = reinterpret_cast<half2*>(output);
TransposeCtx<half2><<<grid, block, 0, stream>>>(H, input2, output2);
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel.
const dim3 block(head_size, num_heads, 1);
TransposeCtx<half><<<grid, block, 0, stream>>>(head_size, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
template <typename T>
__global__ void TransposeQKV(const int H, const T* input, T* output) {
// Input: BxSx3xNxH
// Output: 3xBxNxSxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
const int in_offset = n * H + m * NH + s * 3 * NH + b * NHS * 3;
const int out_offset = s * H + n * sequence_length * H + b * NHS + m * NHS * batch_size;
const int i = threadIdx.x;
if (i < H) {
output[out_offset + i] = input[in_offset + i];
}
}
bool LaunchTransQkv(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const float* input, float* output) {
const dim3 grid(sequence_length, batch_size, 3);
if (0 == (head_size & 1)) {
const int H = head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
const dim3 block(H, num_heads, 1);
TransposeQKV<float2><<<grid, block, 0, stream>>>(H, input2, output2);
} else {
const dim3 block(head_size, num_heads, 1);
TransposeQKV<float><<<grid, block, 0, stream>>>(head_size, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
bool LaunchTransQkv(cudaStream_t stream,
const int sequence_length, const int batch_size, const int head_size, const int num_heads,
const half* input, half* output) {
const dim3 grid(sequence_length, batch_size, 3);
if (0 == (head_size % 4)) {
const int H = head_size / 4;
const dim3 block(H, num_heads, 1);
const float2* input2 = reinterpret_cast<const float2*>(input);
float2* output2 = reinterpret_cast<float2*>(output);
TransposeQKV<float2><<<grid, block, 0, stream>>>(H, input2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
const dim3 block(H, num_heads, 1);
const half2* input2 = reinterpret_cast<const half2*>(input);
half2* output2 = reinterpret_cast<half2*>(output);
TransposeQKV<half2><<<grid, block, 0, stream>>>(H, input2, output2);
} else { // this should be an "odd" case. probably not worth catching it in the half2 kernel..
const dim3 block(head_size, num_heads, 1);
TransposeQKV<half><<<grid, block, 0, stream>>>(head_size, input, output);
}
return CUDA_CALL(cudaPeekAtLastError());
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -277,7 +277,8 @@ void FusedMatMulShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
void RegisterBertSchemas() {
static const char* Attention_ver1_doc = R"DOC(
Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT).
The mask_index input is optional. Besides raw attention mask with shape (batch_size, past_sequence_length + sequence_length),
The mask_index input is optional. Besides raw attention mask with shape (batch_size, past_sequence_length + sequence_length)
or (batch_size, sequence_length, past_sequence_length + sequence_length) with value 0 for masked and 1 otherwise,
we also support other two formats: When input has right-side padding, mask_index is one dimension with shape (batch_size),
where value of each element is the end position, or valid length of actual sequence excluding padding. When input has
left-side padding, mask_index has shape (2 * batch_size), where the values are the exclusive end positions followed by
@ -297,7 +298,7 @@ and present state are optional. Present state could appear in output even when p
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T")
.Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
.Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
.Input(3, "mask_index", "Attention mask with shape (batch_size, past_sequence_length + sequence_length) or (batch_size, sequence_length, past_sequence_length + sequence_length), or index with shape (batch_size) or (2 * batch_size).", "M", OpSchema::Optional)
.Input(4, "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size).", "T", OpSchema::Optional)
.Output(0, "output", "3D output tensor with shape (batch_size, append_length, hidden_size)", "T")
.Output(1, "present", "present state for key and value with shape (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)", "T", OpSchema::Optional)

View file

@ -13,12 +13,13 @@ enum MaskIndexType {
kMaskIndexEnd = 0,
kMaskIndexEndAndStart,
kMaskRaw,
kMask3D,
kMaskDummy // Dummy mask with shape [1, 1] or [batch_size, 1]
};
static void RunAttentionTest(
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size]
const std::vector<float>& weights_data, // weights: [hidden_size, 3 * hidden_size]
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size]
const std::vector<float>& weights_data, // weights: [hidden_size, 3 * hidden_size]
bool is_weights_constant,
const std::vector<float>& bias_data, // bias: [3 * hidden_size]
const std::vector<int32_t>& mask_index_data, // mask_index: [batch_size] or [batch_size, past_sequence_length + sequence_length] or empty
@ -52,6 +53,7 @@ static void RunAttentionTest(
std::vector<int64_t> mask_index_dims_2 = {2 * batch_size};
std::vector<int64_t> mask_index_dims_3 = {batch_size, past_sequence_length + sequence_length};
std::vector<int64_t> mask_index_dims_4 = {batch_size, 1};
std::vector<int64_t> mask_index_dims_5 = {batch_size, sequence_length, past_sequence_length + sequence_length};
std::vector<int64_t> mask_index_dims;
switch (mask_index_type) {
case kMaskIndexEnd:
@ -66,8 +68,11 @@ static void RunAttentionTest(
case kMaskDummy:
mask_index_dims = mask_index_dims_4;
break;
case kMask3D:
mask_index_dims = mask_index_dims_5;
break;
default:
assert(0); // shall not reach here.
assert(0); // shall not reach here.
break;
}
@ -973,6 +978,51 @@ TEST(AttentionTest, AttentionBatch2LeftPaddingMaskIndex2) {
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskIndexEndAndStart);
}
TEST(AttentionTest, Attention3DMask) {
int batch_size = 2;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.5f, 0.2f, 0.3f, -0.6f,
0.8f, -0.5f, 0.0f, 1.f,
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
// Test 3D mask BxSxS*
std::vector<int32_t> mask_index_data = {
0, 1,
0, 1,
1, 1,
1, 1};
std::vector<float> output_data = {
8.69f, -0.13f, 4.25f, 5.65f,
8.69f, -0.13f, 4.25f, 5.65f,
3.14959716796875f, 0.10843672603368759f, 4.25f, 5.65f,
3.9696791172027588f, 0.073143675923347473f, 4.25f, 5.65f};
bool use_float16 = false;
bool is_unidirectional = false;
bool use_past_state = false;
int past_sequence_length = 0;
const std::vector<float>* past_data = nullptr;
const std::vector<float>* present_data = nullptr;
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D);
}
TEST(AttentionTest, AttentionBatch2AttentionMask) {
int batch_size = 2;
int sequence_length = 2;
@ -1014,6 +1064,51 @@ TEST(AttentionTest, AttentionBatch2AttentionMask) {
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw);
}
TEST(AttentionTest, AttentionUnidirectional3DMask) {
int batch_size = 2;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.5f, 0.2f, 0.3f, -0.6f,
0.8f, -0.5f, 0.0f, 1.f,
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
// Test 3D mask BxSxS*
std::vector<int32_t> mask_index_data = {
0, 1,
0, 1,
1, 1,
1, 1};
std::vector<float> output_data = {
3.967245340f, 0.07324841f, 4.25f, 5.65f,
8.69f, -0.13f, 4.25f, 5.65f,
8.69f, -0.13f, 4.25f, 5.65f,
3.96967912f, 0.07314367f, 4.25f, 5.65f};
bool use_float16 = false;
bool is_unidirectional = true;
bool use_past_state = false;
int past_sequence_length = 0;
const std::vector<float>* past_data = nullptr;
const std::vector<float>* present_data = nullptr;
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D);
}
TEST(AttentionTest, AttentionUnidirectionalAttentionMask) {
int batch_size = 2;
int sequence_length = 2;
@ -1181,6 +1276,47 @@ TEST(AttentionTest, AttentionMask2DNoWord) {
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMaskRaw);
}
TEST(AttentionTest, AttentionMask3DNoWord) {
int batch_size = 2;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.5f, 0.2f, 0.3f, -0.6f,
0.8f, -0.5f, 0.0f, 1.f,
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
// Test that all attention masks are zero.
std::vector<int32_t> mask_index_data = {0, 0, 0, 0, 0, 0, 0, 0};
std::vector<float> output_data = {
3.96724534f, 0.07324841f, 4.25f, 5.65f,
3.14984703f, 0.10842596f, 4.25f, 5.65f,
3.14984703f, 0.10842596f, 4.25f, 5.65f,
3.96724534f, 0.07324841f, 4.25f, 5.65f};
bool use_float16 = false;
bool is_unidirectional = false;
bool use_past_state = false;
int past_sequence_length = 0;
const std::vector<float>* past_data = nullptr;
const std::vector<float>* present_data = nullptr;
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length, past_data, present_data, kMask3D);
}
TEST(AttentionTest, AttentionDummyMask2D) {
int batch_size = 2;
@ -1294,4 +1430,4 @@ TEST(AttentionTest, AttentionPastState_dynamic) {
}
} // namespace test
} // namespace onnxruntime
} // namespace onnxruntime