Support flash attention on 2d attention mask for gpt2 left padding. (#14215)

This commit is contained in:
Zhang Lei 2023-01-17 16:45:29 -08:00 committed by GitHub
parent 30b9f5dde1
commit a8df6c35f8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 61 additions and 3 deletions

View file

@ -91,11 +91,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
if (is_unidirectional_) { // GPT
// Fused kernels requires left side padding (The mask shall be sequence lengths or no mask)
// GPT fused kernels requires left side padding. mask can be:
// none (no padding), 1D sequence lengths or 2d mask.
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
// where past state is empty.
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
bool use_causal_fused_runner = !disable_fused_runner_ &&
(nullptr == mask_index || is_mask_1d_seq_len) &&
(nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
nullptr == extra_add_qk &&
parameters.past_sequence_length == 0 &&
parameters.hidden_size == parameters.v_hidden_size &&

View file

@ -324,7 +324,11 @@ Status QkvToContext(
if (use_fused_kernel || use_fused_causal) {
int* sequence_offset = reinterpret_cast<int*>(scratch1);
LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
} else {
LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
}
CUDA_RETURN_IF_ERROR(cudaGetLastError());
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);

View file

@ -20,6 +20,7 @@
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
#include <cub/cub.cuh>
using namespace onnxruntime::cuda;
@ -361,6 +362,50 @@ void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
}
}
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
getTrtSequenceOffset2d(int* trt_mha_padding_offset,
const int* attention_masks,
const int batch_size,
const int sequence_length) {
typedef cub::BlockReduce<int, kMAX_THREADS_PER_BLOCK> BlockReduce;
__shared__ typename BlockReduce::TempStorage temp_storage;
const int batch_id = blockIdx.x;
const int* batch_mask = attention_masks + (batch_id * sequence_length);
const bool leftmost_non_zero = (batch_mask[0] != 0);
int biggest_position = 0;
for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) {
if (leftmost_non_zero == (batch_mask[i] != 0)) {
biggest_position = i;
} else {
break;
}
}
int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);
if (threadIdx.x == 0) {
int batch_offset = batch_id * sequence_length;
trt_mha_padding_offset[2 * batch_id] = batch_offset;
trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1;
if (batch_id == gridDim.x - 1) {
trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length;
}
}
}
// only support simple left padding with mask 0s on leading left,
// or simple right padding with mask 1s on leading left.
void LaunchTrtSequenceOffset2d(int* trt_mha_padding_offset,
const int* attention_masks,
const int batch_size,
const int sequence_length,
cudaStream_t stream) {
getTrtSequenceOffset2d<<<batch_size, kMAX_THREADS_PER_BLOCK, 0, stream>>>(
trt_mha_padding_offset, attention_masks, batch_size, sequence_length);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -43,6 +43,13 @@ void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
const int batch_size,
const int sequence_length,
cudaStream_t stream);
void LaunchTrtSequenceOffset2d(int* trt_mha_padding_offset,
const int* mask_index,
const int batch_size,
const int sequence_length,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime