diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 5862716638..66afbebae5 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -1133,7 +1133,7 @@ This version of the operator has been available since version 1 of the 'com.micr
Custom scale will be used if specified. Default value is 1/sqrt(head_size)
-#### Inputs (3 - 7)
+#### Inputs (7 - 9)
- input : T
@@ -1144,20 +1144,24 @@ This version of the operator has been available since version 1 of the 'com.micr
- Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection
- mask_index (optional) : M
- Mask values of shape (batch_size, total_sequence_length)
-- past (optional) : T
-- past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)
+- past : T
+- past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size). The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.
- relative_position_bias (optional) : T
- additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)
-- past_sequence_length (optional) : M
+- past_sequence_length : M
- When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).
+- beam_width (optional) : M
+- The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.
+- cache_indirection (optional) : M
+- A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration
-#### Outputs (1 - 2)
+#### Outputs
- output : T
- 3D output tensor with shape (batch_size, sequence_length, v_hidden_size)
-- present (optional) : T
+- present : T
- past state for key and value with shape (2, batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 1e687e1cb4..3bd890e5b4 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -798,7 +798,7 @@ Do not modify directly.*
|ComplexMulConj|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**
*in* W:**T**
*in* Pads:**tensor(int64)**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|DecoderAttention|*in* query:**T**
*in* key:**T**
*in* q_weight:**T**
*in* kv_weight:**T**
*in* bias:**T**
*in* key_padding_mask:**B**
*in* key_cache:**T**
*in* value_cache:**T**
*in* static_kv:**B**
*in* use_past:**B**
*in* has_layer_state:**B**
*in* has_key_padding_mask:**B**
*out* output:**T**
*out* new_key_cache:**T**
*out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
-|DecoderMaskedMultiheadAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
+|DecoderMaskedMultiheadAttention|*in* input:**T**
*in* weights:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* relative_position_bias:**T**
*in* past_sequence_length:**M**
*in* beam_width:**M**
*in* cache_indirection:**M**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**
*in* x_scale:**T2**
*in* x_zero_point:**T1**
*out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)
**T2** = tensor(float16)|
|DequantizeWithOrder|*in* input:**Q**
*in* scale_input:**S**
*out* output:**F**|1+|**F** = tensor(float), tensor(float16)
**Q** = tensor(int8)
**S** = tensor(float)|
|EmbedLayerNormalization|*in* input_ids:**T1**
*in* segment_ids:**T1**
*in* word_embedding:**T**
*in* position_embedding:**T**
*in* segment_embedding:**T**
*in* gamma:**T**
*in* beta:**T**
*in* mask:**T1**
*in* position_ids:**T1**
*out* output:**T**
*out* mask_index:**T1**
*out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
index c839bb4245..de2513789c 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -163,6 +163,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
if (has_init_decoder_) {
ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
+ ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
+ "past_present_share_buffer mode must be same for init decoder and decoder subgraphes");
}
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
@@ -181,12 +183,15 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
thread_pool, ctx->GetComputeStream(), dumper_, parameters,
GenerationCpuDeviceHelper::CreateGptInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
+ reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy,
- update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds};
+ update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds,
+ cuda_device_prop_,
+ cuda_device_arch_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
@@ -200,12 +205,15 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
thread_pool, ctx->GetComputeStream(), dumper_, parameters,
GenerationCpuDeviceHelper::CreateGptInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
+ reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
- update_gpt_feeds_fp16_func_};
+ update_gpt_feeds_fp16_func_,
+ cuda_device_prop_,
+ cuda_device_arch_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
index fe11a8a2ba..37b94bd8f1 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
@@ -44,6 +44,7 @@ class BeamSearch : public IControlFlowKernel {
// device helpers that is same for both GPT and encoder-decoder models.
void SetDeviceHelpers(
+ const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func,
@@ -52,6 +53,7 @@ class BeamSearch : public IControlFlowKernel {
const GenerationDeviceHelper::ProcessLogitsFunc& process_logits_fp16_func,
const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_func,
const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_fp16_func) {
+ reorder_past_state_func_ = reorder_past_state_func;
add_to_feeds_func_ = add_to_feeds_func;
topk_func_ = topk_func;
device_copy_func_ = device_copy_func;
@@ -83,8 +85,12 @@ class BeamSearch : public IControlFlowKernel {
expand_buffer_float16_func_ = expand_buffer_float16_func;
}
+ const void* cuda_device_prop_ = nullptr;
+ int cuda_device_arch_ = 0;
+
private:
// Device specific functions
+ GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::TopkFunc topk_func_;
GenerationDeviceHelper::DeviceCopyFunc device_copy_func_;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
index 490a68d240..75d161a2cd 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
@@ -20,6 +20,9 @@ struct BeamSearchState : public IBeamSearchState {
int vocab_size,
int sequence_length,
int max_length,
+ int num_heads,
+ int head_size,
+ int has_decoder_masked_multihead_attention,
bool output_scores,
bool use_position) {
size_t batch_beam_size = SafeInt(batch_size) * num_beams;
@@ -49,6 +52,21 @@ struct BeamSearchState : public IBeamSearchState {
this->scores = AllocateBuffer(allocator, scores_buffer_, elements);
this->remaining_scores = this->scores;
}
+
+ if (has_decoder_masked_multihead_attention) {
+ // We need a temp staging buffer to do the past 'K' state re-ordering that is needed
+ // when using DecoderMaskedMultiheadAttention
+ TensorShape staging_for_past_state_reorder_buffer_shape = {static_cast(batch_beam_size), num_heads, max_length, head_size};
+
+ Tensor temp(DataTypeImpl::GetType(), staging_for_past_state_reorder_buffer_shape, allocator);
+
+ this->staging_for_past_state_reorder = std::move(temp);
+
+ // We need a buffer on GPU to hold the final chosen indices after BeamScorer has finished processing
+ // TODO: This is a temporary work-around as BeamScorer currently only runs on CPU.
+ // We can remove these kinds of work-arounds once BeamScorer runs on CUDA eventually.
+ this->chosen_indices = AllocateBuffer(allocator, chosen_indices_buffer_, batch_beam_size);
+ }
}
private:
@@ -61,6 +79,7 @@ struct BeamSearchState : public IBeamSearchState {
BufferUniquePtr beam_scores_buffer_;
BufferUniquePtr scores_buffer_;
BufferUniquePtr topk_temp_buffer_;
+ BufferUniquePtr chosen_indices_buffer_;
};
struct BeamSearchCpuState : public IBeamSearchCpuState {
@@ -124,7 +143,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
// Base class of beam search implementation that is common for both GPT-2 and T5.
template
-class BeamSearchBase : public GenerateBase {
+class BeamSearchBase : public GenerateBase {
public:
BeamSearchBase(OpKernelContextInternal& context,
const SessionState& decoder_session_state,
@@ -136,13 +155,13 @@ class BeamSearchBase : public GenerateBase {
const GenerationDeviceHelper::ProcessLogitsFunc& process_logits_func,
const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func,
const GenerationDeviceHelper::DeviceCopyFunc& device_copy_int32_func)
- : GenerateBase(context,
- decoder_session_state,
- thread_pool,
- ort_stream,
- cuda_dumper,
- topk_func,
- device_copy_func),
+ : GenerateBase(context,
+ decoder_session_state,
+ thread_pool,
+ ort_stream,
+ cuda_dumper,
+ topk_func,
+ device_copy_func),
parameters_(¶ms),
process_logits_func_(process_logits_func),
device_copy_int32_func_(device_copy_int32_func) {
@@ -188,11 +207,11 @@ Status BeamSearchBase::CheckInputs(const OpKernelContextInternal& context) {
// input_ids : (batch_size, sequence_length)
// vocab_mask : (vocab_size) or nullptr
ORT_RETURN_IF_ERROR(this->CheckInputsImpl(parameters_,
- context.Input(0), // input_ids
- context.Input(7), // vocab_mask
- context.Input(8), // prefix_vocab_mask
- context.Input(9), // attention_mask
- nullptr)); // presence_mask
+ context.Input(0), // input_ids
+ context.Input(7), // vocab_mask
+ context.Input(8), // prefix_vocab_mask
+ context.Input(9), // attention_mask
+ nullptr)); // presence_mask
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
index 46697a202d..dc6d33da99 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
@@ -27,12 +27,15 @@ class BeamSearchGpt : public BeamSearchBase {
BeamSearchParameters& params,
const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func,
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
+ const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::ProcessLogitsFunc& process_logits_func,
const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_func,
const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func,
const GenerationDeviceHelper::DeviceCopyFunc& device_copy_int32_func,
- const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func)
+ const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func,
+ const void* cuda_device_prop,
+ int cuda_device_arch)
: BeamSearchBase(context, decoder_session_state, thread_pool,
ort_stream, cuda_dumper, params,
topk_func, process_logits_func, device_copy_func, device_copy_int32_func),
@@ -42,7 +45,17 @@ class BeamSearchGpt : public BeamSearchBase {
create_inputs_func_(create_inputs_func),
add_to_feeds_func_(add_to_feeds_func),
init_beam_state_func_(init_beam_state_func),
- update_feeds_func_(update_feeds_func) {
+ reorder_past_state_func_(reorder_past_state_func),
+ update_feeds_func_(update_feeds_func),
+ cuda_device_prop_(cuda_device_prop),
+ cuda_device_arch_(cuda_device_arch) {
+ if (gpt_subgraph_.has_decoder_masked_multihead_attention_) {
+ ORT_ENFORCE(cuda_device_arch_ >= 530,
+ "Decoder masked multihead attention can only be used on "
+ "GPU cards of compute capability 5.3 or higher. "
+ "This card has compute capability ",
+ cuda_device_arch_);
+ }
}
// Execute beam search in iterations util stopping criteria is reached.
@@ -55,7 +68,8 @@ class BeamSearchGpt : public BeamSearchBase {
Status CreateInitialFeeds(gsl::span& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector& feeds,
- IAllocatorUniquePtr& buffer);
+ IAllocatorUniquePtr& buffer,
+ bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// Update the input for next iteration.
Status UpdateFeeds(
@@ -65,7 +79,11 @@ class BeamSearchGpt : public BeamSearchBase {
OrtValue& position_ids,
bool increase_position,
gsl::span beam_next_tokens,
- gsl::span beam_indices);
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
+ int past_sequence_length,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
const SessionState* init_run_decoder_session_state_ = nullptr;
GptSubgraph* init_run_gpt_subgraph_ = nullptr;
@@ -75,14 +93,19 @@ class BeamSearchGpt : public BeamSearchBase {
GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_;
+ GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_;
+
+ const void* cuda_device_prop_ = nullptr;
+ int cuda_device_arch_ = 0;
};
template
Status BeamSearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector& feeds,
- IAllocatorUniquePtr& buffer) {
+ IAllocatorUniquePtr& buffer,
+ bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
const OrtValue* input_ids_value = this->context_.GetInputOrtValue(0);
const Tensor& input_ids = input_ids_value->Get();
const OrtValue* attn_mask_value = this->context_.GetInputOrtValue(9);
@@ -99,7 +122,9 @@ Status BeamSearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths
this->create_inputs_func_,
this->add_to_feeds_func_,
buffer,
- this->ort_stream_);
+ this->ort_stream_,
+ this->parameters_->max_length,
+ add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
}
return gpt_subgraph_.CreateInitialFeeds(input_ids,
@@ -113,7 +138,9 @@ Status BeamSearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths
this->create_inputs_func_,
this->add_to_feeds_func_,
buffer,
- this->ort_stream_);
+ this->ort_stream_,
+ this->parameters_->max_length,
+ add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
}
template
@@ -124,7 +151,11 @@ Status BeamSearchGpt::UpdateFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span beam_next_tokens,
- gsl::span beam_indices) {
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
+ int past_sequence_length,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
return update_feeds_func_(this->temp_space_allocator_,
this->ort_stream_,
last_outputs,
@@ -133,12 +164,15 @@ Status BeamSearchGpt::UpdateFeeds(
position_ids,
increase_position,
beam_next_tokens,
- beam_indices,
+ beam_indices_cpu,
+ beam_indices_gpu,
this->parameters_->num_beams,
gpt_subgraph_.GetFirstPastInputIndex(),
gpt_subgraph_.GetFirstPresentOutputIndex(),
- false,
- -1);
+ gpt_subgraph_.past_present_share_buffer_,
+ past_sequence_length,
+ input_sequence_len,
+ has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
}
template
@@ -192,7 +226,22 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch
// buffer in GPU for input_ids, position_ids and attention_mask
IAllocatorUniquePtr buffer;
OrtValue expanded_input_ids_in_cpu;
- ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
+ ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer,
+ gpt_subgraph_.has_decoder_masked_multihead_attention_));
+
+ if (gpt_subgraph_.past_present_share_buffer_) { // Reuse past and present
+ fetches.reserve(static_cast(gpt_subgraph_.GetFirstPresentOutputIndex()) + gpt_subgraph_.num_layers);
+ fetches.resize(gpt_subgraph_.GetFirstPresentOutputIndex(), OrtValue());
+ for (int layer = 0; layer < gpt_subgraph_.num_layers; layer++) {
+ int feed_idx = gpt_subgraph_.GetFirstPastInputIndex() + layer;
+ OrtValue& past_tensor_value = feeds[feed_idx];
+ Tensor* past_tensor = past_tensor_value.GetMutable();
+ OrtValue present_tensor_value;
+ Tensor::InitOrtValue(past_tensor->DataType(), past_tensor->Shape(), past_tensor->MutableData(),
+ past_tensor->Location(), present_tensor_value);
+ fetches.push_back(present_tensor_value);
+ }
+ }
BeamSearchState beam_state;
constexpr bool use_position = true;
@@ -202,6 +251,9 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch
parameters->vocab_size,
parameters->sequence_length,
parameters->max_length,
+ parameters->num_heads,
+ parameters->head_size,
+ gpt_subgraph_.has_decoder_masked_multihead_attention_,
parameters->output_scores,
use_position);
@@ -297,17 +349,49 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch
// Increase sequence length after a new token is generated.
++current_length;
+ // Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
+ // contains DecoderMaskedMultiheadAttention nodes
+ if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_multihead_attention_) {
+ size_t offset = static_cast(gpt_subgraph_.GetFirstPresentOutputIndex());
+ // We will use the same staging buffer while transposing all the layers' past state
+ // and this is okay because we use the same stream to do the staging copy and the transpose
+ // operations.
+ // If we ever do them in different streams, we must use different staging buffers to avoid data
+ // races.
+ for (size_t i = 0; i < static_cast(gpt_subgraph_.num_layers); ++i) {
+ ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
+ *fetches[offset + i].GetMutable(),
+ beam_state.staging_for_past_state_reorder,
+ this->ort_stream_));
+ }
+ }
+
// Prepare inputs for next round of subgraph call.
if (current_length < parameters->max_length) {
+ gsl::span place_holder;
// For the first iteration, position_ids is initialized as sequence lengths. We can add it to feeds directly.
// For the remaining iterations, we need increase position_ids first, then add it to feeds.
bool increase_position = (iteration_counter > 1);
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
position_ids, increase_position,
ReinterpretAsSpan(beam_next_tokens),
- ReinterpretAsSpan(beam_indices)));
+ ReinterpretAsSpan(beam_indices),
+ gpt_subgraph_.has_decoder_masked_multihead_attention_
+ ? ReinterpretAsSpan(beam_state.chosen_indices)
+ : place_holder,
+ current_length - 1,
+ parameters->sequence_length,
+ gpt_subgraph_.has_decoder_masked_multihead_attention_));
+ }
+
+ if (gpt_subgraph_.past_present_share_buffer_) {
+ // clear fetched values before presents[]
+ for (int idx = 0; idx < gpt_subgraph_.GetFirstPresentOutputIndex(); idx++) {
+ fetches[idx] = OrtValue();
+ }
+ } else {
+ fetches.clear();
}
- fetches.clear();
}
gsl::span final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
index 9841f425d2..562f10c615 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
@@ -193,6 +193,9 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches
parameters->vocab_size,
parameters->sequence_length,
parameters->max_length,
+ parameters->num_heads,
+ parameters->head_size,
+ false, // TODO: Support past/present state buffer re-use for T5
parameters->output_scores,
use_position);
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
index e2cd625093..fdd404efa0 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
@@ -94,7 +94,7 @@ class BeamSearchScorer : public IBeamScorer {
gsl::span& GetNextScores() { return next_beam_scores_; }
gsl::span& GetNextTokens() { return next_beam_tokens_; }
- gsl::span& GetNextIndices() { return next_beam_indices_; }
+ gsl::span& GetNextIndices() override { return next_beam_indices_; }
private:
size_t batch_size_;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
index 17c7f1e6c6..a9cde364a3 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
@@ -407,18 +407,18 @@ Status ProcessLogits(const OrtValue& logits, //
template
Status GreedySearchProcessLogits(
- const OrtValue& logits, // logits output of subgraph
- transformers::IGreedySearchState* greedy_state, // state
- transformers::ISamplingState* sampling_state, // sampling_state
- transformers::ISequences* sequences, // sequences
- AllocatorPtr& allocator, // default allocator
- onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
- transformers::ILogitsProcessorList* logits_processors, // logits processors
- const transformers::IGenerationParameters* parameters, // parameters
- bool do_sampling, // whether to do sampling
- int step, // iteration counter
- Stream* stream, // cuda stream (for CUDA only)
- const transformers::IConsoleDumper* dumper) { // tensor dumper
+ const OrtValue& logits, // logits output of subgraph
+ transformers::IGreedySearchState* greedy_state, // state
+ transformers::ISamplingState* sampling_state, // sampling_state
+ transformers::ISequences* sequences, // sequences
+ AllocatorPtr& allocator, // default allocator
+ onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
+ transformers::ILogitsProcessorList* logits_processors, // logits processors
+ const transformers::IGenerationParameters* parameters, // parameters
+ bool do_sampling, // whether to do sampling
+ int step, // iteration counter
+ Stream* stream, // cuda stream (for CUDA only)
+ const transformers::IConsoleDumper* dumper) { // tensor dumper
int batch_size = parameters->batch_size;
int vocab_size = parameters->vocab_size;
@@ -499,8 +499,8 @@ Status GreedySearchProcessLogits(
topk_indices));
#ifdef DEBUG_GENERATION
- dumper->Print("topk_scores", topk_scores);
- dumper->Print("topk_indices", topk_indices);
+ dumper->Print("topk_scores", topk_scores);
+ dumper->Print("topk_indices", topk_indices);
#endif
gsl::span next_token_indices = topk_indices.DataAsSpan();
@@ -574,15 +574,21 @@ Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span beam_next_tokens,
- gsl::span beam_indices,
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
- int past_sequence_len) {
+ int past_sequence_len,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
// last_outputs: logits, present_0, present_1, ...
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
ORT_UNUSED_PARAMETER(stream);
+ ORT_UNUSED_PARAMETER(beam_indices_gpu);
+ ORT_UNUSED_PARAMETER(input_sequence_len);
+ ORT_UNUSED_PARAMETER(has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// The following updates inputs for subgraph
@@ -631,14 +637,14 @@ Status UpdateGptFeeds(
return Status::OK();
}
- if (num_beams == 1) { // Update past state
+ if (num_beams == 1) { // Update past state
// feed present_* output to past_* inputs one by one
const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx;
for (size_t i = gpt_subgraph_first_present_output_idx; i < last_outputs.size(); ++i) {
next_inputs[i + k] = last_outputs[i];
}
} else {
- PickGptPastState(last_outputs, next_inputs, beam_indices,
+ PickGptPastState(last_outputs, next_inputs, beam_indices_cpu,
gpt_subgraph_first_past_input_idx,
gpt_subgraph_first_present_output_idx, allocator);
}
@@ -890,12 +896,15 @@ template Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span beam_next_tokens,
- gsl::span beam_indices,
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
- int past_sequence_len);
+ int past_sequence_len,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
template Status UpdateDecoderFeeds(
AllocatorPtr allocator,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
index 473f194205..6a9b2e93ec 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
@@ -128,12 +128,15 @@ using UpdateGptFeedsFunc = std::function beam_next_tokens,
- gsl::span beam_indices,
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
- int past_sequence_len)>;
+ int past_sequence_len,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention)>;
// Create encoder inputs (for encoder-decoder model like T5).
using CreateEncoderInputsFunc = std::function beam_next_tokens,
- gsl::span beam_indices,
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
- int past_sequence_len);
+ int past_sequence_len,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// ---------------------------------------------------------------
// Functions for encoder-decoder model like T5
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index 9aed3eb3fd..3faba9a856 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -38,6 +38,11 @@ struct IBeamSearchState {
// temp token: (batch_size * num_beams, 2 * num_beams)
// in total, it will be:
// 2 * (batch_size * num_beams * (parts_vocab + 1), 2 * num_beams)
+
+ // The final chosen indices after BeamScorer has finished processing
+ gsl::span chosen_indices; // shape (batch_size, num_beams)
+
+ Tensor staging_for_past_state_reorder; // Tensor of shape (batch_size * num_beams, num_heads, max_length, head_size)
};
struct IBeamSearchCpuState {
@@ -117,6 +122,8 @@ class IBeamScorer {
gsl::span& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) = 0;
+
+ virtual gsl::span& GetNextIndices() = 0;
};
struct IGenerationParameters {
diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h
index ecd40bfdfa..ccf387a895 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h
@@ -81,7 +81,7 @@ struct GreedySearchState : public IGreedySearchState {
int max_length,
int num_heads,
int head_size,
- bool allocate_staging_buffer_for_past_state_reorder,
+ bool has_decoder_masked_multihead_attention,
bool is_cuda) {
// below buffers are on cpu
this->sequences_space = AllocateBuffer(cpu_allocator,
@@ -111,8 +111,9 @@ struct GreedySearchState : public IGreedySearchState {
this->topk_scores_buffer,
this->topk_tokens_buffer);
- // If at all we need to, we only need to re-order past state for CUDA
- if (allocate_staging_buffer_for_past_state_reorder) {
+ // If at all we need to, we only need to re-order past state for CUDA as
+ //`DecoderMaskedMultiheadAttention` is only supported on CUDA
+ if (has_decoder_masked_multihead_attention) {
TensorShape staging_for_past_state_reorder_buffer_shape = {batch_size, num_heads, max_length, head_size};
Tensor temp(DataTypeImpl::GetType(), staging_for_past_state_reorder_buffer_shape, allocator);
diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h
index fc9cc01c86..d1afcbc6c6 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h
@@ -170,11 +170,14 @@ Status GreedySearchGpt::UpdateFeeds(
increase_position,
next_tokens,
place_holder,
+ place_holder,
this->parameters_->num_beams,
gpt_subgraph_.GetFirstPastInputIndex(),
gpt_subgraph_.GetFirstPresentOutputIndex(),
gpt_subgraph_.past_present_share_buffer_,
- past_sequence_length);
+ past_sequence_length,
+ -1, // Input sequence length needn't be passed in for GreedySearch
+ false);
}
template
@@ -329,7 +332,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
// contains DecoderMaskedMultiheadAttention nodes
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_multihead_attention_) {
- size_t offset = static_cast(gpt_subgraph_.GetFirstPresentOutputIndex());
+ size_t offset = static_cast(gpt_subgraph_.GetFirstPresentOutputIndex());
// We will use the same staging buffer while transposing all the layers' past state
// and this is okay because we use the same stream to do the staging copy and the transpose
// operations.
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
index 6703c6eec9..012102b45c 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
@@ -91,6 +91,9 @@ Status Subgraph::Setup(const SessionState& session_state,
past_present_share_buffer_ = true;
// past_sequence_length is on CPU memory
feed_locations.push_back(OrtDevice());
+ } else if (feed_names[i] == "beam_width") {
+ // beam_width is on CPU memory
+ feed_locations.push_back(OrtDevice());
} else {
feed_locations.push_back(default_location.device);
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc
index 3d8af10737..c12301803a 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc
@@ -27,7 +27,8 @@ Status GptSubgraph::CreateInitialFeeds(
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr& buffer,
Stream* ort_stream,
- int max_seq_len_past_present_share_buffer) {
+ int past_present_share_buffer_max_seq_len,
+ bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
const IExecutionProvider* provider = GetProvider();
@@ -83,20 +84,48 @@ Status GptSubgraph::CreateInitialFeeds(
feeds.push_back(empty_past);
}
} else {
- int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, max_seq_len_past_present_share_buffer, head_size};
+ // Past state feeds
+ int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, past_present_share_buffer_max_seq_len, head_size};
TensorShape past_shape(&past_state_dims[0], 5);
- // The remaining inputs are past state except the last one
- for (int i = first_past_input_index_; i < num_subgraph_inputs - 1; ++i) {
+
+ // The remaining inputs are past state except the last one or three (see below for details)
+ // If `add_beam_search_specific_inputs_for_decoder_masked_multihead_attention` is false, then the last input is `past_sequence_length`
+
+ // If `add_beam_search_specific_inputs_for_decoder_masked_multihead_attention` is true, then the last inputs are `past_sequence_length`,
+ // `beam_width`, and `cache_indirection`
+ auto past_end_iter = add_beam_search_specific_inputs_for_decoder_masked_multihead_attention ? num_subgraph_inputs - 3 : num_subgraph_inputs - 1;
+ for (int i = first_past_input_index_; i < past_end_iter; ++i) {
OrtValue past_tensor;
Tensor::InitOrtValue(past_type, past_shape, default_allocator, past_tensor);
feeds.push_back(past_tensor);
}
+
+ // Past sequence length feed
int64_t past_seq_len_dims[] = {1};
TensorShape past_seq_len_shape(&past_seq_len_dims[0], 1);
OrtValue past_seq_len_tensor_value;
Tensor::InitOrtValue(DataTypeImpl::GetType(), past_seq_len_shape, cpu_allocator, past_seq_len_tensor_value);
feeds.push_back(past_seq_len_tensor_value);
*past_seq_len_tensor_value.GetMutable()->MutableData() = 0;
+
+ // Add beam search specific inputs
+ if (add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
+ // Beam width feed
+ int64_t num_beams_dims[] = {1};
+ TensorShape num_beams_shape(&num_beams_dims[0], 1);
+ OrtValue num_beams_tensor_value;
+ Tensor::InitOrtValue(DataTypeImpl::GetType(), num_beams_shape, cpu_allocator, num_beams_tensor_value);
+ feeds.push_back(num_beams_tensor_value);
+ *num_beams_tensor_value.GetMutable()->MutableData() = static_cast(num_beams);
+
+ // Cache indirection feed
+ int64_t cache_indirection_dims[] = {batch_size, num_beams, past_present_share_buffer_max_seq_len};
+ TensorShape cache_indirection_shape(&cache_indirection_dims[0], 3);
+ OrtValue default_cache_indirection;
+ Tensor::InitOrtValue(DataTypeImpl::GetType(), cache_indirection_shape,
+ default_allocator, default_cache_indirection);
+ feeds.push_back(default_cache_indirection);
+ }
}
// Pass in implicit inputs
@@ -112,8 +141,12 @@ Status GptSubgraph::Validate(const std::vector& subgraph_inputs,
ORT_RETURN_IF(num_subgraph_outputs <= first_present_output_index_,
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in outputs).");
- ORT_RETURN_IF(!((num_subgraph_inputs == num_subgraph_outputs + 2) || (num_subgraph_inputs == num_subgraph_outputs + 3)),
- "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2 or 3 (if past_present_share_buffer)");
+ ORT_RETURN_IF(!((num_subgraph_inputs == num_subgraph_outputs + 2) ||
+ (num_subgraph_inputs == num_subgraph_outputs + 3) ||
+ (num_subgraph_inputs == num_subgraph_outputs + 5)),
+ "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2 or "
+ "3 (if past_present_share_buffer) or "
+ "5 (if past_present_share_buffer and use_decoder_masked_multihead_attention for BeamSearch)");
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
"subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h
index 855d017e64..c24ac3a5d1 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h
@@ -16,9 +16,9 @@ class GptSubgraph : public Subgraph {
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {
- first_past_input_index_ = 3;
- first_present_output_index_ = 1;
- }
+ first_past_input_index_ = 3;
+ first_present_output_index_ = 1;
+ }
// Create inputs for first inference of subgraph.
Status CreateInitialFeeds(
@@ -34,7 +34,8 @@ class GptSubgraph : public Subgraph {
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr& buffer,
Stream* ort_stream,
- int max_seq_len_past_present_share_buffer = -1);
+ int past_present_share_buffer_max_seq_len = -1,
+ bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention = false);
Status Validate(const std::vector& subgraph_inputs,
const std::vector& subgraph_outputs) override;
diff --git a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc
index 8e28bb253e..49000855f2 100644
--- a/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc
+++ b/onnxruntime/contrib_ops/cuda/decoder/decoder_masked_multihead_attention.cc
@@ -16,20 +16,23 @@ namespace contrib {
namespace cuda {
static constexpr int kPastSequenceLengthInputIndex = 6;
+static constexpr int kBeamWidthInputIndex = 7;
+static constexpr int kCacheIndirectionInputIndex = 8;
static constexpr int kPastInputIndex = 4;
static constexpr int kPresentOutputIndex = 1;
-#define REGISTER_KERNEL_TYPED(T1, T2) \
- ONNX_OPERATOR_TYPED_KERNEL_EX( \
- DecoderMaskedMultiheadAttention, \
- kMSDomain, \
- 1, \
- T1, \
- kCudaExecutionProvider, \
- (*KernelDefBuilder::Create()) \
- .MayInplace(kPastInputIndex, kPresentOutputIndex) \
- .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
- .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \
+#define REGISTER_KERNEL_TYPED(T1, T2) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ DecoderMaskedMultiheadAttention, \
+ kMSDomain, \
+ 1, \
+ T1, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .MayInplace(kPastInputIndex, kPresentOutputIndex) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()) \
+ .InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \
+ .InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \
DecoderMaskedMultiheadAttention);
REGISTER_KERNEL_TYPED(float, float)
@@ -44,6 +47,8 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext*
const Tensor* past = context->Input(kPastInputIndex);
const Tensor* relative_position_bias = context->Input(5);
const Tensor* past_seq_len = context->Input(kPastSequenceLengthInputIndex);
+ const Tensor* beam_width = context->Input(kBeamWidthInputIndex);
+ const Tensor* cache_indir = context->Input(kCacheIndirectionInputIndex);
auto& device_prop = GetDeviceProp();
DecoderMaskedMultiheadAttentionParams parameters;
@@ -105,7 +110,7 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext*
auto* present_data = present->MutableData();
auto* past_data = past->Data();
- // No production use-case will incure this copy cost as the implementation of
+ // No production use-case will incur this copy cost as the implementation of
// GreedySearch/BeamSearch is written in such a way that the past and present buffers
// will be shared.
// This is just to circumvent the OpTester's limitation of not being able to bind a specific
@@ -142,13 +147,13 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext*
// Update the q, k, and v buffers
parameters.q = gemm_buffer.get();
parameters.k = reinterpret_cast(gemm_buffer.get()) + parameters.hidden_size;
- parameters.v = reinterpret_cast(gemm_buffer.get()) + 2 * parameters.hidden_size;
+ parameters.v = reinterpret_cast(gemm_buffer.get()) + 2 * static_cast(parameters.hidden_size);
// Update the q, k, and v bias
const T1* bias_data = bias->Data();
parameters.q_bias = const_cast(bias_data);
parameters.k_bias = const_cast(bias_data + parameters.hidden_size);
- parameters.v_bias = const_cast(bias_data + 2 * parameters.hidden_size);
+ parameters.v_bias = const_cast(bias_data + 2 * static_cast(parameters.hidden_size));
// Half of the past/present buffer correspond to K - the other half is V.
auto k_size = present->Shape().Size() / 2;
@@ -167,6 +172,22 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext*
parameters.mask = mask_index->Data();
}
+ // Beam width (in case we are using this op inside BeamSearch)
+ if (beam_width != nullptr) {
+ parameters.beam_width = static_cast(*beam_width->Data());
+ }
+
+ // Cache indirection (in case we are using this op inside BeamSearch)
+ if (parameters.beam_width > 1) {
+ // If beam width > 1, then cache indirection buffer MUST be present
+ if (cache_indir == nullptr) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "If beam width is greater than 1, then cache indirection buffer MUST be present");
+ }
+
+ parameters.cache_indir = cache_indir->Data();
+ }
+
switch (parameters.head_size) {
case 64:
mmha_launch_kernel(parameters, cuda_stream);
@@ -182,7 +203,6 @@ Status DecoderMaskedMultiheadAttention::ComputeInternal(OpKernelContext*
"Got head size: ",
parameters.head_size);
}
-
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
index 39dec84663..27a2b6de83 100644
--- a/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/decoder/fastertransformer_decoder_attention/decoder_masked_multihead_attention_impl.cu
@@ -22,6 +22,8 @@
// Modifications:
// (1) Removed some code paths from the original implementation that had features which is not supported by
// corresponding ORT kernel - for example- CrossAttention support, FP8, INT8, supports, etc.
+// (2) When dealing with masked tokens, this kernel implementation deviates from FasterTransformer by applying
+// mask filter values. Appropriate commentary exists in the code below.
#include "decoder_masked_multihead_attention_impl.h"
#include "decoder_masked_multihead_attention_impl_utils.h"
@@ -286,7 +288,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
bool has_beams = params.cache_indir != nullptr;
-
const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_max_seq_length] : nullptr;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
@@ -294,16 +295,24 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
// The keys loaded from the key cache.
K_vec_k k_vec[K_VECS_PER_THREAD];
-#pragma unroll
- for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
- int jj = ii * params.max_sequence_length + ti;
- if (ti < tlength) {
- if (has_beams) {
+ if (has_beams) {
+#pragma unroll
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ int jj = ii * params.max_sequence_length + ti;
+
+ if (ti < tlength) {
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;
k_vec[ii] = vec_conversion(
(*reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
- } else {
+ }
+ }
+ } else {
+#pragma unroll
+ for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
+ int jj = ii * params.max_sequence_length + ti;
+
+ if (ti < tlength) {
k_vec[ii] = vec_conversion(
(*reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B])));
}
@@ -314,9 +323,16 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float qk = Qk_dot::dot(q_vec, k_vec) * inv_sqrt_dh;
+ // This is a deviation from FasterTransformer kernel implementation
+ // but this aligns with ORT's other Attention kernels which strives to
+ // mimic PyTorch when dealing with mask filter values
+ if (is_masked) {
+ qk += params.mask_filter_value;
+ }
+
// Store the product to shared memory. There's one qk value per timestep. Update the max.
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
- qk_max = is_masked ? qk_max : fmaxf(qk_max, qk);
+ qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
}
@@ -355,8 +371,10 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
// Compute the logits and start the sum.
float sum = 0.f;
for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
- bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0);
- float logit = is_masked ? 0.f : __expf(qk_smem[ti] - qk_max);
+ // This is a deviation from FasterTransformer kernel implementation
+ // but this aligns with ORT's other Attention kernels which strives to
+ // mimic PyTorch when dealing with mask filter values
+ float logit = __expf(qk_smem[ti] - qk_max);
sum += logit;
qk_smem[ti] = logit;
}
diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc
index 49aff75eb8..e2ced4786d 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc
@@ -35,7 +35,8 @@ transformers::CudaTensorConsoleDumper g_cuda_dumper;
BeamSearch::BeamSearch(const OpKernelInfo& info)
: onnxruntime::contrib::transformers::BeamSearch(info) {
- SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds,
+ SetDeviceHelpers(GenerationCudaDeviceHelper::ReorderPastState,
+ GenerationCudaDeviceHelper::AddToFeeds,
GenerationCudaDeviceHelper::TopK,
GenerationCudaDeviceHelper::DeviceCopy,
GenerationCudaDeviceHelper::DeviceCopy,
@@ -54,6 +55,11 @@ BeamSearch::BeamSearch(const OpKernelInfo& info)
GenerationCudaDeviceHelper::ExpandBuffer);
SetConsoleDumper(&g_cuda_dumper);
+
+ cuda_device_prop_ = &reinterpret_cast(info.GetExecutionProvider())->GetDeviceProp();
+
+ cuda_device_arch_ = static_cast(cuda_device_prop_)->major * 100 +
+ static_cast(cuda_device_prop_)->minor * 10;
}
Status BeamSearch::ComputeInternal(OpKernelContext* context) const {
diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
index 90c9122820..b354b1c9a6 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
+++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu
@@ -424,7 +424,7 @@ void LaunchSortPairs(void* d_temp_storage,
d_offsets + 1,
0,
sizeof(T) * 8,
- stream));
+ stream));
} else {
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
temp_storage_bytes,
@@ -469,7 +469,7 @@ template void LaunchSortPairs(void* d_temp_storage,
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
struct BlockPrefixCallbackOp {
- float running_total; // running prefix
+ float running_total; // running prefix
__device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
@@ -746,6 +746,65 @@ void TorchMultinomialKernelLauncher(float* d_input,
d_presence_mask);
}
+__global__ void UpdateDecoderMaskedMultiheadAttentionCacheIndirectionKernel(int32_t* tgt_indir_cache,
+ const int32_t* src_indir_cache,
+ const int32_t* beam_ids,
+ int batch_size,
+ int beam_width,
+ int input_seq_length,
+ int max_seq_length,
+ int current_length) {
+ int time_step = threadIdx.x + blockIdx.x * blockDim.x;
+ int bb_id = threadIdx.y + blockIdx.y * blockDim.y;
+ const int batch_id = bb_id / beam_width;
+ const int beam_id = bb_id % beam_width;
+
+ if (bb_id >= beam_width * batch_size || time_step >= current_length) {
+ return;
+ }
+
+ const int src_beam = beam_ids[batch_id * beam_width + beam_id] % beam_width;
+
+ const int tgt_offset = batch_id * beam_width * max_seq_length + beam_id * max_seq_length + time_step;
+
+ if (time_step < input_seq_length) {
+ // For time steps that correspond to the input sequence,
+ // the beam that it comes from is always 0.
+ tgt_indir_cache[tgt_offset] = static_cast(0);
+ } else if (time_step == (current_length - 1)) {
+ // For the final (newly generated) time step,
+ // the beam that it comes from is always the beam that we
+ // are currently processing (i.e.) from this point on, these time-steps
+ // form the new beams.
+ tgt_indir_cache[tgt_offset] = static_cast(beam_id);
+ } else {
+ // For all other time-steps, we look up the source indirection, to
+ // see which beam it came from based on the `src_beam`.
+ const int src_offset = batch_id * beam_width * max_seq_length + src_beam * max_seq_length + time_step;
+ tgt_indir_cache[tgt_offset] = src_indir_cache[src_offset];
+ }
+}
+
+void UpdateDecoderMaskedMultiheadAttentionCacheIndirection(int32_t* tgt_indir_cache,
+ const int32_t* src_indir_cache,
+ const int32_t* beam_ids,
+ int batch_size,
+ int beam_width,
+ int input_seq_length,
+ int max_seq_length,
+ int current_length,
+ cudaStream_t stream) {
+ const dim3 block(32);
+ const dim3 grid((current_length + block.x - 1) / block.x, batch_size * beam_width);
+ UpdateDecoderMaskedMultiheadAttentionCacheIndirectionKernel<<>>(tgt_indir_cache,
+ src_indir_cache,
+ beam_ids,
+ batch_size,
+ beam_width,
+ input_seq_length,
+ max_seq_length,
+ current_length);
+}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h
index 5209622823..85b97df7a0 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h
+++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h
@@ -61,7 +61,7 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data,
cudaStream_t stream);
template
-void GetTempStorageSize(const T *d_keys_in,
+void GetTempStorageSize(const T* d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
@@ -77,15 +77,15 @@ void LaunchSetupParamsKernel(int* d_values_in,
cudaStream_t stream);
template
-void LaunchSortPairs(void *d_temp_storage,
+void LaunchSortPairs(void* d_temp_storage,
size_t temp_storage_bytes,
- const T *d_keys_in,
- T *d_keys_out,
- const int *d_values_in,
- int *d_values_out,
+ const T* d_keys_in,
+ T* d_keys_out,
+ const int* d_values_in,
+ int* d_values_out,
int num_items,
int num_segments,
- int *d_offsets,
+ int* d_offsets,
cudaStream_t stream,
bool is_descending);
@@ -109,6 +109,16 @@ void TorchMultinomialKernelLauncher(float* d_input,
int* d_presence_mask,
cudaStream_t stream);
+void UpdateDecoderMaskedMultiheadAttentionCacheIndirection(int32_t* tgt_indir_cache,
+ const int32_t* src_indir_cache,
+ const int32_t* beam_ids,
+ int batch_size,
+ int beam_width,
+ int input_seq_length,
+ int max_seq_length,
+ int current_length,
+ cudaStream_t stream);
+
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
index e4d75971c9..eed24b4a1f 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc
@@ -481,6 +481,8 @@ Status ProcessLogits(const OrtValue& logits, //
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
#endif
+
+ // TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
beam_state->next_scores.data(),
beam_state->next_scores.size_bytes(),
@@ -528,6 +530,7 @@ Status ProcessLogits(const OrtValue& logits, //
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
#endif
+ // TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
data,
topk_scores->SizeInBytes(),
@@ -535,6 +538,7 @@ Status ProcessLogits(const OrtValue& logits, //
cuda_stream));
}
+ // TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(),
beam_state->next_tokens.data(),
beam_state->next_tokens.size_bytes(),
@@ -551,13 +555,31 @@ Status ProcessLogits(const OrtValue& logits, //
gsl::span next_tokens(cpu_state->topk_tokens.data(), beam_state->next_tokens.size());
gsl::span next_indices(cpu_state->topk_indices.data(), beam_state->next_indices.size());
- // Limitation: beam scorer runs in CPU. It might be better to use CUDA kernel to replace it.
+ // TODO: Implement BeamScorer on CUDA
beam_scorer->Process(
sequences,
next_scores,
next_tokens,
next_indices);
+ // TODO: This is a temporary work-around as BeamScorer currently only runs on CPU.
+ // We can remove these kinds of work-arounds once BeamScorer runs on CUDA eventually.
+ auto chosen_indices = beam_scorer->GetNextIndices();
+ auto beam_state_chosen_indices = beam_state->chosen_indices;
+
+ if (!beam_state_chosen_indices.empty()) {
+ // If we have allocated `chosen_indices` in beam_state, it means that we
+ // will be needing the chosen indices from BeamScorer as we are using
+ // DecoderMaskedMultiheadAttention, so copy it over.
+ CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state_chosen_indices.data(),
+ chosen_indices.data(),
+ chosen_indices.size_bytes(),
+ cudaMemcpyHostToDevice,
+ cuda_stream));
+
+ CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
+ }
+
#ifdef ENABLE_NVTX_PROFILE
processLogitsRange.End();
#endif
@@ -869,12 +891,15 @@ Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span beam_next_tokens,
- gsl::span beam_indices,
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
- int past_sequence_len) {
+ int past_sequence_len,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
#ifdef ENABLE_NVTX_PROFILE
profile::NvtxNestedRangeCreator updateFeedsRange("UpdateGptFeeds", profile::Color::Yellow);
updateFeedsRange.Begin();
@@ -889,6 +914,8 @@ Status UpdateGptFeeds(
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids);
int32_t* input_ids_data = input_ids.GetMutable()->MutableData();
cudaStream_t cuda_stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr;
+
+ // TODO(hasesh): When BeamScorer is implemented on CUDA, figure out a way to avoid this cross-device copy
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream));
next_inputs[0] = input_ids;
@@ -914,8 +941,41 @@ Status UpdateGptFeeds(
next_inputs[2] = attention_mask;
if (past_present_share_buffer) {
- const int k = (static_cast(last_outputs.size()) - gpt_subgraph_first_present_output_idx) + gpt_subgraph_first_past_input_idx;
- *(next_inputs[k].GetMutable()->MutableData()) = past_sequence_len;
+ // Update past sequence length input
+ const int past_sequence_length_idx = (static_cast(last_outputs.size()) - gpt_subgraph_first_present_output_idx) + gpt_subgraph_first_past_input_idx;
+ *(next_inputs[past_sequence_length_idx].GetMutable()->MutableData()) = past_sequence_len;
+
+ // Update beam search specific input for DecoderMaskedMultiheadAttention (cache indirection) if present
+
+ // If the last input is not `past_sequence_length`, then the beam search specific inputs
+ // for `DecoderMaskedMultiheadAttention` is present
+ if (has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
+ ORT_ENFORCE(!beam_indices_gpu.empty(), "Beam indices must be present on CUDA while using DecoderMaskedMultiheadAttention with BeamSearch");
+
+ // The cache indirection feed comes 2 feeds after the `past_sequence_length` feed
+ const OrtValue& old_cache_indirection = next_inputs[past_sequence_length_idx + 2];
+
+ // New cache indirection updated for next decoding run
+ OrtValue cache_indirection;
+
+ Tensor::InitOrtValue(DataTypeImpl::GetType(), old_cache_indirection.Get().Shape(), allocator, cache_indirection);
+
+ // The third index of the past/present tensor is the max_sequence_length
+ int max_sequence_length = static_cast(last_outputs[gpt_subgraph_first_present_output_idx].Get().Shape()[3]);
+
+ // Launch kernel to update the cache indirection buffer
+ cuda::UpdateDecoderMaskedMultiheadAttentionCacheIndirection(cache_indirection.GetMutable()->MutableData(),
+ old_cache_indirection.Get().Data(),
+ reinterpret_cast(beam_indices_gpu.data()),
+ batch_beam_size / num_beams,
+ num_beams,
+ input_sequence_len,
+ max_sequence_length,
+ current_length,
+ cuda_stream);
+ // Update cache indirection for next decoding run
+ next_inputs[past_sequence_length_idx + 2] = cache_indirection;
+ }
} else {
if (num_beams == 1) {
const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx;
@@ -924,7 +984,7 @@ Status UpdateGptFeeds(
next_inputs[i + k] = last_outputs[i];
}
} else {
- ORT_RETURN_IF_ERROR(PickGptPastState(last_outputs, next_inputs, beam_indices, allocator,
+ ORT_RETURN_IF_ERROR(PickGptPastState(last_outputs, next_inputs, beam_indices_cpu, allocator,
gpt_subgraph_first_past_input_idx,
gpt_subgraph_first_present_output_idx, ort_stream));
}
@@ -1114,12 +1174,15 @@ template Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span beam_next_tokens,
- gsl::span beam_indices,
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
- int past_sequence_len);
+ int past_sequence_len,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// Float16
template void InitBeamState(
@@ -1171,12 +1234,15 @@ template Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span beam_next_tokens,
- gsl::span beam_indices,
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
- int past_sequence_len);
+ int past_sequence_len,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
template Status UpdateDecoderFeeds(
AllocatorPtr allocator,
diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h
index 13ea4931a5..42a7a32959 100644
--- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h
+++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h
@@ -97,12 +97,15 @@ Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span beam_next_tokens,
- gsl::span beam_indices,
+ gsl::span beam_indices_cpu,
+ gsl::span beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
- int past_sequence_len);
+ int past_sequence_len,
+ int input_sequence_len,
+ bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// ---------------------------------------------------------------
// Functions for encoder-decoder model like T5
diff --git a/onnxruntime/core/graph/contrib_ops/bert_defs.cc b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
index 0b1d319272..b48aab3813 100644
--- a/onnxruntime/core/graph/contrib_ops/bert_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/bert_defs.cc
@@ -367,9 +367,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"past",
"past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)"
"When past_present_share_buffer is set, "
- "its shape is (2, batch_size, num_heads, max_sequence_length, head_size)",
- "T",
- OpSchema::Optional)
+ "its shape is (2, batch_size, num_heads, max_sequence_length, head_size). "
+ "The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys "
+ "and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. "
+ "The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape "
+ "(batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape "
+ "(batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to "
+ "become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.",
+ "T")
.Input(5,
"relative_position_bias",
"additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)",
@@ -378,6 +383,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(6,
"past_sequence_length",
"When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).",
+ "M")
+ .Input(7,
+ "beam_width",
+ "The beam width that is being used while decoding."
+ "If not provided, the beam width will be assumed to be 1.",
+ "M",
+ OpSchema::Optional)
+ .Input(8,
+ "cache_indirection",
+ "A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies"
+ "which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration",
"M",
OpSchema::Optional)
.Output(0,
@@ -390,8 +406,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"If past_present_share_buffer is set, "
"its shape is (2, batch_size, num_heads, max_sequence_length, head_size), "
"while effective_seq_length = (past_sequence_length + kv_sequence_length).",
- "T",
- OpSchema::Optional)
+ "T")
.TypeConstraint("T",
{"tensor(float)", "tensor(float16)"},
"Constrain input and output types to float tensors.")
diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py
index 37855c8c02..a7ea3b4e05 100644
--- a/onnxruntime/python/tools/transformers/convert_generation.py
+++ b/onnxruntime/python/tools/transformers/convert_generation.py
@@ -1026,7 +1026,7 @@ def shape_of(vi):
return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim])
-def update_decoder_subgraph_past_present_share_buffer(subg):
+def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto):
input_past_0 = 3
output_past_0 = 1
new_inputs = []
@@ -1074,37 +1074,82 @@ def update_decoder_subgraph_past_present_share_buffer(subg):
return subg
-def update_decoder_subgraph_use_decoder_masked_multihead_attention(subg) -> bool:
+def update_decoder_subgraph_use_decoder_masked_multihead_attention(
+ subg: GraphProto, is_beam_search: bool, switch_attention: bool
+) -> bool:
"""Update the Attention nodes to DecoderMaskedMultiheadAttention.
Args:
subg (GraphProto): GraphProto of the decoder subgraph
+ is_beam_search (bool): Boolean specifying if the sampling algo is BeamSearch
+ switch_attention (bool): Boolean specifying if `Attention` is to be switched with `DecoderMaskedMultiheadAttention`
"""
- decoder_masked_attention_supported_attr = [
- "past_present_share_buffer",
- "num_heads",
- "scale",
- "domain",
- ]
- new_nodes = []
- for node in subg.node:
- if node.op_type == "Attention":
- kwargs = kwargs_of(node)
- for k in kwargs.copy():
- # The Attention operator does not support different qkv hidden sizes when past/present
- # input/output exists (GPT2 model). Hence, we should never run into this.
- # But, if we do, do not go ahead with the optimization.
- if k == "qkv_hidden_sizes":
- return False
+ if is_beam_search:
+ new_inputs = []
+ for i, vi in enumerate(subg.input):
+ new_inputs.extend([vi])
+
+ # Add 2 BeamSearch specific inputs
+ new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
+ new_inputs.extend(
+ [
+ onnx.helper.make_tensor_value_info(
+ "cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
+ )
+ ]
+ )
+ subg.ClearField("input")
+ subg.input.extend(new_inputs)
+
+ if switch_attention:
+ decoder_masked_attention_supported_attr = [
+ "past_present_share_buffer",
+ "num_heads",
+ "scale",
+ "domain",
+ ]
+
+ new_nodes = []
+ for node in subg.node:
+ if node.op_type == "Attention":
+ kwargs = kwargs_of(node)
+ for k in kwargs.copy():
+ # The Attention operator does not support different qkv hidden sizes when past/present
+ # input/output exists (GPT2 model). Hence, we should never run into this.
+ # But, if we do, do not go ahead with the optimization.
+ if k == "qkv_hidden_sizes":
+ return False
+
+ if k not in decoder_masked_attention_supported_attr:
+ # Log the fact that we are removing certain attributes from the node
+ # We don't need to log it for "unidirectional" as we are aware that
+ # decoding attention kernels are unidirectional by definition.
+ if k != "unidirectional":
+ logger.warning(
+ f"Removing attribute: {k} from Attention node while switching to DecoderMaskedMultiheadAttention"
+ )
+
+ del kwargs[k]
+
+ nis = []
+ nis.extend(node.input)
+
+ # Add 2 BeamSearch specific inputs
+ if is_beam_search:
+ while len(nis) < 7:
+ nis.extend([""])
+ if len(nis) < 8:
+ nis.extend(["beam_width"])
+ if len(nis) < 9:
+ nis.extend(["cache_indirection"])
+
+ node = onnx.helper.make_node(
+ "DecoderMaskedMultiheadAttention", nis, node.output, name=node.name, **kwargs
+ )
+ new_nodes.extend([node])
+ subg.ClearField("node")
+ subg.node.extend(new_nodes)
- if k not in decoder_masked_attention_supported_attr:
- del kwargs[k]
- nis = []
- nis.extend(node.input)
- node = onnx.helper.make_node("DecoderMaskedMultiheadAttention", nis, node.output, name=node.name, **kwargs)
- new_nodes.extend([node])
- subg.ClearField("node")
- subg.node.extend(new_nodes)
return True
@@ -1452,9 +1497,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH
is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH
is_sampling: bool = generation_type == GenerationType.SAMPLING
- past_present_share_buffer: bool = args.past_present_share_buffer and (is_greedysearch or is_sampling)
+ past_present_share_buffer: bool = args.past_present_share_buffer
- logger.info(f"**** past_present_share_buffer={past_present_share_buffer}, is_greedysearch={is_greedysearch}")
+ logger.info(f"**** past_present_share_buffer={past_present_share_buffer}")
if is_greedysearch or is_sampling:
if not is_gpt2:
@@ -1464,6 +1509,29 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
if args.output_token_scores:
raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling")
+ # For BeamSearch, sharing buffers for past and present states is only supported
+ # when using `use_decoder_masked_multihead_attention`
+ if past_present_share_buffer and is_beamsearch and not args.use_decoder_masked_multihead_attention:
+ raise ValueError(
+ "`use_decoder_masked_multihead_attention` MUST be turned on to use `past_present_share_buffer` in case of BeamSearch"
+ )
+
+ # For any kind of sampling, using decoder masked multihead attention is only supported
+ # when using `past_present_share_buffer`
+ if args.use_decoder_masked_multihead_attention and not past_present_share_buffer:
+ raise ValueError(
+ "`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_multihead_attention`"
+ )
+
+ # For any kind of sampling, using decoder masked multihead attention is only supported
+ # on GPUs
+ if args.use_decoder_masked_multihead_attention and not args.use_gpu:
+ raise ValueError("`use_decoder_masked_multihead_attention` option is only supported on GPUs")
+
+ # Using decoder masked multihead attention is only supported for GPT2
+ if args.use_decoder_masked_multihead_attention and args.model_type in ["t5", "mt5"]:
+ raise ValueError("`use_decoder_masked_multihead_attention` option is only supported for GPT2")
+
if is_gpt2:
if args.decoder_onnx and os.path.exists(args.decoder_onnx):
logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
@@ -1699,6 +1767,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
node.attribute.extend(attr_to_extend)
initializers = []
+
if args.model_type in ["t5", "mt5"]:
if args.run_shape_inference:
logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
@@ -1745,37 +1814,43 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
logger.info(
f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in decoder and init decoder subgraphs are moved to the main graph"
)
- # Update for init decoder subgraph
+
+ # Update init decoder subgraph in preparation to use past present share buffer
if past_present_share_buffer:
logger.info("*****update init decoder subgraph to make past and present share buffer******************")
update_decoder_subgraph_past_present_share_buffer(gpt2_init_decoder_model.graph)
+
+ # Update init decoder subgraph in preparation to use DecoderMaskedMultiheadAttention
+ # NOTE: Even if we will not use DecoderMaskedMultiheadAttention in the init decoder subgraph
+ # it makes the runtime changes cleaner if we keep both the init decoder and decoder subgraphs
+ # same in terms of the subgraph inputs.
+ if (
+ args.use_decoder_masked_multihead_attention
+ and not update_decoder_subgraph_use_decoder_masked_multihead_attention(
+ gpt2_init_decoder_model.graph, is_beamsearch, False
+ )
+ ):
+ raise ValueError("Could not update the init decoder subgraph to use DecoderMaskedMultiheadAttention")
+
node.attribute.append(onnx.helper.make_attribute("init_decoder", gpt2_init_decoder_model.graph))
else:
# Move initializer from subgraph to main graph could reduce memory usage in inference.
initializers = move_initializers(decoder_model.graph)
logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph")
- # Update for non-init decoder subgraph
+ # Update decoder subgraph in preparation to use past present share buffer
if past_present_share_buffer:
logger.info("*****update decoder subgraph to make past and present share buffer******************")
update_decoder_subgraph_past_present_share_buffer(decoder_model.graph)
- # If at all, the user requests the use of `use_decoder_masked_multihead_attention`,
- # we can only use it in the decoder subgraph - it cannot be used in the init_decoder subgraph
- # as the kernel only supports the decoding use-case (i.e.) when input sequence length is 1.
- if args.use_decoder_masked_multihead_attention:
- logger.info("Update decoder subgraph to use DecoderMaskedMultiheadAttention")
-
- if not past_present_share_buffer:
- raise ValueError(
- "`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_multihead_attention`"
- )
-
- if not args.use_gpu:
- raise ValueError("`use_decoder_masked_multihead_attention` option is only supported on GPUs")
-
- if not update_decoder_subgraph_use_decoder_masked_multihead_attention(decoder_model.graph):
- raise ValueError("Could not update the decoder subgraph to use DecoderMaskedMultiheadAttention")
+ # Update decoder subgraph in preparation to use DecoderMaskedMultiheadAttention
+ if (
+ args.use_decoder_masked_multihead_attention
+ and not update_decoder_subgraph_use_decoder_masked_multihead_attention(
+ decoder_model.graph, is_beamsearch, True
+ )
+ ):
+ raise ValueError("Could not update the decoder subgraph to use DecoderMaskedMultiheadAttention")
node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph))
diff --git a/onnxruntime/test/python/transformers/test_generation.py b/onnxruntime/test/python/transformers/test_generation.py
index 7bc8c1db45..33122f08b7 100644
--- a/onnxruntime/test/python/transformers/test_generation.py
+++ b/onnxruntime/test/python/transformers/test_generation.py
@@ -156,6 +156,18 @@ class TestBeamSearchGpt(unittest.TestCase):
if self.enable_cuda:
self.run_beam_search("--repetition_penalty 1.0 --use_gpu -p fp16", is_greedy=True)
+ @pytest.mark.slow
+ def test_beam_search_use_decoder_masked_multihead_attention(self):
+ if self.enable_cuda:
+ self.run_beam_search(f"--past_present_share_buffer --use_decoder_masked_multihead_attention --use_gpu")
+
+ @pytest.mark.slow
+ def test_beam_search_use_decoder_masked_multihead_attention_fp16(self):
+ if self.enable_cuda:
+ self.run_beam_search(
+ f"--past_present_share_buffer --use_decoder_masked_multihead_attention --use_gpu -p fp16"
+ )
+
@pytest.mark.slow
def test_external_data(self):
self.run_beam_search(