From ed7ab1660dc59a91d33ccf0cc71da2a672469ddd Mon Sep 17 00:00:00 2001 From: Hariharan Seshadri Date: Wed, 15 Mar 2023 17:16:32 -0700 Subject: [PATCH] [CUDA] Add option to use DecoderMaskedMultiheadAttention in BeamSearch (#14990) --- docs/ContribOperators.md | 16 +- docs/OperatorKernels.md | 2 +- .../cpu/transformers/beam_search.cc | 12 +- .../cpu/transformers/beam_search.h | 6 + .../cpu/transformers/beam_search_impl_base.h | 45 +++-- .../cpu/transformers/beam_search_impl_gpt.h | 112 ++++++++++-- .../cpu/transformers/beam_search_impl_t5.h | 3 + .../cpu/transformers/beam_search_scorer.h | 2 +- .../transformers/generation_device_helper.cc | 49 ++--- .../transformers/generation_device_helper.h | 14 +- .../cpu/transformers/generation_shared.h | 7 + .../transformers/greedy_search_impl_base.h | 7 +- .../cpu/transformers/greedy_search_impl_gpt.h | 7 +- .../cpu/transformers/subgraph_base.cc | 3 + .../cpu/transformers/subgraph_gpt.cc | 45 ++++- .../cpu/transformers/subgraph_gpt.h | 9 +- .../decoder_masked_multihead_attention.cc | 50 ++++-- ...decoder_masked_multihead_attention_impl.cu | 38 ++-- .../cuda/transformers/beam_search.cc | 8 +- .../cuda/transformers/generation_cuda_impl.cu | 63 ++++++- .../cuda/transformers/generation_cuda_impl.h | 24 ++- .../transformers/generation_device_helper.cc | 86 +++++++-- .../transformers/generation_device_helper.h | 7 +- .../core/graph/contrib_ops/bert_defs.cc | 25 ++- .../tools/transformers/convert_generation.py | 167 +++++++++++++----- .../python/transformers/test_generation.py | 12 ++ 26 files changed, 645 insertions(+), 174 deletions(-) diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 5862716638..66afbebae5 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -1133,7 +1133,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs (3 - 7) +#### Inputs (7 - 9)
input : T
@@ -1144,20 +1144,24 @@ This version of the operator has been available since version 1 of the 'com.micr
Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection
mask_index (optional) : M
Mask values of shape (batch_size, total_sequence_length)
-
past (optional) : T
-
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
+
past : T
+
past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size). The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
relative_position_bias (optional) : T
additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
-
past_sequence_length (optional) : M
+
past_sequence_length : M
When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
+
beam_width (optional) : M
+
The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
+
cache_indirection (optional) : M
+
A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
-#### Outputs (1 - 2) +#### Outputs
output : T
3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
-
present (optional) : T
+
present : T
past state for key and value with shape (2, batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 1e687e1cb4..3bd890e5b4 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -798,7 +798,7 @@ Do not modify directly.* |ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)| |ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)| -|DecoderMaskedMultiheadAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| +|DecoderMaskedMultiheadAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)| |DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)| |DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)| |EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index c839bb4245..de2513789c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -163,6 +163,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { if (has_init_decoder_) { ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute."); ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); + ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_, + "past_present_share_buffer mode must be same for init decoder and decoder subgraphes"); } concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); @@ -181,12 +183,15 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { thread_pool, ctx->GetComputeStream(), dumper_, parameters, GenerationCpuDeviceHelper::CreateGptInputs, add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits, init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState, device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy, - update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; + update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds, + cuda_device_prop_, + cuda_device_arch_}; ORT_RETURN_IF_ERROR(impl.Initialize()); return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); @@ -200,12 +205,15 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { thread_pool, ctx->GetComputeStream(), dumper_, parameters, GenerationCpuDeviceHelper::CreateGptInputs, add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, process_logits_fp16_func_, init_beam_state_fp16_func_, device_copy_func_, device_copy_int32_func_, - update_gpt_feeds_fp16_func_}; + update_gpt_feeds_fp16_func_, + cuda_device_prop_, + cuda_device_arch_}; ORT_RETURN_IF_ERROR(impl.Initialize()); return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index fe11a8a2ba..37b94bd8f1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -44,6 +44,7 @@ class BeamSearch : public IControlFlowKernel { // device helpers that is same for both GPT and encoder-decoder models. void SetDeviceHelpers( + const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, const GenerationDeviceHelper::TopkFunc& topk_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, @@ -52,6 +53,7 @@ class BeamSearch : public IControlFlowKernel { const GenerationDeviceHelper::ProcessLogitsFunc& process_logits_fp16_func, const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_func, const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_fp16_func) { + reorder_past_state_func_ = reorder_past_state_func; add_to_feeds_func_ = add_to_feeds_func; topk_func_ = topk_func; device_copy_func_ = device_copy_func; @@ -83,8 +85,12 @@ class BeamSearch : public IControlFlowKernel { expand_buffer_float16_func_ = expand_buffer_float16_func; } + const void* cuda_device_prop_ = nullptr; + int cuda_device_arch_ = 0; + private: // Device specific functions + GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::TopkFunc topk_func_; GenerationDeviceHelper::DeviceCopyFunc device_copy_func_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index 490a68d240..75d161a2cd 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -20,6 +20,9 @@ struct BeamSearchState : public IBeamSearchState { int vocab_size, int sequence_length, int max_length, + int num_heads, + int head_size, + int has_decoder_masked_multihead_attention, bool output_scores, bool use_position) { size_t batch_beam_size = SafeInt(batch_size) * num_beams; @@ -49,6 +52,21 @@ struct BeamSearchState : public IBeamSearchState { this->scores = AllocateBuffer(allocator, scores_buffer_, elements); this->remaining_scores = this->scores; } + + if (has_decoder_masked_multihead_attention) { + // We need a temp staging buffer to do the past 'K' state re-ordering that is needed + // when using DecoderMaskedMultiheadAttention + TensorShape staging_for_past_state_reorder_buffer_shape = {static_cast(batch_beam_size), num_heads, max_length, head_size}; + + Tensor temp(DataTypeImpl::GetType(), staging_for_past_state_reorder_buffer_shape, allocator); + + this->staging_for_past_state_reorder = std::move(temp); + + // We need a buffer on GPU to hold the final chosen indices after BeamScorer has finished processing + // TODO: This is a temporary work-around as BeamScorer currently only runs on CPU. + // We can remove these kinds of work-arounds once BeamScorer runs on CUDA eventually. + this->chosen_indices = AllocateBuffer(allocator, chosen_indices_buffer_, batch_beam_size); + } } private: @@ -61,6 +79,7 @@ struct BeamSearchState : public IBeamSearchState { BufferUniquePtr beam_scores_buffer_; BufferUniquePtr scores_buffer_; BufferUniquePtr topk_temp_buffer_; + BufferUniquePtr chosen_indices_buffer_; }; struct BeamSearchCpuState : public IBeamSearchCpuState { @@ -124,7 +143,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState { // Base class of beam search implementation that is common for both GPT-2 and T5. template -class BeamSearchBase : public GenerateBase { +class BeamSearchBase : public GenerateBase { public: BeamSearchBase(OpKernelContextInternal& context, const SessionState& decoder_session_state, @@ -136,13 +155,13 @@ class BeamSearchBase : public GenerateBase { const GenerationDeviceHelper::ProcessLogitsFunc& process_logits_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_int32_func) - : GenerateBase(context, - decoder_session_state, - thread_pool, - ort_stream, - cuda_dumper, - topk_func, - device_copy_func), + : GenerateBase(context, + decoder_session_state, + thread_pool, + ort_stream, + cuda_dumper, + topk_func, + device_copy_func), parameters_(¶ms), process_logits_func_(process_logits_func), device_copy_int32_func_(device_copy_int32_func) { @@ -188,11 +207,11 @@ Status BeamSearchBase::CheckInputs(const OpKernelContextInternal& context) { // input_ids : (batch_size, sequence_length) // vocab_mask : (vocab_size) or nullptr ORT_RETURN_IF_ERROR(this->CheckInputsImpl(parameters_, - context.Input(0), // input_ids - context.Input(7), // vocab_mask - context.Input(8), // prefix_vocab_mask - context.Input(9), // attention_mask - nullptr)); // presence_mask + context.Input(0), // input_ids + context.Input(7), // vocab_mask + context.Input(8), // prefix_vocab_mask + context.Input(9), // attention_mask + nullptr)); // presence_mask return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 46697a202d..dc6d33da99 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -27,12 +27,15 @@ class BeamSearchGpt : public BeamSearchBase { BeamSearchParameters& params, const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func, const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, const GenerationDeviceHelper::TopkFunc& topk_func, const GenerationDeviceHelper::ProcessLogitsFunc& process_logits_func, const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_int32_func, - const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func) + const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func, + const void* cuda_device_prop, + int cuda_device_arch) : BeamSearchBase(context, decoder_session_state, thread_pool, ort_stream, cuda_dumper, params, topk_func, process_logits_func, device_copy_func, device_copy_int32_func), @@ -42,7 +45,17 @@ class BeamSearchGpt : public BeamSearchBase { create_inputs_func_(create_inputs_func), add_to_feeds_func_(add_to_feeds_func), init_beam_state_func_(init_beam_state_func), - update_feeds_func_(update_feeds_func) { + reorder_past_state_func_(reorder_past_state_func), + update_feeds_func_(update_feeds_func), + cuda_device_prop_(cuda_device_prop), + cuda_device_arch_(cuda_device_arch) { + if (gpt_subgraph_.has_decoder_masked_multihead_attention_) { + ORT_ENFORCE(cuda_device_arch_ >= 530, + "Decoder masked multihead attention can only be used on " + "GPU cards of compute capability 5.3 or higher. " + "This card has compute capability ", + cuda_device_arch_); + } } // Execute beam search in iterations util stopping criteria is reached. @@ -55,7 +68,8 @@ class BeamSearchGpt : public BeamSearchBase { Status CreateInitialFeeds(gsl::span& sequence_lengths, OrtValue& expanded_input_ids, std::vector& feeds, - IAllocatorUniquePtr& buffer); + IAllocatorUniquePtr& buffer, + bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention); // Update the input for next iteration. Status UpdateFeeds( @@ -65,7 +79,11 @@ class BeamSearchGpt : public BeamSearchBase { OrtValue& position_ids, bool increase_position, gsl::span beam_next_tokens, - gsl::span beam_indices); + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, + int past_sequence_length, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention); const SessionState* init_run_decoder_session_state_ = nullptr; GptSubgraph* init_run_gpt_subgraph_ = nullptr; @@ -75,14 +93,19 @@ class BeamSearchGpt : public BeamSearchBase { GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_; GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_; + GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_; + + const void* cuda_device_prop_ = nullptr; + int cuda_device_arch_ = 0; }; template Status BeamSearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths, OrtValue& expanded_input_ids, std::vector& feeds, - IAllocatorUniquePtr& buffer) { + IAllocatorUniquePtr& buffer, + bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) { const OrtValue* input_ids_value = this->context_.GetInputOrtValue(0); const Tensor& input_ids = input_ids_value->Get(); const OrtValue* attn_mask_value = this->context_.GetInputOrtValue(9); @@ -99,7 +122,9 @@ Status BeamSearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths this->create_inputs_func_, this->add_to_feeds_func_, buffer, - this->ort_stream_); + this->ort_stream_, + this->parameters_->max_length, + add_beam_search_specific_inputs_for_decoder_masked_multihead_attention); } return gpt_subgraph_.CreateInitialFeeds(input_ids, @@ -113,7 +138,9 @@ Status BeamSearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths this->create_inputs_func_, this->add_to_feeds_func_, buffer, - this->ort_stream_); + this->ort_stream_, + this->parameters_->max_length, + add_beam_search_specific_inputs_for_decoder_masked_multihead_attention); } template @@ -124,7 +151,11 @@ Status BeamSearchGpt::UpdateFeeds( OrtValue& position_ids, bool increase_position, gsl::span beam_next_tokens, - gsl::span beam_indices) { + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, + int past_sequence_length, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) { return update_feeds_func_(this->temp_space_allocator_, this->ort_stream_, last_outputs, @@ -133,12 +164,15 @@ Status BeamSearchGpt::UpdateFeeds( position_ids, increase_position, beam_next_tokens, - beam_indices, + beam_indices_cpu, + beam_indices_gpu, this->parameters_->num_beams, gpt_subgraph_.GetFirstPastInputIndex(), gpt_subgraph_.GetFirstPresentOutputIndex(), - false, - -1); + gpt_subgraph_.past_present_share_buffer_, + past_sequence_length, + input_sequence_len, + has_beam_search_specific_inputs_for_decoder_masked_multihead_attention); } template @@ -192,7 +226,22 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch // buffer in GPU for input_ids, position_ids and attention_mask IAllocatorUniquePtr buffer; OrtValue expanded_input_ids_in_cpu; - ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer)); + ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer, + gpt_subgraph_.has_decoder_masked_multihead_attention_)); + + if (gpt_subgraph_.past_present_share_buffer_) { // Reuse past and present + fetches.reserve(static_cast(gpt_subgraph_.GetFirstPresentOutputIndex()) + gpt_subgraph_.num_layers); + fetches.resize(gpt_subgraph_.GetFirstPresentOutputIndex(), OrtValue()); + for (int layer = 0; layer < gpt_subgraph_.num_layers; layer++) { + int feed_idx = gpt_subgraph_.GetFirstPastInputIndex() + layer; + OrtValue& past_tensor_value = feeds[feed_idx]; + Tensor* past_tensor = past_tensor_value.GetMutable(); + OrtValue present_tensor_value; + Tensor::InitOrtValue(past_tensor->DataType(), past_tensor->Shape(), past_tensor->MutableData(), + past_tensor->Location(), present_tensor_value); + fetches.push_back(present_tensor_value); + } + } BeamSearchState beam_state; constexpr bool use_position = true; @@ -202,6 +251,9 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch parameters->vocab_size, parameters->sequence_length, parameters->max_length, + parameters->num_heads, + parameters->head_size, + gpt_subgraph_.has_decoder_masked_multihead_attention_, parameters->output_scores, use_position); @@ -297,17 +349,49 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch // Increase sequence length after a new token is generated. ++current_length; + // Reorder past state after first run if the GPT subgraph (the one used after the first iteration) + // contains DecoderMaskedMultiheadAttention nodes + if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_multihead_attention_) { + size_t offset = static_cast(gpt_subgraph_.GetFirstPresentOutputIndex()); + // We will use the same staging buffer while transposing all the layers' past state + // and this is okay because we use the same stream to do the staging copy and the transpose + // operations. + // If we ever do them in different streams, we must use different staging buffers to avoid data + // races. + for (size_t i = 0; i < static_cast(gpt_subgraph_.num_layers); ++i) { + ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_, + *fetches[offset + i].GetMutable(), + beam_state.staging_for_past_state_reorder, + this->ort_stream_)); + } + } + // Prepare inputs for next round of subgraph call. if (current_length < parameters->max_length) { + gsl::span place_holder; // For the first iteration, position_ids is initialized as sequence lengths. We can add it to feeds directly. // For the remaining iterations, we need increase position_ids first, then add it to feeds. bool increase_position = (iteration_counter > 1); ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length, position_ids, increase_position, ReinterpretAsSpan(beam_next_tokens), - ReinterpretAsSpan(beam_indices))); + ReinterpretAsSpan(beam_indices), + gpt_subgraph_.has_decoder_masked_multihead_attention_ + ? ReinterpretAsSpan(beam_state.chosen_indices) + : place_holder, + current_length - 1, + parameters->sequence_length, + gpt_subgraph_.has_decoder_masked_multihead_attention_)); + } + + if (gpt_subgraph_.past_present_share_buffer_) { + // clear fetched values before presents[] + for (int idx = 0; idx < gpt_subgraph_.GetFirstPresentOutputIndex(); idx++) { + fetches[idx] = OrtValue(); + } + } else { + fetches.clear(); } - fetches.clear(); } gsl::span final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index 9841f425d2..562f10c615 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -193,6 +193,9 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches parameters->vocab_size, parameters->sequence_length, parameters->max_length, + parameters->num_heads, + parameters->head_size, + false, // TODO: Support past/present state buffer re-use for T5 parameters->output_scores, use_position); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index e2cd625093..fdd404efa0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -94,7 +94,7 @@ class BeamSearchScorer : public IBeamScorer { gsl::span& GetNextScores() { return next_beam_scores_; } gsl::span& GetNextTokens() { return next_beam_tokens_; } - gsl::span& GetNextIndices() { return next_beam_indices_; } + gsl::span& GetNextIndices() override { return next_beam_indices_; } private: size_t batch_size_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 17c7f1e6c6..a9cde364a3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -407,18 +407,18 @@ Status ProcessLogits(const OrtValue& logits, // template Status GreedySearchProcessLogits( - const OrtValue& logits, // logits output of subgraph - transformers::IGreedySearchState* greedy_state, // state - transformers::ISamplingState* sampling_state, // sampling_state - transformers::ISequences* sequences, // sequences - AllocatorPtr& allocator, // default allocator - onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) - transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IGenerationParameters* parameters, // parameters - bool do_sampling, // whether to do sampling - int step, // iteration counter - Stream* stream, // cuda stream (for CUDA only) - const transformers::IConsoleDumper* dumper) { // tensor dumper + const OrtValue& logits, // logits output of subgraph + transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingState* sampling_state, // sampling_state + transformers::ISequences* sequences, // sequences + AllocatorPtr& allocator, // default allocator + onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) + transformers::ILogitsProcessorList* logits_processors, // logits processors + const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling + int step, // iteration counter + Stream* stream, // cuda stream (for CUDA only) + const transformers::IConsoleDumper* dumper) { // tensor dumper int batch_size = parameters->batch_size; int vocab_size = parameters->vocab_size; @@ -499,8 +499,8 @@ Status GreedySearchProcessLogits( topk_indices)); #ifdef DEBUG_GENERATION - dumper->Print("topk_scores", topk_scores); - dumper->Print("topk_indices", topk_indices); + dumper->Print("topk_scores", topk_scores); + dumper->Print("topk_indices", topk_indices); #endif gsl::span next_token_indices = topk_indices.DataAsSpan(); @@ -574,15 +574,21 @@ Status UpdateGptFeeds( OrtValue& position_ids, bool increase_position, gsl::span beam_next_tokens, - gsl::span beam_indices, + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, int num_beams, int gpt_subgraph_first_past_input_idx, int gpt_subgraph_first_present_output_idx, bool past_present_share_buffer, - int past_sequence_len) { + int past_sequence_len, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) { // last_outputs: logits, present_0, present_1, ... // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 ORT_UNUSED_PARAMETER(stream); + ORT_UNUSED_PARAMETER(beam_indices_gpu); + ORT_UNUSED_PARAMETER(input_sequence_len); + ORT_UNUSED_PARAMETER(has_beam_search_specific_inputs_for_decoder_masked_multihead_attention); // The following updates inputs for subgraph @@ -631,14 +637,14 @@ Status UpdateGptFeeds( return Status::OK(); } - if (num_beams == 1) { // Update past state + if (num_beams == 1) { // Update past state // feed present_* output to past_* inputs one by one const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx; for (size_t i = gpt_subgraph_first_present_output_idx; i < last_outputs.size(); ++i) { next_inputs[i + k] = last_outputs[i]; } } else { - PickGptPastState(last_outputs, next_inputs, beam_indices, + PickGptPastState(last_outputs, next_inputs, beam_indices_cpu, gpt_subgraph_first_past_input_idx, gpt_subgraph_first_present_output_idx, allocator); } @@ -890,12 +896,15 @@ template Status UpdateGptFeeds( OrtValue& position_ids, bool increase_position, gsl::span beam_next_tokens, - gsl::span beam_indices, + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, int num_beams, int gpt_subgraph_first_past_input_idx, int gpt_subgraph_first_present_output_idx, bool past_present_share_buffer, - int past_sequence_len); + int past_sequence_len, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention); template Status UpdateDecoderFeeds( AllocatorPtr allocator, diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 473f194205..6a9b2e93ec 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -128,12 +128,15 @@ using UpdateGptFeedsFunc = std::function beam_next_tokens, - gsl::span beam_indices, + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, int num_beams, int gpt_subgraph_first_past_input_idx, int gpt_subgraph_first_present_output_idx, bool past_present_share_buffer, - int past_sequence_len)>; + int past_sequence_len, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention)>; // Create encoder inputs (for encoder-decoder model like T5). using CreateEncoderInputsFunc = std::function beam_next_tokens, - gsl::span beam_indices, + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, int num_beams, int gpt_subgraph_first_past_input_idx, int gpt_subgraph_first_present_output_idx, bool past_present_share_buffer, - int past_sequence_len); + int past_sequence_len, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention); // --------------------------------------------------------------- // Functions for encoder-decoder model like T5 diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index 9aed3eb3fd..3faba9a856 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -38,6 +38,11 @@ struct IBeamSearchState { // temp token: (batch_size * num_beams, 2 * num_beams) // in total, it will be: // 2 * (batch_size * num_beams * (parts_vocab + 1), 2 * num_beams) + + // The final chosen indices after BeamScorer has finished processing + gsl::span chosen_indices; // shape (batch_size, num_beams) + + Tensor staging_for_past_state_reorder; // Tensor of shape (batch_size * num_beams, num_heads, max_length, head_size) }; struct IBeamSearchCpuState { @@ -117,6 +122,8 @@ class IBeamScorer { gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) = 0; + + virtual gsl::span& GetNextIndices() = 0; }; struct IGenerationParameters { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index ecd40bfdfa..ccf387a895 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -81,7 +81,7 @@ struct GreedySearchState : public IGreedySearchState { int max_length, int num_heads, int head_size, - bool allocate_staging_buffer_for_past_state_reorder, + bool has_decoder_masked_multihead_attention, bool is_cuda) { // below buffers are on cpu this->sequences_space = AllocateBuffer(cpu_allocator, @@ -111,8 +111,9 @@ struct GreedySearchState : public IGreedySearchState { this->topk_scores_buffer, this->topk_tokens_buffer); - // If at all we need to, we only need to re-order past state for CUDA - if (allocate_staging_buffer_for_past_state_reorder) { + // If at all we need to, we only need to re-order past state for CUDA as + //`DecoderMaskedMultiheadAttention` is only supported on CUDA + if (has_decoder_masked_multihead_attention) { TensorShape staging_for_past_state_reorder_buffer_shape = {batch_size, num_heads, max_length, head_size}; Tensor temp(DataTypeImpl::GetType(), staging_for_past_state_reorder_buffer_shape, allocator); diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index fc9cc01c86..d1afcbc6c6 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -170,11 +170,14 @@ Status GreedySearchGpt::UpdateFeeds( increase_position, next_tokens, place_holder, + place_holder, this->parameters_->num_beams, gpt_subgraph_.GetFirstPastInputIndex(), gpt_subgraph_.GetFirstPresentOutputIndex(), gpt_subgraph_.past_present_share_buffer_, - past_sequence_length); + past_sequence_length, + -1, // Input sequence length needn't be passed in for GreedySearch + false); } template @@ -329,7 +332,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_ // Reorder past state after first run if the GPT subgraph (the one used after the first iteration) // contains DecoderMaskedMultiheadAttention nodes if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_multihead_attention_) { - size_t offset = static_cast(gpt_subgraph_.GetFirstPresentOutputIndex()); + size_t offset = static_cast(gpt_subgraph_.GetFirstPresentOutputIndex()); // We will use the same staging buffer while transposing all the layers' past state // and this is okay because we use the same stream to do the staging copy and the transpose // operations. diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc index 6703c6eec9..012102b45c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -91,6 +91,9 @@ Status Subgraph::Setup(const SessionState& session_state, past_present_share_buffer_ = true; // past_sequence_length is on CPU memory feed_locations.push_back(OrtDevice()); + } else if (feed_names[i] == "beam_width") { + // beam_width is on CPU memory + feed_locations.push_back(OrtDevice()); } else { feed_locations.push_back(default_location.device); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc index 3d8af10737..c12301803a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc @@ -27,7 +27,8 @@ Status GptSubgraph::CreateInitialFeeds( const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, IAllocatorUniquePtr& buffer, Stream* ort_stream, - int max_seq_len_past_present_share_buffer) { + int past_present_share_buffer_max_seq_len, + bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) { ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); const IExecutionProvider* provider = GetProvider(); @@ -83,20 +84,48 @@ Status GptSubgraph::CreateInitialFeeds( feeds.push_back(empty_past); } } else { - int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, max_seq_len_past_present_share_buffer, head_size}; + // Past state feeds + int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, past_present_share_buffer_max_seq_len, head_size}; TensorShape past_shape(&past_state_dims[0], 5); - // The remaining inputs are past state except the last one - for (int i = first_past_input_index_; i < num_subgraph_inputs - 1; ++i) { + + // The remaining inputs are past state except the last one or three (see below for details) + // If `add_beam_search_specific_inputs_for_decoder_masked_multihead_attention` is false, then the last input is `past_sequence_length` + + // If `add_beam_search_specific_inputs_for_decoder_masked_multihead_attention` is true, then the last inputs are `past_sequence_length`, + // `beam_width`, and `cache_indirection` + auto past_end_iter = add_beam_search_specific_inputs_for_decoder_masked_multihead_attention ? num_subgraph_inputs - 3 : num_subgraph_inputs - 1; + for (int i = first_past_input_index_; i < past_end_iter; ++i) { OrtValue past_tensor; Tensor::InitOrtValue(past_type, past_shape, default_allocator, past_tensor); feeds.push_back(past_tensor); } + + // Past sequence length feed int64_t past_seq_len_dims[] = {1}; TensorShape past_seq_len_shape(&past_seq_len_dims[0], 1); OrtValue past_seq_len_tensor_value; Tensor::InitOrtValue(DataTypeImpl::GetType(), past_seq_len_shape, cpu_allocator, past_seq_len_tensor_value); feeds.push_back(past_seq_len_tensor_value); *past_seq_len_tensor_value.GetMutable()->MutableData() = 0; + + // Add beam search specific inputs + if (add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) { + // Beam width feed + int64_t num_beams_dims[] = {1}; + TensorShape num_beams_shape(&num_beams_dims[0], 1); + OrtValue num_beams_tensor_value; + Tensor::InitOrtValue(DataTypeImpl::GetType(), num_beams_shape, cpu_allocator, num_beams_tensor_value); + feeds.push_back(num_beams_tensor_value); + *num_beams_tensor_value.GetMutable()->MutableData() = static_cast(num_beams); + + // Cache indirection feed + int64_t cache_indirection_dims[] = {batch_size, num_beams, past_present_share_buffer_max_seq_len}; + TensorShape cache_indirection_shape(&cache_indirection_dims[0], 3); + OrtValue default_cache_indirection; + Tensor::InitOrtValue(DataTypeImpl::GetType(), cache_indirection_shape, + default_allocator, default_cache_indirection); + feeds.push_back(default_cache_indirection); + } } // Pass in implicit inputs @@ -112,8 +141,12 @@ Status GptSubgraph::Validate(const std::vector& subgraph_inputs, ORT_RETURN_IF(num_subgraph_outputs <= first_present_output_index_, "Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in outputs)."); - ORT_RETURN_IF(!((num_subgraph_inputs == num_subgraph_outputs + 2) || (num_subgraph_inputs == num_subgraph_outputs + 3)), - "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2 or 3 (if past_present_share_buffer)"); + ORT_RETURN_IF(!((num_subgraph_inputs == num_subgraph_outputs + 2) || + (num_subgraph_inputs == num_subgraph_outputs + 3) || + (num_subgraph_inputs == num_subgraph_outputs + 5)), + "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2 or " + "3 (if past_present_share_buffer) or " + "5 (if past_present_share_buffer and use_decoder_masked_multihead_attention for BeamSearch)"); ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h index 855d017e64..c24ac3a5d1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h @@ -16,9 +16,9 @@ class GptSubgraph : public Subgraph { const onnxruntime::Node& node_in, const std::string& attribute_name, const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) { - first_past_input_index_ = 3; - first_present_output_index_ = 1; - } + first_past_input_index_ = 3; + first_present_output_index_ = 1; + } // Create inputs for first inference of subgraph. Status CreateInitialFeeds( @@ -34,7 +34,8 @@ class GptSubgraph : public Subgraph { const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, IAllocatorUniquePtr& buffer, Stream* ort_stream, - int max_seq_len_past_present_share_buffer = -1); + int past_present_share_buffer_max_seq_len = -1, + bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention = false); Status Validate(const std::vector& subgraph_inputs, const std::vector& subgraph_outputs) override; diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc index 8e28bb253e..49000855f2 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc +++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc @@ -16,20 +16,23 @@ namespace contrib { namespace cuda { static constexpr int kPastSequenceLengthInputIndex = 6; +static constexpr int kBeamWidthInputIndex = 7; +static constexpr int kCacheIndirectionInputIndex = 8; static constexpr int kPastInputIndex = 4; static constexpr int kPresentOutputIndex = 1; -#define REGISTER_KERNEL_TYPED(T1, T2) \ - ONNX_OPERATOR_TYPED_KERNEL_EX( \ - DecoderMaskedMultiheadAttention, \ - kMSDomain, \ - 1, \ - T1, \ - kCudaExecutionProvider, \ - (*KernelDefBuilder::Create()) \ - .MayInplace(kPastInputIndex, kPresentOutputIndex) \ - .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ - .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \ +#define REGISTER_KERNEL_TYPED(T1, T2) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + DecoderMaskedMultiheadAttention, \ + kMSDomain, \ + 1, \ + T1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .MayInplace(kPastInputIndex, kPresentOutputIndex) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()) \ + .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \ + .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \ DecoderMaskedMultiheadAttention); REGISTER_KERNEL_TYPED(float, float) @@ -44,6 +47,8 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext* const Tensor* past = context->Input(kPastInputIndex); const Tensor* relative_position_bias = context->Input(5); const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex); + const Tensor* beam_width = context->Input(kBeamWidthInputIndex); + const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex); auto& device_prop = GetDeviceProp(); DecoderMaskedMultiheadAttentionParams parameters; @@ -105,7 +110,7 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext* auto* present_data = present->MutableData(); auto* past_data = past->Data(); - // No production use-case will incure this copy cost as the implementation of + // No production use-case will incur this copy cost as the implementation of // GreedySearch/BeamSearch is written in such a way that the past and present buffers // will be shared. // This is just to circumvent the OpTester's limitation of not being able to bind a specific @@ -142,13 +147,13 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext* // Update the q, k, and v buffers parameters.q = gemm_buffer.get(); parameters.k = reinterpret_cast(gemm_buffer.get()) + parameters.hidden_size; - parameters.v = reinterpret_cast(gemm_buffer.get()) + 2 * parameters.hidden_size; + parameters.v = reinterpret_cast(gemm_buffer.get()) + 2 * static_cast(parameters.hidden_size); // Update the q, k, and v bias const T1* bias_data = bias->Data(); parameters.q_bias = const_cast(bias_data); parameters.k_bias = const_cast(bias_data + parameters.hidden_size); - parameters.v_bias = const_cast(bias_data + 2 * parameters.hidden_size); + parameters.v_bias = const_cast(bias_data + 2 * static_cast(parameters.hidden_size)); // Half of the past/present buffer correspond to K - the other half is V. auto k_size = present->Shape().Size() / 2; @@ -167,6 +172,22 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext* parameters.mask = mask_index->Data(); } + // Beam width (in case we are using this op inside BeamSearch) + if (beam_width != nullptr) { + parameters.beam_width = static_cast(*beam_width->Data()); + } + + // Cache indirection (in case we are using this op inside BeamSearch) + if (parameters.beam_width > 1) { + // If beam width > 1, then cache indirection buffer MUST be present + if (cache_indir == nullptr) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "If beam width is greater than 1, then cache indirection buffer MUST be present"); + } + + parameters.cache_indir = cache_indir->Data(); + } + switch (parameters.head_size) { case 64: mmha_launch_kernel(parameters, cuda_stream); @@ -182,7 +203,6 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext* "Got head size: ", parameters.head_size); } - return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu index 39dec84663..27a2b6de83 100644 --- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu +++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu @@ -22,6 +22,8 @@ // Modifications: // (1) Removed some code paths from the original implementation that had features which is not supported by // corresponding ORT kernel - for example- CrossAttention support, FP8, INT8, supports, etc. +// (2) When dealing with masked tokens, this kernel implementation deviates from FasterTransformer by applying +// mask filter values. Appropriate commentary exists in the code below. #include "decoder_masked_multihead_attention_impl.h" #include "decoder_masked_multihead_attention_impl_utils.h" @@ -286,7 +288,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. bool has_beams = params.cache_indir != nullptr; - const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_max_seq_length] : nullptr; for (int ti = ko; ti < ti_end; ti += K_PER_ITER) { @@ -294,16 +295,24 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio // The keys loaded from the key cache. K_vec_k k_vec[K_VECS_PER_THREAD]; -#pragma unroll - for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { - int jj = ii * params.max_sequence_length + ti; - if (ti < tlength) { - if (has_beams) { + if (has_beams) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; + + if (ti < tlength) { const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size; k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]))); - } else { + } + } + } else { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.max_sequence_length + ti; + + if (ti < tlength) { k_vec[ii] = vec_conversion( (*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]))); } @@ -314,9 +323,16 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! float qk = Qk_dot::dot(q_vec, k_vec) * inv_sqrt_dh; + // This is a deviation from FasterTransformer kernel implementation + // but this aligns with ORT's other Attention kernels which strives to + // mimic PyTorch when dealing with mask filter values + if (is_masked) { + qk += params.mask_filter_value; + } + // Store the product to shared memory. There's one qk value per timestep. Update the max. if (ti < tlength && tidx % THREADS_PER_KEY == 0) { - qk_max = is_masked ? qk_max : fmaxf(qk_max, qk); + qk_max = fmaxf(qk_max, qk); qk_smem[ti] = qk; } } @@ -355,8 +371,10 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio // Compute the logits and start the sum. float sum = 0.f; for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { - bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0); - float logit = is_masked ? 0.f : __expf(qk_smem[ti] - qk_max); + // This is a deviation from FasterTransformer kernel implementation + // but this aligns with ORT's other Attention kernels which strives to + // mimic PyTorch when dealing with mask filter values + float logit = __expf(qk_smem[ti] - qk_max); sum += logit; qk_smem[ti] = logit; } diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 49aff75eb8..e2ced4786d 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -35,7 +35,8 @@ transformers::CudaTensorConsoleDumper g_cuda_dumper; BeamSearch::BeamSearch(const OpKernelInfo& info) : onnxruntime::contrib::transformers::BeamSearch(info) { - SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds, + SetDeviceHelpers(GenerationCudaDeviceHelper::ReorderPastState, + GenerationCudaDeviceHelper::AddToFeeds, GenerationCudaDeviceHelper::TopK, GenerationCudaDeviceHelper::DeviceCopy, GenerationCudaDeviceHelper::DeviceCopy, @@ -54,6 +55,11 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) GenerationCudaDeviceHelper::ExpandBuffer); SetConsoleDumper(&g_cuda_dumper); + + cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp(); + + cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 + + static_cast(cuda_device_prop_)->minor * 10; } Status BeamSearch::ComputeInternal(OpKernelContext* context) const { diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 90c9122820..b354b1c9a6 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -424,7 +424,7 @@ void LaunchSortPairs(void* d_temp_storage, d_offsets + 1, 0, sizeof(T) * 8, - stream)); + stream)); } else { CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, @@ -469,7 +469,7 @@ template void LaunchSortPairs(void* d_temp_storage, // A stateful callback functor that maintains a running prefix to be applied // during consecutive scan operations. struct BlockPrefixCallbackOp { - float running_total; // running prefix + float running_total; // running prefix __device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {} // Callback operator to be entered by the first warp of threads in the block. @@ -746,6 +746,65 @@ void TorchMultinomialKernelLauncher(float* d_input, d_presence_mask); } +__global__ void UpdateDecoderMaskedMultiheadAttentionCacheIndirectionKernel(int32_t* tgt_indir_cache, + const int32_t* src_indir_cache, + const int32_t* beam_ids, + int batch_size, + int beam_width, + int input_seq_length, + int max_seq_length, + int current_length) { + int time_step = threadIdx.x + blockIdx.x * blockDim.x; + int bb_id = threadIdx.y + blockIdx.y * blockDim.y; + const int batch_id = bb_id / beam_width; + const int beam_id = bb_id % beam_width; + + if (bb_id >= beam_width * batch_size || time_step >= current_length) { + return; + } + + const int src_beam = beam_ids[batch_id * beam_width + beam_id] % beam_width; + + const int tgt_offset = batch_id * beam_width * max_seq_length + beam_id * max_seq_length + time_step; + + if (time_step < input_seq_length) { + // For time steps that correspond to the input sequence, + // the beam that it comes from is always 0. + tgt_indir_cache[tgt_offset] = static_cast(0); + } else if (time_step == (current_length - 1)) { + // For the final (newly generated) time step, + // the beam that it comes from is always the beam that we + // are currently processing (i.e.) from this point on, these time-steps + // form the new beams. + tgt_indir_cache[tgt_offset] = static_cast(beam_id); + } else { + // For all other time-steps, we look up the source indirection, to + // see which beam it came from based on the `src_beam`. + const int src_offset = batch_id * beam_width * max_seq_length + src_beam * max_seq_length + time_step; + tgt_indir_cache[tgt_offset] = src_indir_cache[src_offset]; + } +} + +void UpdateDecoderMaskedMultiheadAttentionCacheIndirection(int32_t* tgt_indir_cache, + const int32_t* src_indir_cache, + const int32_t* beam_ids, + int batch_size, + int beam_width, + int input_seq_length, + int max_seq_length, + int current_length, + cudaStream_t stream) { + const dim3 block(32); + const dim3 grid((current_length + block.x - 1) / block.x, batch_size * beam_width); + UpdateDecoderMaskedMultiheadAttentionCacheIndirectionKernel<<>>(tgt_indir_cache, + src_indir_cache, + beam_ids, + batch_size, + beam_width, + input_seq_length, + max_seq_length, + current_length); +} } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h index 5209622823..85b97df7a0 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -61,7 +61,7 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data, cudaStream_t stream); template -void GetTempStorageSize(const T *d_keys_in, +void GetTempStorageSize(const T* d_keys_in, const int* d_values_in, int* d_offsets, int num_items, @@ -77,15 +77,15 @@ void LaunchSetupParamsKernel(int* d_values_in, cudaStream_t stream); template -void LaunchSortPairs(void *d_temp_storage, +void LaunchSortPairs(void* d_temp_storage, size_t temp_storage_bytes, - const T *d_keys_in, - T *d_keys_out, - const int *d_values_in, - int *d_values_out, + const T* d_keys_in, + T* d_keys_out, + const int* d_values_in, + int* d_values_out, int num_items, int num_segments, - int *d_offsets, + int* d_offsets, cudaStream_t stream, bool is_descending); @@ -109,6 +109,16 @@ void TorchMultinomialKernelLauncher(float* d_input, int* d_presence_mask, cudaStream_t stream); +void UpdateDecoderMaskedMultiheadAttentionCacheIndirection(int32_t* tgt_indir_cache, + const int32_t* src_indir_cache, + const int32_t* beam_ids, + int batch_size, + int beam_width, + int input_seq_length, + int max_seq_length, + int current_length, + cudaStream_t stream); + } // namespace cuda } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index e4d75971c9..eed24b4a1f 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -481,6 +481,8 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams); dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams); #endif + + // TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), beam_state->next_scores.data(), beam_state->next_scores.size_bytes(), @@ -528,6 +530,7 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k); #endif + // TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), data, topk_scores->SizeInBytes(), @@ -535,6 +538,7 @@ Status ProcessLogits(const OrtValue& logits, // cuda_stream)); } + // TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(), beam_state->next_tokens.data(), beam_state->next_tokens.size_bytes(), @@ -551,13 +555,31 @@ Status ProcessLogits(const OrtValue& logits, // gsl::span next_tokens(cpu_state->topk_tokens.data(), beam_state->next_tokens.size()); gsl::span next_indices(cpu_state->topk_indices.data(), beam_state->next_indices.size()); - // Limitation: beam scorer runs in CPU. It might be better to use CUDA kernel to replace it. + // TODO: Implement BeamScorer on CUDA beam_scorer->Process( sequences, next_scores, next_tokens, next_indices); + // TODO: This is a temporary work-around as BeamScorer currently only runs on CPU. + // We can remove these kinds of work-arounds once BeamScorer runs on CUDA eventually. + auto chosen_indices = beam_scorer->GetNextIndices(); + auto beam_state_chosen_indices = beam_state->chosen_indices; + + if (!beam_state_chosen_indices.empty()) { + // If we have allocated `chosen_indices` in beam_state, it means that we + // will be needing the chosen indices from BeamScorer as we are using + // DecoderMaskedMultiheadAttention, so copy it over. + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state_chosen_indices.data(), + chosen_indices.data(), + chosen_indices.size_bytes(), + cudaMemcpyHostToDevice, + cuda_stream)); + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + } + #ifdef ENABLE_NVTX_PROFILE processLogitsRange.End(); #endif @@ -869,12 +891,15 @@ Status UpdateGptFeeds( OrtValue& position_ids, bool increase_position, gsl::span beam_next_tokens, - gsl::span beam_indices, + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, int num_beams, int gpt_subgraph_first_past_input_idx, int gpt_subgraph_first_present_output_idx, bool past_present_share_buffer, - int past_sequence_len) { + int past_sequence_len, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) { #ifdef ENABLE_NVTX_PROFILE profile::NvtxNestedRangeCreator updateFeedsRange("UpdateGptFeeds", profile::Color::Yellow); updateFeedsRange.Begin(); @@ -889,6 +914,8 @@ Status UpdateGptFeeds( Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; + + // TODO(hasesh): When BeamScorer is implemented on CUDA, figure out a way to avoid this cross-device copy CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, cuda_stream)); next_inputs[0] = input_ids; @@ -914,8 +941,41 @@ Status UpdateGptFeeds( next_inputs[2] = attention_mask; if (past_present_share_buffer) { - const int k = (static_cast(last_outputs.size()) - gpt_subgraph_first_present_output_idx) + gpt_subgraph_first_past_input_idx; - *(next_inputs[k].GetMutable()->MutableData()) = past_sequence_len; + // Update past sequence length input + const int past_sequence_length_idx = (static_cast(last_outputs.size()) - gpt_subgraph_first_present_output_idx) + gpt_subgraph_first_past_input_idx; + *(next_inputs[past_sequence_length_idx].GetMutable()->MutableData()) = past_sequence_len; + + // Update beam search specific input for DecoderMaskedMultiheadAttention (cache indirection) if present + + // If the last input is not `past_sequence_length`, then the beam search specific inputs + // for `DecoderMaskedMultiheadAttention` is present + if (has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) { + ORT_ENFORCE(!beam_indices_gpu.empty(), "Beam indices must be present on CUDA while using DecoderMaskedMultiheadAttention with BeamSearch"); + + // The cache indirection feed comes 2 feeds after the `past_sequence_length` feed + const OrtValue& old_cache_indirection = next_inputs[past_sequence_length_idx + 2]; + + // New cache indirection updated for next decoding run + OrtValue cache_indirection; + + Tensor::InitOrtValue(DataTypeImpl::GetType(), old_cache_indirection.Get().Shape(), allocator, cache_indirection); + + // The third index of the past/present tensor is the max_sequence_length + int max_sequence_length = static_cast(last_outputs[gpt_subgraph_first_present_output_idx].Get().Shape()[3]); + + // Launch kernel to update the cache indirection buffer + cuda::UpdateDecoderMaskedMultiheadAttentionCacheIndirection(cache_indirection.GetMutable()->MutableData(), + old_cache_indirection.Get().Data(), + reinterpret_cast(beam_indices_gpu.data()), + batch_beam_size / num_beams, + num_beams, + input_sequence_len, + max_sequence_length, + current_length, + cuda_stream); + // Update cache indirection for next decoding run + next_inputs[past_sequence_length_idx + 2] = cache_indirection; + } } else { if (num_beams == 1) { const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx; @@ -924,7 +984,7 @@ Status UpdateGptFeeds( next_inputs[i + k] = last_outputs[i]; } } else { - ORT_RETURN_IF_ERROR(PickGptPastState(last_outputs, next_inputs, beam_indices, allocator, + ORT_RETURN_IF_ERROR(PickGptPastState(last_outputs, next_inputs, beam_indices_cpu, allocator, gpt_subgraph_first_past_input_idx, gpt_subgraph_first_present_output_idx, ort_stream)); } @@ -1114,12 +1174,15 @@ template Status UpdateGptFeeds( OrtValue& position_ids, bool increase_position, gsl::span beam_next_tokens, - gsl::span beam_indices, + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, int num_beams, int gpt_subgraph_first_past_input_idx, int gpt_subgraph_first_present_output_idx, bool past_present_share_buffer, - int past_sequence_len); + int past_sequence_len, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention); // Float16 template void InitBeamState( @@ -1171,12 +1234,15 @@ template Status UpdateGptFeeds( OrtValue& position_ids, bool increase_position, gsl::span beam_next_tokens, - gsl::span beam_indices, + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, int num_beams, int gpt_subgraph_first_past_input_idx, int gpt_subgraph_first_present_output_idx, bool past_present_share_buffer, - int past_sequence_len); + int past_sequence_len, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention); template Status UpdateDecoderFeeds( AllocatorPtr allocator, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index 13ea4931a5..42a7a32959 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -97,12 +97,15 @@ Status UpdateGptFeeds( OrtValue& position_ids, bool increase_position, gsl::span beam_next_tokens, - gsl::span beam_indices, + gsl::span beam_indices_cpu, + gsl::span beam_indices_gpu, int num_beams, int gpt_subgraph_first_past_input_idx, int gpt_subgraph_first_present_output_idx, bool past_present_share_buffer, - int past_sequence_len); + int past_sequence_len, + int input_sequence_len, + bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention); // --------------------------------------------------------------- // Functions for encoder-decoder model like T5 diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc index 0b1d319272..b48aab3813 100644 --- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc @@ -367,9 +367,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "past", "past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)" "When past_present_share_buffer is set, " - "its shape is (2, batch_size, num_heads, max_sequence_length, head_size)", - "T", - OpSchema::Optional) + "its shape is (2, batch_size, num_heads, max_sequence_length, head_size). " + "The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys " + "and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. " + "The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape " + "(batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape " + "(batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to " + "become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.", + "T") .Input(5, "relative_position_bias", "additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)", @@ -378,6 +383,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA( .Input(6, "past_sequence_length", "When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).", + "M") + .Input(7, + "beam_width", + "The beam width that is being used while decoding." + "If not provided, the beam width will be assumed to be 1.", + "M", + OpSchema::Optional) + .Input(8, + "cache_indirection", + "A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies" + "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration", "M", OpSchema::Optional) .Output(0, @@ -390,8 +406,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA( "If past_present_share_buffer is set, " "its shape is (2, batch_size, num_heads, max_sequence_length, head_size), " "while effective_seq_length = (past_sequence_length + kv_sequence_length).", - "T", - OpSchema::Optional) + "T") .TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 37855c8c02..a7ea3b4e05 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -1026,7 +1026,7 @@ def shape_of(vi): return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim]) -def update_decoder_subgraph_past_present_share_buffer(subg): +def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto): input_past_0 = 3 output_past_0 = 1 new_inputs = [] @@ -1074,37 +1074,82 @@ def update_decoder_subgraph_past_present_share_buffer(subg): return subg -def update_decoder_subgraph_use_decoder_masked_multihead_attention(subg) -> bool: +def update_decoder_subgraph_use_decoder_masked_multihead_attention( + subg: GraphProto, is_beam_search: bool, switch_attention: bool +) -> bool: """Update the Attention nodes to DecoderMaskedMultiheadAttention. Args: subg (GraphProto): GraphProto of the decoder subgraph + is_beam_search (bool): Boolean specifying if the sampling algo is BeamSearch + switch_attention (bool): Boolean specifying if `Attention` is to be switched with `DecoderMaskedMultiheadAttention` """ - decoder_masked_attention_supported_attr = [ - "past_present_share_buffer", - "num_heads", - "scale", - "domain", - ] - new_nodes = [] - for node in subg.node: - if node.op_type == "Attention": - kwargs = kwargs_of(node) - for k in kwargs.copy(): - # The Attention operator does not support different qkv hidden sizes when past/present - # input/output exists (GPT2 model). Hence, we should never run into this. - # But, if we do, do not go ahead with the optimization. - if k == "qkv_hidden_sizes": - return False + if is_beam_search: + new_inputs = [] + for i, vi in enumerate(subg.input): + new_inputs.extend([vi]) + + # Add 2 BeamSearch specific inputs + new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])]) + new_inputs.extend( + [ + onnx.helper.make_tensor_value_info( + "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"] + ) + ] + ) + subg.ClearField("input") + subg.input.extend(new_inputs) + + if switch_attention: + decoder_masked_attention_supported_attr = [ + "past_present_share_buffer", + "num_heads", + "scale", + "domain", + ] + + new_nodes = [] + for node in subg.node: + if node.op_type == "Attention": + kwargs = kwargs_of(node) + for k in kwargs.copy(): + # The Attention operator does not support different qkv hidden sizes when past/present + # input/output exists (GPT2 model). Hence, we should never run into this. + # But, if we do, do not go ahead with the optimization. + if k == "qkv_hidden_sizes": + return False + + if k not in decoder_masked_attention_supported_attr: + # Log the fact that we are removing certain attributes from the node + # We don't need to log it for "unidirectional" as we are aware that + # decoding attention kernels are unidirectional by definition. + if k != "unidirectional": + logger.warning( + f"Removing attribute: {k} from Attention node while switching to DecoderMaskedMultiheadAttention" + ) + + del kwargs[k] + + nis = [] + nis.extend(node.input) + + # Add 2 BeamSearch specific inputs + if is_beam_search: + while len(nis) < 7: + nis.extend([""]) + if len(nis) < 8: + nis.extend(["beam_width"]) + if len(nis) < 9: + nis.extend(["cache_indirection"]) + + node = onnx.helper.make_node( + "DecoderMaskedMultiheadAttention", nis, node.output, name=node.name, **kwargs + ) + new_nodes.extend([node]) + subg.ClearField("node") + subg.node.extend(new_nodes) - if k not in decoder_masked_attention_supported_attr: - del kwargs[k] - nis = [] - nis.extend(node.input) - node = onnx.helper.make_node("DecoderMaskedMultiheadAttention", nis, node.output, name=node.name, **kwargs) - new_nodes.extend([node]) - subg.ClearField("node") - subg.node.extend(new_nodes) return True @@ -1452,9 +1497,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH is_sampling: bool = generation_type == GenerationType.SAMPLING - past_present_share_buffer: bool = args.past_present_share_buffer and (is_greedysearch or is_sampling) + past_present_share_buffer: bool = args.past_present_share_buffer - logger.info(f"**** past_present_share_buffer={past_present_share_buffer}, is_greedysearch={is_greedysearch}") + logger.info(f"**** past_present_share_buffer={past_present_share_buffer}") if is_greedysearch or is_sampling: if not is_gpt2: @@ -1464,6 +1509,29 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati if args.output_token_scores: raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling") + # For BeamSearch, sharing buffers for past and present states is only supported + # when using `use_decoder_masked_multihead_attention` + if past_present_share_buffer and is_beamsearch and not args.use_decoder_masked_multihead_attention: + raise ValueError( + "`use_decoder_masked_multihead_attention` MUST be turned on to use `past_present_share_buffer` in case of BeamSearch" + ) + + # For any kind of sampling, using decoder masked multihead attention is only supported + # when using `past_present_share_buffer` + if args.use_decoder_masked_multihead_attention and not past_present_share_buffer: + raise ValueError( + "`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_multihead_attention`" + ) + + # For any kind of sampling, using decoder masked multihead attention is only supported + # on GPUs + if args.use_decoder_masked_multihead_attention and not args.use_gpu: + raise ValueError("`use_decoder_masked_multihead_attention` option is only supported on GPUs") + + # Using decoder masked multihead attention is only supported for GPT2 + if args.use_decoder_masked_multihead_attention and args.model_type in ["t5", "mt5"]: + raise ValueError("`use_decoder_masked_multihead_attention` option is only supported for GPT2") + if is_gpt2: if args.decoder_onnx and os.path.exists(args.decoder_onnx): logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}") @@ -1699,6 +1767,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati node.attribute.extend(attr_to_extend) initializers = [] + if args.model_type in ["t5", "mt5"]: if args.run_shape_inference: logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.") @@ -1745,37 +1814,43 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati logger.info( f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in decoder and init decoder subgraphs are moved to the main graph" ) - # Update for init decoder subgraph + + # Update init decoder subgraph in preparation to use past present share buffer if past_present_share_buffer: logger.info("*****update init decoder subgraph to make past and present share buffer******************") update_decoder_subgraph_past_present_share_buffer(gpt2_init_decoder_model.graph) + + # Update init decoder subgraph in preparation to use DecoderMaskedMultiheadAttention + # NOTE: Even if we will not use DecoderMaskedMultiheadAttention in the init decoder subgraph + # it makes the runtime changes cleaner if we keep both the init decoder and decoder subgraphs + # same in terms of the subgraph inputs. + if ( + args.use_decoder_masked_multihead_attention + and not update_decoder_subgraph_use_decoder_masked_multihead_attention( + gpt2_init_decoder_model.graph, is_beamsearch, False + ) + ): + raise ValueError("Could not update the init decoder subgraph to use DecoderMaskedMultiheadAttention") + node.attribute.append(onnx.helper.make_attribute("init_decoder", gpt2_init_decoder_model.graph)) else: # Move initializer from subgraph to main graph could reduce memory usage in inference. initializers = move_initializers(decoder_model.graph) logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph") - # Update for non-init decoder subgraph + # Update decoder subgraph in preparation to use past present share buffer if past_present_share_buffer: logger.info("*****update decoder subgraph to make past and present share buffer******************") update_decoder_subgraph_past_present_share_buffer(decoder_model.graph) - # If at all, the user requests the use of `use_decoder_masked_multihead_attention`, - # we can only use it in the decoder subgraph - it cannot be used in the init_decoder subgraph - # as the kernel only supports the decoding use-case (i.e.) when input sequence length is 1. - if args.use_decoder_masked_multihead_attention: - logger.info("Update decoder subgraph to use DecoderMaskedMultiheadAttention") - - if not past_present_share_buffer: - raise ValueError( - "`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_multihead_attention`" - ) - - if not args.use_gpu: - raise ValueError("`use_decoder_masked_multihead_attention` option is only supported on GPUs") - - if not update_decoder_subgraph_use_decoder_masked_multihead_attention(decoder_model.graph): - raise ValueError("Could not update the decoder subgraph to use DecoderMaskedMultiheadAttention") + # Update decoder subgraph in preparation to use DecoderMaskedMultiheadAttention + if ( + args.use_decoder_masked_multihead_attention + and not update_decoder_subgraph_use_decoder_masked_multihead_attention( + decoder_model.graph, is_beamsearch, True + ) + ): + raise ValueError("Could not update the decoder subgraph to use DecoderMaskedMultiheadAttention") node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph)) diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py index 7bc8c1db45..33122f08b7 100644 --- a/onnxruntime/test/python/transformers/test_generation.py +++ b/onnxruntime/test/python/transformers/test_generation.py @@ -156,6 +156,18 @@ class TestBeamSearchGpt(unittest.TestCase): if self.enable_cuda: self.run_beam_search("--repetition_penalty 1.0 --use_gpu -p fp16", is_greedy=True) + @pytest.mark.slow + def test_beam_search_use_decoder_masked_multihead_attention(self): + if self.enable_cuda: + self.run_beam_search(f"--past_present_share_buffer --use_decoder_masked_multihead_attention --use_gpu") + + @pytest.mark.slow + def test_beam_search_use_decoder_masked_multihead_attention_fp16(self): + if self.enable_cuda: + self.run_beam_search( + f"--past_present_share_buffer --use_decoder_masked_multihead_attention --use_gpu -p fp16" + ) + @pytest.mark.slow def test_external_data(self): self.run_beam_search(