mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-04 04:07:22 +00:00
Support flash attention on 2d attention mask for gpt2 left padding. (#14215)
This commit is contained in:
parent
30b9f5dde1
commit
a8df6c35f8
4 changed files with 61 additions and 3 deletions
|
|
@ -91,11 +91,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
|
||||
if (is_unidirectional_) { // GPT
|
||||
// Fused kernels requires left side padding (The mask shall be sequence lengths or no mask)
|
||||
// GPT fused kernels requires left side padding. mask can be:
|
||||
// none (no padding), 1D sequence lengths or 2d mask.
|
||||
// Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token
|
||||
// where past state is empty.
|
||||
bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING;
|
||||
bool use_causal_fused_runner = !disable_fused_runner_ &&
|
||||
(nullptr == mask_index || is_mask_1d_seq_len) &&
|
||||
(nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) &&
|
||||
nullptr == extra_add_qk &&
|
||||
parameters.past_sequence_length == 0 &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
|
|
|
|||
|
|
@ -324,7 +324,11 @@ Status QkvToContext(
|
|||
|
||||
if (use_fused_kernel || use_fused_causal) {
|
||||
int* sequence_offset = reinterpret_cast<int*>(scratch1);
|
||||
LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
|
||||
if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) {
|
||||
LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
|
||||
} else {
|
||||
LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream);
|
||||
}
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
|
||||
FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast<FusedMHARunnerFP16v2*>(fused_runner);
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cuda/bert/bert_padding.h"
|
||||
#include <cub/cub.cuh>
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
|
|
@ -361,6 +362,50 @@ void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
|
|||
}
|
||||
}
|
||||
|
||||
__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK)
|
||||
getTrtSequenceOffset2d(int* trt_mha_padding_offset,
|
||||
const int* attention_masks,
|
||||
const int batch_size,
|
||||
const int sequence_length) {
|
||||
typedef cub::BlockReduce<int, kMAX_THREADS_PER_BLOCK> BlockReduce;
|
||||
__shared__ typename BlockReduce::TempStorage temp_storage;
|
||||
|
||||
const int batch_id = blockIdx.x;
|
||||
const int* batch_mask = attention_masks + (batch_id * sequence_length);
|
||||
const bool leftmost_non_zero = (batch_mask[0] != 0);
|
||||
int biggest_position = 0;
|
||||
|
||||
for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) {
|
||||
if (leftmost_non_zero == (batch_mask[i] != 0)) {
|
||||
biggest_position = i;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x);
|
||||
|
||||
if (threadIdx.x == 0) {
|
||||
int batch_offset = batch_id * sequence_length;
|
||||
trt_mha_padding_offset[2 * batch_id] = batch_offset;
|
||||
trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1;
|
||||
if (batch_id == gridDim.x - 1) {
|
||||
trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// only support simple left padding with mask 0s on leading left,
|
||||
// or simple right padding with mask 1s on leading left.
|
||||
void LaunchTrtSequenceOffset2d(int* trt_mha_padding_offset,
|
||||
const int* attention_masks,
|
||||
const int batch_size,
|
||||
const int sequence_length,
|
||||
cudaStream_t stream) {
|
||||
getTrtSequenceOffset2d<<<batch_size, kMAX_THREADS_PER_BLOCK, 0, stream>>>(
|
||||
trt_mha_padding_offset, attention_masks, batch_size, sequence_length);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -43,6 +43,13 @@ void LaunchTrtSequenceOffset(int* trt_mha_padding_offset,
|
|||
const int batch_size,
|
||||
const int sequence_length,
|
||||
cudaStream_t stream);
|
||||
|
||||
void LaunchTrtSequenceOffset2d(int* trt_mha_padding_offset,
|
||||
const int* mask_index,
|
||||
const int batch_size,
|
||||
const int sequence_length,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
Loading…
Reference in a new issue