mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-18 21:21:17 +00:00
Change mask_index input of Attention op to be optional (#3459)
Change Mask Index to optional
This commit is contained in:
parent
7f6e407e09
commit
54bbbb78ae
6 changed files with 133 additions and 46 deletions
|
|
@ -35,7 +35,7 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
|
|||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
// Input 2 - bias : (3 * hidden_size)
|
||||
// Input 3 - mask_index : (batch_size)
|
||||
// Input 3 - mask_index : (batch_size) if presented
|
||||
// Output : (batch_size, sequence_length, hidden_size)
|
||||
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
|
|
@ -77,13 +77,15 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
const auto mask_dims = mask_index->Shape().GetDims();
|
||||
if (mask_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 is expected to have 1 dimension, got ",
|
||||
mask_dims.size());
|
||||
}
|
||||
if (static_cast<int>(mask_dims[0]) != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 3 and 0 shall have same length at dimension 0");
|
||||
if (mask_index != nullptr) { // mask_index is optional
|
||||
const auto mask_dims = mask_index->Shape().GetDims();
|
||||
if (mask_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 is expected to have 1 dimension, got ",
|
||||
mask_dims.size());
|
||||
}
|
||||
if (static_cast<int>(mask_dims[0]) != batch_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Inputs 3 and 0 shall have same length at dimension 0");
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
@ -179,22 +181,36 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
|
||||
// STEP.2: scratch(B, N, S, S) = 1/sqrt(H) x Q(B, N, S, H) x K'(B, N, S, H -> B, N, H, S) + 1 x mask_index(B -> B, 1,
|
||||
// 1, 1)
|
||||
auto scratch_data =
|
||||
allocator->Alloc(SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * sequence_length * element_size);
|
||||
size_t scratch_data_bytes = SafeInt<size_t>(batch_size) * num_heads_ * sequence_length * sequence_length * element_size;
|
||||
auto scratch_data = allocator->Alloc(scratch_data_bytes);
|
||||
BufferUniquePtr scratch_buffer(scratch_data, BufferDeleter(allocator));
|
||||
|
||||
{
|
||||
auto scratch_broadcast_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * element_size);
|
||||
BufferUniquePtr scratch_broadcast_buffer(scratch_broadcast_data, BufferDeleter(allocator));
|
||||
memset(scratch_broadcast_data, 0, batch_size * sequence_length * element_size);
|
||||
T* p_scratch_broadcast_current_data = reinterpret_cast<T*>(scratch_broadcast_data);
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
// TODO: mask_index can be used in softmax to save some calculation.
|
||||
int mask = mask_index->template Data<int32_t>()[b_i];
|
||||
for (int m_i = mask; m_i < sequence_length; m_i++) {
|
||||
p_scratch_broadcast_current_data[m_i] = static_cast<T>(-10000.0);
|
||||
size_t mask_data_bytes = 0;
|
||||
if (mask_index != nullptr) {
|
||||
mask_data_bytes = SafeInt<size_t>(batch_size) * sequence_length * element_size;
|
||||
}
|
||||
|
||||
void* mask_data = nullptr;
|
||||
if (mask_data_bytes > 0) {
|
||||
mask_data = allocator->Alloc(mask_data_bytes);
|
||||
memset(mask_data, 0, mask_data_bytes);
|
||||
}
|
||||
BufferUniquePtr mask_data_buffer(mask_data, BufferDeleter(allocator));
|
||||
|
||||
if (mask_index != nullptr) {
|
||||
T* p_mask = reinterpret_cast<T*>(mask_data);
|
||||
for (int b_i = 0; b_i < batch_size; b_i++) {
|
||||
// TODO: mask_index can be used in softmax to save some calculation.
|
||||
// Convert mask_index to mask (-10000 means out of range, which will be 0 after softmax): B => BxS
|
||||
int valid_length = mask_index->template Data<int32_t>()[b_i];
|
||||
for (int m_i = valid_length; m_i < sequence_length; m_i++) {
|
||||
p_mask[m_i] = static_cast<T>(-10000.0);
|
||||
}
|
||||
p_mask += sequence_length;
|
||||
}
|
||||
p_scratch_broadcast_current_data += sequence_length;
|
||||
} else {
|
||||
memset(scratch_data, 0, scratch_data_bytes);
|
||||
}
|
||||
|
||||
const int loop_len = batch_size * num_heads_;
|
||||
|
|
@ -206,12 +222,15 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
ThreadPool::TryParallelFor(tp, loop_len, cost, [&](std::ptrdiff_t begin, std::ptrdiff_t end) {
|
||||
for (std::ptrdiff_t i = begin; i != end; ++i) {
|
||||
const std::ptrdiff_t batch_index = i / num_heads_;
|
||||
|
||||
// broadcast masks (B) -> (B.N.)S.S
|
||||
const T* broadcast_data_src = reinterpret_cast<T*>(scratch_broadcast_data) + batch_index * sequence_length;
|
||||
T* broadcast_data_dest = reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i;
|
||||
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
|
||||
memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T));
|
||||
broadcast_data_dest += sequence_length;
|
||||
if (mask_index != nullptr) {
|
||||
const T* broadcast_data_src = reinterpret_cast<T*>(mask_data) + batch_index * sequence_length;
|
||||
T* broadcast_data_dest = reinterpret_cast<T*>(scratch_data) + sequence_length * sequence_length * i;
|
||||
for (int seq_index = 0; seq_index < sequence_length; seq_index++) {
|
||||
memcpy(broadcast_data_dest, broadcast_data_src, sequence_length * sizeof(T));
|
||||
broadcast_data_dest += sequence_length;
|
||||
}
|
||||
}
|
||||
|
||||
// gemm
|
||||
|
|
|
|||
|
|
@ -40,7 +40,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
// Input 2 - bias : (3 * hidden_size)
|
||||
// Input 3 - mask_index : (batch_size)
|
||||
// Input 3 - mask_index : (batch_size) if presented
|
||||
// Output : (batch_size, sequence_length, hidden_size)
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* weights = context->Input<Tensor>(1);
|
||||
|
|
@ -88,7 +88,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
auto temp_buffer = GetScratchBuffer<void>(workSpaceSize);
|
||||
if (!LaunchAttentionKernel(
|
||||
reinterpret_cast<const CudaT*>(gemm_buffer.get()),
|
||||
mask_index->template Data<int>(),
|
||||
nullptr == mask_index ? nullptr : mask_index->template Data<int>(),
|
||||
output->template MutableData<T>(),
|
||||
batch_size,
|
||||
sequence_length,
|
||||
|
|
|
|||
|
|
@ -152,6 +152,38 @@ __device__ inline void SoftmaxSmall(const int ld, const int num_valid, const T*
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxKernelSmall(const int sequence_length, const T* input, T* output) {
|
||||
SoftmaxSmall<T, TPB>(sequence_length, sequence_length, input, output);
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void SoftmaxKernel(const int sequence_length, const T* input, T* output) {
|
||||
Softmax<T, TPB>(sequence_length, sequence_length, input, output);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool ComputeSoftmax(
|
||||
cudaStream_t stream, const int sequence_length, const int batch_size, const int num_heads,
|
||||
const T* input, T* output) {
|
||||
const dim3 grid(sequence_length * num_heads, batch_size, 1);
|
||||
if (sequence_length <= 32) {
|
||||
const int blockSize = 32;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(sequence_length, input, output);
|
||||
} else if (sequence_length <= 128) {
|
||||
const int blockSize = 128;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(sequence_length, input, output);
|
||||
} else if (sequence_length == 384) {
|
||||
const int blockSize = 384;
|
||||
SoftmaxKernelSmall<T, blockSize><<<grid, blockSize, 0, stream>>>(sequence_length, input, output);
|
||||
} else {
|
||||
const int blockSize = 256;
|
||||
SoftmaxKernel<T, blockSize><<<grid, blockSize, 0, stream>>>(sequence_length, input, output);
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError());
|
||||
}
|
||||
|
||||
template <typename T, unsigned TPB>
|
||||
__global__ void MaskedSoftmaxKernelSmall(const int sequence_length, const int* mask_index, const T* input, T* output) {
|
||||
__shared__ int num_valid;
|
||||
|
|
@ -390,8 +422,14 @@ bool QkvToContext(
|
|||
}
|
||||
|
||||
// apply softmax and store result P to scratch2: BxNxSxS
|
||||
if (!ComputeMaskedSoftmax<T>(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) {
|
||||
return false;
|
||||
if (nullptr != mask_index) {
|
||||
if (!ComputeMaskedSoftmax<T>(stream, sequence_length, batch_size, num_heads, mask_index, scratch1, scratch2)) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
if (!ComputeSoftmax<T>(stream, sequence_length, batch_size, num_heads, scratch1, scratch2)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// compute P*V (as V*P), and store in scratch3: BxNxSxH
|
||||
|
|
|
|||
|
|
@ -7,20 +7,20 @@
|
|||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
size_t GetAttentionWorkspaceSize(size_t element_size, int batchsize, int num_heads, int head_size, int sequence_length);
|
||||
size_t GetAttentionWorkspaceSize(size_t element_size, int batchsize, int num_heads, int head_size, int sequence_length);
|
||||
|
||||
bool LaunchAttentionKernel(
|
||||
const void* input, // Input tensor
|
||||
const int* mask_index, // Nask index where each element is length of a sequence
|
||||
void* output, // Output tensor
|
||||
int batch_size, // Batch size (B)
|
||||
int sequence_length, // Sequence length (S)
|
||||
int num_heads, // Number of attention heads (N)
|
||||
int head_size, // Hidden layer size per head (H)
|
||||
void* workspace, // Temporary buffer
|
||||
cublasHandle_t& cublas, // Cublas handle
|
||||
const size_t element_size // Element size of input tensor
|
||||
);
|
||||
bool LaunchAttentionKernel(
|
||||
const void* input, // Input tensor
|
||||
const int* mask_index, // Mask index (length of each sequence). NULL means no mask.
|
||||
void* output, // Output tensor
|
||||
int batch_size, // Batch size (B)
|
||||
int sequence_length, // Sequence length (S)
|
||||
int num_heads, // Number of attention heads (N)
|
||||
int head_size, // Hidden layer size per head (H)
|
||||
void* workspace, // Temporary buffer
|
||||
cublasHandle_t& cublas, // Cublas handle
|
||||
const size_t element_size // Element size of input tensor
|
||||
);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -299,7 +299,7 @@ void RegisterBertSchemas() {
|
|||
.Input(0, "input", "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", "T")
|
||||
.Input(1, "weight", "2D input tensor with shape (hidden_size, 3 * hidden_size)", "T")
|
||||
.Input(2, "bias", "1D input tensor with shape (3 * hidden_size)", "T")
|
||||
.Input(3, "mask_index", "Attention mask index with shape (batch_size)", "M")
|
||||
.Input(3, "mask_index", "Attention mask index with shape (batch_size).", "M", OpSchema::Optional)
|
||||
.Output(0, "output", "3D output tensor with shape (batch_size, sequence_length, hidden_size)", "T")
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types")
|
||||
|
|
@ -2325,7 +2325,6 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
|
|||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
RegisterBertSchemas();
|
||||
|
||||
}
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -39,16 +39,18 @@ static void RunAttentionTest(
|
|||
tester.AddInput<MLFloat16>("input", input_dims, ToFloat16(input_data));
|
||||
tester.AddInput<MLFloat16>("weight", weights_dims, ToFloat16(weights_data));
|
||||
tester.AddInput<MLFloat16>("bias", bias_dims, ToFloat16(bias_data));
|
||||
tester.AddInput<int32_t>("mask_index", mask_index_dims, mask_index_data);
|
||||
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
|
||||
} else {
|
||||
tester.AddInput<float>("input", input_dims, input_data);
|
||||
tester.AddInput<float>("weight", weights_dims, weights_data);
|
||||
tester.AddInput<float>("bias", bias_dims, bias_data);
|
||||
tester.AddInput<int32_t>("mask_index", mask_index_dims, mask_index_data);
|
||||
tester.AddOutput<float>("output", output_dims, output_data);
|
||||
}
|
||||
|
||||
if (mask_index_data.size() > 0) { // mask index is optional.
|
||||
tester.AddInput<int32_t>("mask_index", mask_index_dims, mask_index_data);
|
||||
}
|
||||
|
||||
tester.Run();
|
||||
}
|
||||
}
|
||||
|
|
@ -204,5 +206,34 @@ TEST(AttentionTest, AttentionMaskExceedSequence) {
|
|||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionNoMaskIndex) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// No mask_index
|
||||
std::vector<int32_t> mask_index_data = {};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
|
||||
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue