From a8df6c35f84cd8136000df1d80e504e055ce2363 Mon Sep 17 00:00:00 2001 From: Zhang Lei Date: Tue, 17 Jan 2023 16:45:29 -0800 Subject: [PATCH] Support flash attention on 2d attention mask for gpt2 left padding. (#14215) --- .../contrib_ops/cuda/bert/attention.cc | 6 ++- .../contrib_ops/cuda/bert/attention_impl.cu | 6 ++- .../contrib_ops/cuda/bert/bert_padding.cu | 45 +++++++++++++++++++ .../contrib_ops/cuda/bert/bert_padding.h | 7 +++ 4 files changed, 61 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index bba941f1e4..4efa5b611a 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -91,11 +91,13 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN; if (is_unidirectional_) { // GPT - // Fused kernels requires left side padding (The mask shall be sequence lengths or no mask) + // GPT fused kernels requires left side padding. mask can be: + // none (no padding), 1D sequence lengths or 2d mask. // Fused kernels don't support different sequence lengths of q and kv, so only apply to the first token // where past state is empty. + bool is_mask_2d_key_padding = parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING; bool use_causal_fused_runner = !disable_fused_runner_ && - (nullptr == mask_index || is_mask_1d_seq_len) && + (nullptr == mask_index || is_mask_1d_seq_len || is_mask_2d_key_padding) && nullptr == extra_add_qk && parameters.past_sequence_length == 0 && parameters.hidden_size == parameters.v_hidden_size && diff --git a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu index 0818feedbb..848e5f5a84 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/bert/attention_impl.cu @@ -324,7 +324,11 @@ Status QkvToContext( if (use_fused_kernel || use_fused_causal) { int* sequence_offset = reinterpret_cast(scratch1); - LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + if (parameters.mask_type == AttentionMaskType::MASK_2D_KEY_PADDING) { + LaunchTrtSequenceOffset2d(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } else { + LaunchTrtSequenceOffset(sequence_offset, data.mask_index, batch_size, sequence_length, stream); + } CUDA_RETURN_IF_ERROR(cudaGetLastError()); FusedMHARunnerFP16v2* fused_fp16_runner = reinterpret_cast(fused_runner); diff --git a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu index 67ceb4b72e..2af748d8d4 100644 --- a/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu +++ b/onnxruntime/contrib_ops/cuda/bert/bert_padding.cu @@ -20,6 +20,7 @@ #include "core/providers/cuda/cuda_common.h" #include "contrib_ops/cuda/bert/bert_padding.h" +#include using namespace onnxruntime::cuda; @@ -361,6 +362,50 @@ void LaunchTrtSequenceOffset(int* trt_mha_padding_offset, } } +__global__ void __launch_bounds__(kMAX_THREADS_PER_BLOCK) + getTrtSequenceOffset2d(int* trt_mha_padding_offset, + const int* attention_masks, + const int batch_size, + const int sequence_length) { + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + + const int batch_id = blockIdx.x; + const int* batch_mask = attention_masks + (batch_id * sequence_length); + const bool leftmost_non_zero = (batch_mask[0] != 0); + int biggest_position = 0; + + for (int i = threadIdx.x; i < sequence_length; i += blockDim.x) { + if (leftmost_non_zero == (batch_mask[i] != 0)) { + biggest_position = i; + } else { + break; + } + } + + int last_leading_position = BlockReduce(temp_storage).Reduce(biggest_position, cub::Max(), blockDim.x); + + if (threadIdx.x == 0) { + int batch_offset = batch_id * sequence_length; + trt_mha_padding_offset[2 * batch_id] = batch_offset; + trt_mha_padding_offset[2 * batch_id + 1] = batch_offset + last_leading_position + 1; + if (batch_id == gridDim.x - 1) { + trt_mha_padding_offset[2 * batch_id + 2] = batch_offset + sequence_length; + } + } +} + +// only support simple left padding with mask 0s on leading left, +// or simple right padding with mask 1s on leading left. +void LaunchTrtSequenceOffset2d(int* trt_mha_padding_offset, + const int* attention_masks, + const int batch_size, + const int sequence_length, + cudaStream_t stream) { + getTrtSequenceOffset2d<<>>( + trt_mha_padding_offset, attention_masks, batch_size, sequence_length); +} + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/bert/bert_padding.h b/onnxruntime/contrib_ops/cuda/bert/bert_padding.h index 4e5d30b3d3..331caefeea 100644 --- a/onnxruntime/contrib_ops/cuda/bert/bert_padding.h +++ b/onnxruntime/contrib_ops/cuda/bert/bert_padding.h @@ -43,6 +43,13 @@ void LaunchTrtSequenceOffset(int* trt_mha_padding_offset, const int batch_size, const int sequence_length, cudaStream_t stream); + +void LaunchTrtSequenceOffset2d(int* trt_mha_padding_offset, + const int* mask_index, + const int batch_size, + const int sequence_length, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime