[CUDA] Add option to use DecoderMaskedMultiheadAttention in BeamSearch (#14990)

This commit is contained in:
Hariharan Seshadri 2023-03-15 17:16:32 -07:00 committed by GitHub
parent da084b0fc1
commit ed7ab1660d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 645 additions and 174 deletions

View file

@ -1133,7 +1133,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
</dl>
#### Inputs (3 - 7)
#### Inputs (7 - 9)
<dl>
<dt><tt>input</tt> : T</dt>
@ -1144,20 +1144,24 @@ This version of the operator has been available since version 1 of the 'com.micr
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection</dd>
<dt><tt>mask_index</tt> (optional) : M</dt>
<dd>Mask values of shape (batch_size, total_sequence_length)</dd>
<dt><tt>past</tt> (optional) : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>past</tt> : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size). The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
<dt><tt>past_sequence_length</tt> : M</dt>
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
<dt><tt>beam_width</tt> (optional) : M</dt>
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
<dt><tt>cache_indirection</tt> (optional) : M</dt>
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
</dl>
#### Outputs (1 - 2)
#### Outputs
<dl>
<dt><tt>output</tt> : T</dt>
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
<dt><tt>present</tt> (optional) : T</dt>
<dt><tt>present</tt> : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
</dl>

View file

@ -798,7 +798,7 @@ Do not modify directly.*
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|DecoderAttention|*in* query:**T**<br> *in* key:**T**<br> *in* q_weight:**T**<br> *in* kv_weight:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**B**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* static_kv:**B**<br> *in* use_past:**B**<br> *in* has_layer_state:**B**<br> *in* has_key_padding_mask:**B**<br> *out* output:**T**<br> *out* new_key_cache:**T**<br> *out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderMaskedMultiheadAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|DecoderMaskedMultiheadAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(float16)|
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|

View file

@ -163,6 +163,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
if (has_init_decoder_) {
ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
"past_present_share_buffer mode must be same for init decoder and decoder subgraphes");
}
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
@ -181,12 +183,15 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
thread_pool, ctx->GetComputeStream(), dumper_, parameters,
GenerationCpuDeviceHelper::CreateGptInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>};
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>,
cuda_device_prop_,
cuda_device_arch_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
@ -200,12 +205,15 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
thread_pool, ctx->GetComputeStream(), dumper_, parameters,
GenerationCpuDeviceHelper::CreateGptInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
update_gpt_feeds_fp16_func_};
update_gpt_feeds_fp16_func_,
cuda_device_prop_,
cuda_device_arch_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);

View file

@ -44,6 +44,7 @@ class BeamSearch : public IControlFlowKernel {
// device helpers that is same for both GPT and encoder-decoder models.
void SetDeviceHelpers(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
@ -52,6 +53,7 @@ class BeamSearch : public IControlFlowKernel {
const GenerationDeviceHelper::ProcessLogitsFunc<MLFloat16>& process_logits_fp16_func,
const GenerationDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
const GenerationDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_fp16_func) {
reorder_past_state_func_ = reorder_past_state_func;
add_to_feeds_func_ = add_to_feeds_func;
topk_func_ = topk_func;
device_copy_func_ = device_copy_func;
@ -83,8 +85,12 @@ class BeamSearch : public IControlFlowKernel {
expand_buffer_float16_func_ = expand_buffer_float16_func;
}
const void* cuda_device_prop_ = nullptr;
int cuda_device_arch_ = 0;
private:
// Device specific functions
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::TopkFunc topk_func_;
GenerationDeviceHelper::DeviceCopyFunc<float> device_copy_func_;

View file

@ -20,6 +20,9 @@ struct BeamSearchState : public IBeamSearchState<T> {
int vocab_size,
int sequence_length,
int max_length,
int num_heads,
int head_size,
int has_decoder_masked_multihead_attention,
bool output_scores,
bool use_position) {
size_t batch_beam_size = SafeInt<size_t>(batch_size) * num_beams;
@ -49,6 +52,21 @@ struct BeamSearchState : public IBeamSearchState<T> {
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements);
this->remaining_scores = this->scores;
}
if (has_decoder_masked_multihead_attention) {
// We need a temp staging buffer to do the past 'K' state re-ordering that is needed
// when using DecoderMaskedMultiheadAttention
TensorShape staging_for_past_state_reorder_buffer_shape = {static_cast<int64_t>(batch_beam_size), num_heads, max_length, head_size};
Tensor temp(DataTypeImpl::GetType<T>(), staging_for_past_state_reorder_buffer_shape, allocator);
this->staging_for_past_state_reorder = std::move(temp);
// We need a buffer on GPU to hold the final chosen indices after BeamScorer has finished processing
// TODO: This is a temporary work-around as BeamScorer currently only runs on CPU.
// We can remove these kinds of work-arounds once BeamScorer runs on CUDA eventually.
this->chosen_indices = AllocateBuffer<int32_t>(allocator, chosen_indices_buffer_, batch_beam_size);
}
}
private:
@ -61,6 +79,7 @@ struct BeamSearchState : public IBeamSearchState<T> {
BufferUniquePtr beam_scores_buffer_;
BufferUniquePtr scores_buffer_;
BufferUniquePtr topk_temp_buffer_;
BufferUniquePtr chosen_indices_buffer_;
};
struct BeamSearchCpuState : public IBeamSearchCpuState {
@ -124,7 +143,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
// Base class of beam search implementation that is common for both GPT-2 and T5.
template <typename T>
class BeamSearchBase : public GenerateBase {
class BeamSearchBase : public GenerateBase {
public:
BeamSearchBase(OpKernelContextInternal& context,
const SessionState& decoder_session_state,
@ -136,13 +155,13 @@ class BeamSearchBase : public GenerateBase {
const GenerationDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const GenerationDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func)
: GenerateBase(context,
decoder_session_state,
thread_pool,
ort_stream,
cuda_dumper,
topk_func,
device_copy_func),
: GenerateBase(context,
decoder_session_state,
thread_pool,
ort_stream,
cuda_dumper,
topk_func,
device_copy_func),
parameters_(&params),
process_logits_func_(process_logits_func),
device_copy_int32_func_(device_copy_int32_func) {
@ -188,11 +207,11 @@ Status BeamSearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
// input_ids : (batch_size, sequence_length)
// vocab_mask : (vocab_size) or nullptr
ORT_RETURN_IF_ERROR(this->CheckInputsImpl(parameters_,
context.Input<Tensor>(0), // input_ids
context.Input<Tensor>(7), // vocab_mask
context.Input<Tensor>(8), // prefix_vocab_mask
context.Input<Tensor>(9), // attention_mask
nullptr)); // presence_mask
context.Input<Tensor>(0), // input_ids
context.Input<Tensor>(7), // vocab_mask
context.Input<Tensor>(8), // prefix_vocab_mask
context.Input<Tensor>(9), // attention_mask
nullptr)); // presence_mask
return Status::OK();
}

View file

@ -27,12 +27,15 @@ class BeamSearchGpt : public BeamSearchBase<T> {
BeamSearchParameters& params,
const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func,
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
const GenerationDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const GenerationDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
const GenerationDeviceHelper::UpdateGptFeedsFunc<T>& update_feeds_func)
const GenerationDeviceHelper::UpdateGptFeedsFunc<T>& update_feeds_func,
const void* cuda_device_prop,
int cuda_device_arch)
: BeamSearchBase<T>(context, decoder_session_state, thread_pool,
ort_stream, cuda_dumper, params,
topk_func, process_logits_func, device_copy_func, device_copy_int32_func),
@ -42,7 +45,17 @@ class BeamSearchGpt : public BeamSearchBase<T> {
create_inputs_func_(create_inputs_func),
add_to_feeds_func_(add_to_feeds_func),
init_beam_state_func_(init_beam_state_func),
update_feeds_func_(update_feeds_func) {
reorder_past_state_func_(reorder_past_state_func),
update_feeds_func_(update_feeds_func),
cuda_device_prop_(cuda_device_prop),
cuda_device_arch_(cuda_device_arch) {
if (gpt_subgraph_.has_decoder_masked_multihead_attention_) {
ORT_ENFORCE(cuda_device_arch_ >= 530,
"Decoder masked multihead attention can only be used on "
"GPU cards of compute capability 5.3 or higher. "
"This card has compute capability ",
cuda_device_arch_);
}
}
// Execute beam search in iterations util stopping criteria is reached.
@ -55,7 +68,8 @@ class BeamSearchGpt : public BeamSearchBase<T> {
Status CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer);
IAllocatorUniquePtr<char>& buffer,
bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// Update the input for next iteration.
Status UpdateFeeds(
@ -65,7 +79,11 @@ class BeamSearchGpt : public BeamSearchBase<T> {
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices);
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int past_sequence_length,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
const SessionState* init_run_decoder_session_state_ = nullptr;
GptSubgraph* init_run_gpt_subgraph_ = nullptr;
@ -75,14 +93,19 @@ class BeamSearchGpt : public BeamSearchBase<T> {
GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
const void* cuda_device_prop_ = nullptr;
int cuda_device_arch_ = 0;
};
template <typename T>
Status BeamSearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer) {
IAllocatorUniquePtr<char>& buffer,
bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
const OrtValue* input_ids_value = this->context_.GetInputOrtValue(0);
const Tensor& input_ids = input_ids_value->Get<Tensor>();
const OrtValue* attn_mask_value = this->context_.GetInputOrtValue(9);
@ -99,7 +122,9 @@ Status BeamSearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths
this->create_inputs_func_,
this->add_to_feeds_func_,
buffer,
this->ort_stream_);
this->ort_stream_,
this->parameters_->max_length,
add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
}
return gpt_subgraph_.CreateInitialFeeds(input_ids,
@ -113,7 +138,9 @@ Status BeamSearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths
this->create_inputs_func_,
this->add_to_feeds_func_,
buffer,
this->ort_stream_);
this->ort_stream_,
this->parameters_->max_length,
add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
}
template <typename T>
@ -124,7 +151,11 @@ Status BeamSearchGpt<T>::UpdateFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices) {
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int past_sequence_length,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
return update_feeds_func_(this->temp_space_allocator_,
this->ort_stream_,
last_outputs,
@ -133,12 +164,15 @@ Status BeamSearchGpt<T>::UpdateFeeds(
position_ids,
increase_position,
beam_next_tokens,
beam_indices,
beam_indices_cpu,
beam_indices_gpu,
this->parameters_->num_beams,
gpt_subgraph_.GetFirstPastInputIndex(),
gpt_subgraph_.GetFirstPresentOutputIndex(),
false,
-1);
gpt_subgraph_.past_present_share_buffer_,
past_sequence_length,
input_sequence_len,
has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
}
template <typename T>
@ -192,7 +226,22 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
// buffer in GPU for input_ids, position_ids and attention_mask
IAllocatorUniquePtr<char> buffer;
OrtValue expanded_input_ids_in_cpu;
ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer,
gpt_subgraph_.has_decoder_masked_multihead_attention_));
if (gpt_subgraph_.past_present_share_buffer_) { // Reuse past and present
fetches.reserve(static_cast<int64_t>(gpt_subgraph_.GetFirstPresentOutputIndex()) + gpt_subgraph_.num_layers);
fetches.resize(gpt_subgraph_.GetFirstPresentOutputIndex(), OrtValue());
for (int layer = 0; layer < gpt_subgraph_.num_layers; layer++) {
int feed_idx = gpt_subgraph_.GetFirstPastInputIndex() + layer;
OrtValue& past_tensor_value = feeds[feed_idx];
Tensor* past_tensor = past_tensor_value.GetMutable<Tensor>();
OrtValue present_tensor_value;
Tensor::InitOrtValue(past_tensor->DataType(), past_tensor->Shape(), past_tensor->MutableData<T>(),
past_tensor->Location(), present_tensor_value);
fetches.push_back(present_tensor_value);
}
}
BeamSearchState<T> beam_state;
constexpr bool use_position = true;
@ -202,6 +251,9 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
parameters->vocab_size,
parameters->sequence_length,
parameters->max_length,
parameters->num_heads,
parameters->head_size,
gpt_subgraph_.has_decoder_masked_multihead_attention_,
parameters->output_scores,
use_position);
@ -297,17 +349,49 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
// Increase sequence length after a new token is generated.
++current_length;
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
// contains DecoderMaskedMultiheadAttention nodes
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_multihead_attention_) {
size_t offset = static_cast<size_t>(gpt_subgraph_.GetFirstPresentOutputIndex());
// We will use the same staging buffer while transposing all the layers' past state
// and this is okay because we use the same stream to do the staging copy and the transpose
// operations.
// If we ever do them in different streams, we must use different staging buffers to avoid data
// races.
for (size_t i = 0; i < static_cast<size_t>(gpt_subgraph_.num_layers); ++i) {
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
*fetches[offset + i].GetMutable<Tensor>(),
beam_state.staging_for_past_state_reorder,
this->ort_stream_));
}
}
// Prepare inputs for next round of subgraph call.
if (current_length < parameters->max_length) {
gsl::span<const int32_t> place_holder;
// For the first iteration, position_ids is initialized as sequence lengths. We can add it to feeds directly.
// For the remaining iterations, we need increase position_ids first, then add it to feeds.
bool increase_position = (iteration_counter > 1);
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
position_ids, increase_position,
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
ReinterpretAsSpan<const int32_t>(beam_indices)));
ReinterpretAsSpan<const int32_t>(beam_indices),
gpt_subgraph_.has_decoder_masked_multihead_attention_
? ReinterpretAsSpan<const int32_t>(beam_state.chosen_indices)
: place_holder,
current_length - 1,
parameters->sequence_length,
gpt_subgraph_.has_decoder_masked_multihead_attention_));
}
if (gpt_subgraph_.past_present_share_buffer_) {
// clear fetched values before presents[]
for (int idx = 0; idx < gpt_subgraph_.GetFirstPresentOutputIndex(); idx++) {
fetches[idx] = OrtValue();
}
} else {
fetches.clear();
}
fetches.clear();
}
gsl::span<const float> final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());

View file

@ -193,6 +193,9 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
parameters->vocab_size,
parameters->sequence_length,
parameters->max_length,
parameters->num_heads,
parameters->head_size,
false, // TODO: Support past/present state buffer re-use for T5
parameters->output_scores,
use_position);

View file

@ -94,7 +94,7 @@ class BeamSearchScorer : public IBeamScorer {
gsl::span<float>& GetNextScores() { return next_beam_scores_; }
gsl::span<int32_t>& GetNextTokens() { return next_beam_tokens_; }
gsl::span<int32_t>& GetNextIndices() { return next_beam_indices_; }
gsl::span<int32_t>& GetNextIndices() override { return next_beam_indices_; }
private:
size_t batch_size_;

View file

@ -407,18 +407,18 @@ Status ProcessLogits(const OrtValue& logits, //
template <typename T>
Status GreedySearchProcessLogits(
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingState<T>* sampling_state, // sampling_state
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
const transformers::IGenerationParameters* parameters, // parameters
bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingState<T>* sampling_state, // sampling_state
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
const transformers::IGenerationParameters* parameters, // parameters
bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
int batch_size = parameters->batch_size;
int vocab_size = parameters->vocab_size;
@ -499,8 +499,8 @@ Status GreedySearchProcessLogits(
topk_indices));
#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", topk_scores);
dumper->Print("topk_indices", topk_indices);
dumper->Print("topk_scores", topk_scores);
dumper->Print("topk_indices", topk_indices);
#endif
gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
@ -574,15 +574,21 @@ Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
int past_sequence_len) {
int past_sequence_len,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
// last_outputs: logits, present_0, present_1, ...
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
ORT_UNUSED_PARAMETER(stream);
ORT_UNUSED_PARAMETER(beam_indices_gpu);
ORT_UNUSED_PARAMETER(input_sequence_len);
ORT_UNUSED_PARAMETER(has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// The following updates inputs for subgraph
@ -631,14 +637,14 @@ Status UpdateGptFeeds(
return Status::OK();
}
if (num_beams == 1) { // Update past state
if (num_beams == 1) { // Update past state
// feed present_* output to past_* inputs one by one
const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx;
for (size_t i = gpt_subgraph_first_present_output_idx; i < last_outputs.size(); ++i) {
next_inputs[i + k] = last_outputs[i];
}
} else {
PickGptPastState<T>(last_outputs, next_inputs, beam_indices,
PickGptPastState<T>(last_outputs, next_inputs, beam_indices_cpu,
gpt_subgraph_first_past_input_idx,
gpt_subgraph_first_present_output_idx, allocator);
}
@ -890,12 +896,15 @@ template Status UpdateGptFeeds<float>(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
int past_sequence_len);
int past_sequence_len,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
template Status UpdateDecoderFeeds<float>(
AllocatorPtr allocator,

View file

@ -128,12 +128,15 @@ using UpdateGptFeedsFunc = std::function<Status(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
int past_sequence_len)>;
int past_sequence_len,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention)>;
// Create encoder inputs (for encoder-decoder model like T5).
using CreateEncoderInputsFunc = std::function<Status(
@ -262,12 +265,15 @@ Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
int past_sequence_len);
int past_sequence_len,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// ---------------------------------------------------------------
// Functions for encoder-decoder model like T5

View file

@ -38,6 +38,11 @@ struct IBeamSearchState {
// temp token: (batch_size * num_beams, 2 * num_beams)
// in total, it will be:
// 2 * (batch_size * num_beams * (parts_vocab + 1), 2 * num_beams)
// The final chosen indices after BeamScorer has finished processing
gsl::span<int32_t> chosen_indices; // shape (batch_size, num_beams)
Tensor staging_for_past_state_reorder; // Tensor of shape (batch_size * num_beams, num_heads, max_length, head_size)
};
struct IBeamSearchCpuState {
@ -117,6 +122,8 @@ class IBeamScorer {
gsl::span<const float>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) = 0;
virtual gsl::span<int32_t>& GetNextIndices() = 0;
};
struct IGenerationParameters {

View file

@ -81,7 +81,7 @@ struct GreedySearchState : public IGreedySearchState<T> {
int max_length,
int num_heads,
int head_size,
bool allocate_staging_buffer_for_past_state_reorder,
bool has_decoder_masked_multihead_attention,
bool is_cuda) {
// below buffers are on cpu
this->sequences_space = AllocateBuffer<int32_t>(cpu_allocator,
@ -111,8 +111,9 @@ struct GreedySearchState : public IGreedySearchState<T> {
this->topk_scores_buffer,
this->topk_tokens_buffer);
// If at all we need to, we only need to re-order past state for CUDA
if (allocate_staging_buffer_for_past_state_reorder) {
// If at all we need to, we only need to re-order past state for CUDA as
//`DecoderMaskedMultiheadAttention` is only supported on CUDA
if (has_decoder_masked_multihead_attention) {
TensorShape staging_for_past_state_reorder_buffer_shape = {batch_size, num_heads, max_length, head_size};
Tensor temp(DataTypeImpl::GetType<T>(), staging_for_past_state_reorder_buffer_shape, allocator);

View file

@ -170,11 +170,14 @@ Status GreedySearchGpt<T, ParametersT>::UpdateFeeds(
increase_position,
next_tokens,
place_holder,
place_holder,
this->parameters_->num_beams,
gpt_subgraph_.GetFirstPastInputIndex(),
gpt_subgraph_.GetFirstPresentOutputIndex(),
gpt_subgraph_.past_present_share_buffer_,
past_sequence_length);
past_sequence_length,
-1, // Input sequence length needn't be passed in for GreedySearch
false);
}
template <typename T, typename ParametersT>
@ -329,7 +332,7 @@ Status GreedySearchGpt<T, ParametersT>::Execute(const FeedsFetchesManager* init_
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
// contains DecoderMaskedMultiheadAttention nodes
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_multihead_attention_) {
size_t offset = static_cast<int64_t>(gpt_subgraph_.GetFirstPresentOutputIndex());
size_t offset = static_cast<size_t>(gpt_subgraph_.GetFirstPresentOutputIndex());
// We will use the same staging buffer while transposing all the layers' past state
// and this is okay because we use the same stream to do the staging copy and the transpose
// operations.

View file

@ -91,6 +91,9 @@ Status Subgraph::Setup(const SessionState& session_state,
past_present_share_buffer_ = true;
// past_sequence_length is on CPU memory
feed_locations.push_back(OrtDevice());
} else if (feed_names[i] == "beam_width") {
// beam_width is on CPU memory
feed_locations.push_back(OrtDevice());
} else {
feed_locations.push_back(default_location.device);
}

View file

@ -27,7 +27,8 @@ Status GptSubgraph::CreateInitialFeeds(
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr<char>& buffer,
Stream* ort_stream,
int max_seq_len_past_present_share_buffer) {
int past_present_share_buffer_max_seq_len,
bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
const IExecutionProvider* provider = GetProvider();
@ -83,20 +84,48 @@ Status GptSubgraph::CreateInitialFeeds(
feeds.push_back(empty_past);
}
} else {
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, max_seq_len_past_present_share_buffer, head_size};
// Past state feeds
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, past_present_share_buffer_max_seq_len, head_size};
TensorShape past_shape(&past_state_dims[0], 5);
// The remaining inputs are past state except the last one
for (int i = first_past_input_index_; i < num_subgraph_inputs - 1; ++i) {
// The remaining inputs are past state except the last one or three (see below for details)
// If `add_beam_search_specific_inputs_for_decoder_masked_multihead_attention` is false, then the last input is `past_sequence_length`
// If `add_beam_search_specific_inputs_for_decoder_masked_multihead_attention` is true, then the last inputs are `past_sequence_length`,
// `beam_width`, and `cache_indirection`
auto past_end_iter = add_beam_search_specific_inputs_for_decoder_masked_multihead_attention ? num_subgraph_inputs - 3 : num_subgraph_inputs - 1;
for (int i = first_past_input_index_; i < past_end_iter; ++i) {
OrtValue past_tensor;
Tensor::InitOrtValue(past_type, past_shape, default_allocator, past_tensor);
feeds.push_back(past_tensor);
}
// Past sequence length feed
int64_t past_seq_len_dims[] = {1};
TensorShape past_seq_len_shape(&past_seq_len_dims[0], 1);
OrtValue past_seq_len_tensor_value;
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), past_seq_len_shape, cpu_allocator, past_seq_len_tensor_value);
feeds.push_back(past_seq_len_tensor_value);
*past_seq_len_tensor_value.GetMutable<Tensor>()->MutableData<int32_t>() = 0;
// Add beam search specific inputs
if (add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
// Beam width feed
int64_t num_beams_dims[] = {1};
TensorShape num_beams_shape(&num_beams_dims[0], 1);
OrtValue num_beams_tensor_value;
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), num_beams_shape, cpu_allocator, num_beams_tensor_value);
feeds.push_back(num_beams_tensor_value);
*num_beams_tensor_value.GetMutable<Tensor>()->MutableData<int32_t>() = static_cast<int32_t>(num_beams);
// Cache indirection feed
int64_t cache_indirection_dims[] = {batch_size, num_beams, past_present_share_buffer_max_seq_len};
TensorShape cache_indirection_shape(&cache_indirection_dims[0], 3);
OrtValue default_cache_indirection;
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), cache_indirection_shape,
default_allocator, default_cache_indirection);
feeds.push_back(default_cache_indirection);
}
}
// Pass in implicit inputs
@ -112,8 +141,12 @@ Status GptSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
ORT_RETURN_IF(num_subgraph_outputs <= first_present_output_index_,
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in outputs).");
ORT_RETURN_IF(!((num_subgraph_inputs == num_subgraph_outputs + 2) || (num_subgraph_inputs == num_subgraph_outputs + 3)),
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2 or 3 (if past_present_share_buffer)");
ORT_RETURN_IF(!((num_subgraph_inputs == num_subgraph_outputs + 2) ||
(num_subgraph_inputs == num_subgraph_outputs + 3) ||
(num_subgraph_inputs == num_subgraph_outputs + 5)),
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2 or "
"3 (if past_present_share_buffer) or "
"5 (if past_present_share_buffer and use_decoder_masked_multihead_attention for BeamSearch)");
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
"subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());

View file

@ -16,9 +16,9 @@ class GptSubgraph : public Subgraph {
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {
first_past_input_index_ = 3;
first_present_output_index_ = 1;
}
first_past_input_index_ = 3;
first_present_output_index_ = 1;
}
// Create inputs for first inference of subgraph.
Status CreateInitialFeeds(
@ -34,7 +34,8 @@ class GptSubgraph : public Subgraph {
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr<char>& buffer,
Stream* ort_stream,
int max_seq_len_past_present_share_buffer = -1);
int past_present_share_buffer_max_seq_len = -1,
bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention = false);
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;

View file

@ -16,20 +16,23 @@ namespace contrib {
namespace cuda {
static constexpr int kPastSequenceLengthInputIndex = 6;
static constexpr int kBeamWidthInputIndex = 7;
static constexpr int kCacheIndirectionInputIndex = 8;
static constexpr int kPastInputIndex = 4;
static constexpr int kPresentOutputIndex = 1;
#define REGISTER_KERNEL_TYPED(T1, T2) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
DecoderMaskedMultiheadAttention, \
kMSDomain, \
1, \
T1, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T1>()) \
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \
#define REGISTER_KERNEL_TYPED(T1, T2) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
DecoderMaskedMultiheadAttention, \
kMSDomain, \
1, \
T1, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T1>()) \
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \
.InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \
DecoderMaskedMultiheadAttention<T1, T2>);
REGISTER_KERNEL_TYPED(float, float)
@ -44,6 +47,8 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
const Tensor* past = context->Input<Tensor>(kPastInputIndex);
const Tensor* relative_position_bias = context->Input<Tensor>(5);
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
const Tensor* beam_width = context->Input<Tensor>(kBeamWidthInputIndex);
const Tensor* cache_indir = context->Input<Tensor>(kCacheIndirectionInputIndex);
auto& device_prop = GetDeviceProp();
DecoderMaskedMultiheadAttentionParams parameters;
@ -105,7 +110,7 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
auto* present_data = present->MutableData<T1>();
auto* past_data = past->Data<T1>();
// No production use-case will incure this copy cost as the implementation of
// No production use-case will incur this copy cost as the implementation of
// GreedySearch/BeamSearch is written in such a way that the past and present buffers
// will be shared.
// This is just to circumvent the OpTester's limitation of not being able to bind a specific
@ -142,13 +147,13 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
// Update the q, k, and v buffers
parameters.q = gemm_buffer.get();
parameters.k = reinterpret_cast<CudaT*>(gemm_buffer.get()) + parameters.hidden_size;
parameters.v = reinterpret_cast<CudaT*>(gemm_buffer.get()) + 2 * parameters.hidden_size;
parameters.v = reinterpret_cast<CudaT*>(gemm_buffer.get()) + 2 * static_cast<int64_t>(parameters.hidden_size);
// Update the q, k, and v bias
const T1* bias_data = bias->Data<T1>();
parameters.q_bias = const_cast<T1*>(bias_data);
parameters.k_bias = const_cast<T1*>(bias_data + parameters.hidden_size);
parameters.v_bias = const_cast<T1*>(bias_data + 2 * parameters.hidden_size);
parameters.v_bias = const_cast<T1*>(bias_data + 2 * static_cast<int64_t>(parameters.hidden_size));
// Half of the past/present buffer correspond to K - the other half is V.
auto k_size = present->Shape().Size() / 2;
@ -167,6 +172,22 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
parameters.mask = mask_index->Data<int32_t>();
}
// Beam width (in case we are using this op inside BeamSearch)
if (beam_width != nullptr) {
parameters.beam_width = static_cast<int>(*beam_width->Data<int32_t>());
}
// Cache indirection (in case we are using this op inside BeamSearch)
if (parameters.beam_width > 1) {
// If beam width > 1, then cache indirection buffer MUST be present
if (cache_indir == nullptr) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"If beam width is greater than 1, then cache indirection buffer MUST be present");
}
parameters.cache_indir = cache_indir->Data<int32_t>();
}
switch (parameters.head_size) {
case 64:
mmha_launch_kernel<T2, 64>(parameters, cuda_stream);
@ -182,7 +203,6 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
"Got head size: ",
parameters.head_size);
}
return Status::OK();
}

View file

@ -22,6 +22,8 @@
// Modifications:
// (1) Removed some code paths from the original implementation that had features which is not supported by
// corresponding ORT kernel - for example- CrossAttention support, FP8, INT8, supports, etc.
// (2) When dealing with masked tokens, this kernel implementation deviates from FasterTransformer by applying
// mask filter values. Appropriate commentary exists in the code below.
#include "decoder_masked_multihead_attention_impl.h"
#include "decoder_masked_multihead_attention_impl_utils.h"
@ -286,7 +288,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
bool has_beams = params.cache_indir != nullptr;
const int* beam_indices = has_beams ? &params.cache_indir[bi_max_seq_length] : nullptr;
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
@ -294,16 +295,24 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
// The keys loaded from the key cache.
K_vec_k k_vec[K_VECS_PER_THREAD];
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
if (ti < tlength) {
if (has_beams) {
if (has_beams) {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
if (ti < tlength) {
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
} else {
}
}
} else {
#pragma unroll
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
int jj = ii * params.max_sequence_length + ti;
if (ti < tlength) {
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B])));
}
@ -314,9 +323,16 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * inv_sqrt_dh;
// This is a deviation from FasterTransformer kernel implementation
// but this aligns with ORT's other Attention kernels which strives to
// mimic PyTorch when dealing with mask filter values
if (is_masked) {
qk += params.mask_filter_value;
}
// Store the product to shared memory. There's one qk value per timestep. Update the max.
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
qk_max = is_masked ? qk_max : fmaxf(qk_max, qk);
qk_max = fmaxf(qk_max, qk);
qk_smem[ti] = qk;
}
}
@ -355,8 +371,10 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
// Compute the logits and start the sum.
float sum = 0.f;
for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0);
float logit = is_masked ? 0.f : __expf(qk_smem[ti] - qk_max);
// This is a deviation from FasterTransformer kernel implementation
// but this aligns with ORT's other Attention kernels which strives to
// mimic PyTorch when dealing with mask filter values
float logit = __expf(qk_smem[ti] - qk_max);
sum += logit;
qk_smem[ti] = logit;
}

View file

@ -35,7 +35,8 @@ transformers::CudaTensorConsoleDumper g_cuda_dumper;
BeamSearch::BeamSearch(const OpKernelInfo& info)
: onnxruntime::contrib::transformers::BeamSearch(info) {
SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds,
SetDeviceHelpers(GenerationCudaDeviceHelper::ReorderPastState,
GenerationCudaDeviceHelper::AddToFeeds,
GenerationCudaDeviceHelper::TopK,
GenerationCudaDeviceHelper::DeviceCopy<float>,
GenerationCudaDeviceHelper::DeviceCopy<int32_t>,
@ -54,6 +55,11 @@ BeamSearch::BeamSearch(const OpKernelInfo& info)
GenerationCudaDeviceHelper::ExpandBuffer<MLFloat16>);
SetConsoleDumper(&g_cuda_dumper);
cuda_device_prop_ = &reinterpret_cast<const CUDAExecutionProvider*>(info.GetExecutionProvider())->GetDeviceProp();
cuda_device_arch_ = static_cast<const cudaDeviceProp*>(cuda_device_prop_)->major * 100 +
static_cast<const cudaDeviceProp*>(cuda_device_prop_)->minor * 10;
}
Status BeamSearch::ComputeInternal(OpKernelContext* context) const {

View file

@ -424,7 +424,7 @@ void LaunchSortPairs(void* d_temp_storage,
d_offsets + 1,
0,
sizeof(T) * 8,
stream));
stream));
} else {
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
temp_storage_bytes,
@ -469,7 +469,7 @@ template void LaunchSortPairs(void* d_temp_storage,
// A stateful callback functor that maintains a running prefix to be applied
// during consecutive scan operations.
struct BlockPrefixCallbackOp {
float running_total; // running prefix
float running_total; // running prefix
__device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {}
// Callback operator to be entered by the first warp of threads in the block.
@ -746,6 +746,65 @@ void TorchMultinomialKernelLauncher(float* d_input,
d_presence_mask);
}
__global__ void UpdateDecoderMaskedMultiheadAttentionCacheIndirectionKernel(int32_t* tgt_indir_cache,
const int32_t* src_indir_cache,
const int32_t* beam_ids,
int batch_size,
int beam_width,
int input_seq_length,
int max_seq_length,
int current_length) {
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
int bb_id = threadIdx.y + blockIdx.y * blockDim.y;
const int batch_id = bb_id / beam_width;
const int beam_id = bb_id % beam_width;
if (bb_id >= beam_width * batch_size || time_step >= current_length) {
return;
}
const int src_beam = beam_ids[batch_id * beam_width + beam_id] % beam_width;
const int tgt_offset = batch_id * beam_width * max_seq_length + beam_id * max_seq_length + time_step;
if (time_step < input_seq_length) {
// For time steps that correspond to the input sequence,
// the beam that it comes from is always 0.
tgt_indir_cache[tgt_offset] = static_cast<int32_t>(0);
} else if (time_step == (current_length - 1)) {
// For the final (newly generated) time step,
// the beam that it comes from is always the beam that we
// are currently processing (i.e.) from this point on, these time-steps
// form the new beams.
tgt_indir_cache[tgt_offset] = static_cast<int32_t>(beam_id);
} else {
// For all other time-steps, we look up the source indirection, to
// see which beam it came from based on the `src_beam`.
const int src_offset = batch_id * beam_width * max_seq_length + src_beam * max_seq_length + time_step;
tgt_indir_cache[tgt_offset] = src_indir_cache[src_offset];
}
}
void UpdateDecoderMaskedMultiheadAttentionCacheIndirection(int32_t* tgt_indir_cache,
const int32_t* src_indir_cache,
const int32_t* beam_ids,
int batch_size,
int beam_width,
int input_seq_length,
int max_seq_length,
int current_length,
cudaStream_t stream) {
const dim3 block(32);
const dim3 grid((current_length + block.x - 1) / block.x, batch_size * beam_width);
UpdateDecoderMaskedMultiheadAttentionCacheIndirectionKernel<<<grid, block, 0, stream>>>(tgt_indir_cache,
src_indir_cache,
beam_ids,
batch_size,
beam_width,
input_seq_length,
max_seq_length,
current_length);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -61,7 +61,7 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data,
cudaStream_t stream);
template <typename T>
void GetTempStorageSize(const T *d_keys_in,
void GetTempStorageSize(const T* d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
@ -77,15 +77,15 @@ void LaunchSetupParamsKernel(int* d_values_in,
cudaStream_t stream);
template <typename T>
void LaunchSortPairs(void *d_temp_storage,
void LaunchSortPairs(void* d_temp_storage,
size_t temp_storage_bytes,
const T *d_keys_in,
T *d_keys_out,
const int *d_values_in,
int *d_values_out,
const T* d_keys_in,
T* d_keys_out,
const int* d_values_in,
int* d_values_out,
int num_items,
int num_segments,
int *d_offsets,
int* d_offsets,
cudaStream_t stream,
bool is_descending);
@ -109,6 +109,16 @@ void TorchMultinomialKernelLauncher(float* d_input,
int* d_presence_mask,
cudaStream_t stream);
void UpdateDecoderMaskedMultiheadAttentionCacheIndirection(int32_t* tgt_indir_cache,
const int32_t* src_indir_cache,
const int32_t* beam_ids,
int batch_size,
int beam_width,
int input_seq_length,
int max_seq_length,
int current_length,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -481,6 +481,8 @@ Status ProcessLogits(const OrtValue& logits, //
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
#endif
// TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
beam_state->next_scores.data(),
beam_state->next_scores.size_bytes(),
@ -528,6 +530,7 @@ Status ProcessLogits(const OrtValue& logits, //
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
#endif
// TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
data,
topk_scores->SizeInBytes(),
@ -535,6 +538,7 @@ Status ProcessLogits(const OrtValue& logits, //
cuda_stream));
}
// TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(),
beam_state->next_tokens.data(),
beam_state->next_tokens.size_bytes(),
@ -551,13 +555,31 @@ Status ProcessLogits(const OrtValue& logits, //
gsl::span<const int32_t> next_tokens(cpu_state->topk_tokens.data(), beam_state->next_tokens.size());
gsl::span<const int32_t> next_indices(cpu_state->topk_indices.data(), beam_state->next_indices.size());
// Limitation: beam scorer runs in CPU. It might be better to use CUDA kernel to replace it.
// TODO: Implement BeamScorer on CUDA
beam_scorer->Process(
sequences,
next_scores,
next_tokens,
next_indices);
// TODO: This is a temporary work-around as BeamScorer currently only runs on CPU.
// We can remove these kinds of work-arounds once BeamScorer runs on CUDA eventually.
auto chosen_indices = beam_scorer->GetNextIndices();
auto beam_state_chosen_indices = beam_state->chosen_indices;
if (!beam_state_chosen_indices.empty()) {
// If we have allocated `chosen_indices` in beam_state, it means that we
// will be needing the chosen indices from BeamScorer as we are using
// DecoderMaskedMultiheadAttention, so copy it over.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state_chosen_indices.data(),
chosen_indices.data(),
chosen_indices.size_bytes(),
cudaMemcpyHostToDevice,
cuda_stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
}
#ifdef ENABLE_NVTX_PROFILE
processLogitsRange.End();
#endif
@ -869,12 +891,15 @@ Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
int past_sequence_len) {
int past_sequence_len,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
#ifdef ENABLE_NVTX_PROFILE
profile::NvtxNestedRangeCreator updateFeedsRange("UpdateGptFeeds", profile::Color::Yellow);
updateFeedsRange.Begin();
@ -889,6 +914,8 @@ Status UpdateGptFeeds(
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids);
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
cudaStream_t cuda_stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
// TODO(hasesh): When BeamScorer is implemented on CUDA, figure out a way to avoid this cross-device copy
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream));
next_inputs[0] = input_ids;
@ -914,8 +941,41 @@ Status UpdateGptFeeds(
next_inputs[2] = attention_mask;
if (past_present_share_buffer) {
const int k = (static_cast<int>(last_outputs.size()) - gpt_subgraph_first_present_output_idx) + gpt_subgraph_first_past_input_idx;
*(next_inputs[k].GetMutable<Tensor>()->MutableData<int32_t>()) = past_sequence_len;
// Update past sequence length input
const int past_sequence_length_idx = (static_cast<int>(last_outputs.size()) - gpt_subgraph_first_present_output_idx) + gpt_subgraph_first_past_input_idx;
*(next_inputs[past_sequence_length_idx].GetMutable<Tensor>()->MutableData<int32_t>()) = past_sequence_len;
// Update beam search specific input for DecoderMaskedMultiheadAttention (cache indirection) if present
// If the last input is not `past_sequence_length`, then the beam search specific inputs
// for `DecoderMaskedMultiheadAttention` is present
if (has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
ORT_ENFORCE(!beam_indices_gpu.empty(), "Beam indices must be present on CUDA while using DecoderMaskedMultiheadAttention with BeamSearch");
// The cache indirection feed comes 2 feeds after the `past_sequence_length` feed
const OrtValue& old_cache_indirection = next_inputs[past_sequence_length_idx + 2];
// New cache indirection updated for next decoding run
OrtValue cache_indirection;
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), old_cache_indirection.Get<Tensor>().Shape(), allocator, cache_indirection);
// The third index of the past/present tensor is the max_sequence_length
int max_sequence_length = static_cast<int>(last_outputs[gpt_subgraph_first_present_output_idx].Get<Tensor>().Shape()[3]);
// Launch kernel to update the cache indirection buffer
cuda::UpdateDecoderMaskedMultiheadAttentionCacheIndirection(cache_indirection.GetMutable<Tensor>()->MutableData<int32_t>(),
old_cache_indirection.Get<Tensor>().Data<int32_t>(),
reinterpret_cast<const int32_t*>(beam_indices_gpu.data()),
batch_beam_size / num_beams,
num_beams,
input_sequence_len,
max_sequence_length,
current_length,
cuda_stream);
// Update cache indirection for next decoding run
next_inputs[past_sequence_length_idx + 2] = cache_indirection;
}
} else {
if (num_beams == 1) {
const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx;
@ -924,7 +984,7 @@ Status UpdateGptFeeds(
next_inputs[i + k] = last_outputs[i];
}
} else {
ORT_RETURN_IF_ERROR(PickGptPastState<T>(last_outputs, next_inputs, beam_indices, allocator,
ORT_RETURN_IF_ERROR(PickGptPastState<T>(last_outputs, next_inputs, beam_indices_cpu, allocator,
gpt_subgraph_first_past_input_idx,
gpt_subgraph_first_present_output_idx, ort_stream));
}
@ -1114,12 +1174,15 @@ template Status UpdateGptFeeds<float>(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
int past_sequence_len);
int past_sequence_len,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// Float16
template void InitBeamState<MLFloat16>(
@ -1171,12 +1234,15 @@ template Status UpdateGptFeeds<MLFloat16>(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
int past_sequence_len);
int past_sequence_len,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
template Status UpdateDecoderFeeds<float>(
AllocatorPtr allocator,

View file

@ -97,12 +97,15 @@ Status UpdateGptFeeds(
OrtValue& position_ids,
bool increase_position,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
gsl::span<const int32_t> beam_indices_cpu,
gsl::span<const int32_t> beam_indices_gpu,
int num_beams,
int gpt_subgraph_first_past_input_idx,
int gpt_subgraph_first_present_output_idx,
bool past_present_share_buffer,
int past_sequence_len);
int past_sequence_len,
int input_sequence_len,
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
// ---------------------------------------------------------------
// Functions for encoder-decoder model like T5

View file

@ -367,9 +367,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"past",
"past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)"
"When past_present_share_buffer is set, "
"its shape is (2, batch_size, num_heads, max_sequence_length, head_size)",
"T",
OpSchema::Optional)
"its shape is (2, batch_size, num_heads, max_sequence_length, head_size). "
"The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys "
"and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. "
"The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape "
"(batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape "
"(batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to "
"become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.",
"T")
.Input(5,
"relative_position_bias",
"additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)",
@ -378,6 +383,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
.Input(6,
"past_sequence_length",
"When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).",
"M")
.Input(7,
"beam_width",
"The beam width that is being used while decoding."
"If not provided, the beam width will be assumed to be 1.",
"M",
OpSchema::Optional)
.Input(8,
"cache_indirection",
"A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies"
"which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration",
"M",
OpSchema::Optional)
.Output(0,
@ -390,8 +406,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"If past_present_share_buffer is set, "
"its shape is (2, batch_size, num_heads, max_sequence_length, head_size), "
"while effective_seq_length = (past_sequence_length + kv_sequence_length).",
"T",
OpSchema::Optional)
"T")
.TypeConstraint("T",
{"tensor(float)", "tensor(float16)"},
"Constrain input and output types to float tensors.")

View file

@ -1026,7 +1026,7 @@ def shape_of(vi):
return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim])
def update_decoder_subgraph_past_present_share_buffer(subg):
def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto):
input_past_0 = 3
output_past_0 = 1
new_inputs = []
@ -1074,37 +1074,82 @@ def update_decoder_subgraph_past_present_share_buffer(subg):
return subg
def update_decoder_subgraph_use_decoder_masked_multihead_attention(subg) -> bool:
def update_decoder_subgraph_use_decoder_masked_multihead_attention(
subg: GraphProto, is_beam_search: bool, switch_attention: bool
) -> bool:
"""Update the Attention nodes to DecoderMaskedMultiheadAttention.
Args:
subg (GraphProto): GraphProto of the decoder subgraph
is_beam_search (bool): Boolean specifying if the sampling algo is BeamSearch
switch_attention (bool): Boolean specifying if `Attention` is to be switched with `DecoderMaskedMultiheadAttention`
"""
decoder_masked_attention_supported_attr = [
"past_present_share_buffer",
"num_heads",
"scale",
"domain",
]
new_nodes = []
for node in subg.node:
if node.op_type == "Attention":
kwargs = kwargs_of(node)
for k in kwargs.copy():
# The Attention operator does not support different qkv hidden sizes when past/present
# input/output exists (GPT2 model). Hence, we should never run into this.
# But, if we do, do not go ahead with the optimization.
if k == "qkv_hidden_sizes":
return False
if is_beam_search:
new_inputs = []
for i, vi in enumerate(subg.input):
new_inputs.extend([vi])
# Add 2 BeamSearch specific inputs
new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
new_inputs.extend(
[
onnx.helper.make_tensor_value_info(
"cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
)
]
)
subg.ClearField("input")
subg.input.extend(new_inputs)
if switch_attention:
decoder_masked_attention_supported_attr = [
"past_present_share_buffer",
"num_heads",
"scale",
"domain",
]
new_nodes = []
for node in subg.node:
if node.op_type == "Attention":
kwargs = kwargs_of(node)
for k in kwargs.copy():
# The Attention operator does not support different qkv hidden sizes when past/present
# input/output exists (GPT2 model). Hence, we should never run into this.
# But, if we do, do not go ahead with the optimization.
if k == "qkv_hidden_sizes":
return False
if k not in decoder_masked_attention_supported_attr:
# Log the fact that we are removing certain attributes from the node
# We don't need to log it for "unidirectional" as we are aware that
# decoding attention kernels are unidirectional by definition.
if k != "unidirectional":
logger.warning(
f"Removing attribute: {k} from Attention node while switching to DecoderMaskedMultiheadAttention"
)
del kwargs[k]
nis = []
nis.extend(node.input)
# Add 2 BeamSearch specific inputs
if is_beam_search:
while len(nis) < 7:
nis.extend([""])
if len(nis) < 8:
nis.extend(["beam_width"])
if len(nis) < 9:
nis.extend(["cache_indirection"])
node = onnx.helper.make_node(
"DecoderMaskedMultiheadAttention", nis, node.output, name=node.name, **kwargs
)
new_nodes.extend([node])
subg.ClearField("node")
subg.node.extend(new_nodes)
if k not in decoder_masked_attention_supported_attr:
del kwargs[k]
nis = []
nis.extend(node.input)
node = onnx.helper.make_node("DecoderMaskedMultiheadAttention", nis, node.output, name=node.name, **kwargs)
new_nodes.extend([node])
subg.ClearField("node")
subg.node.extend(new_nodes)
return True
@ -1452,9 +1497,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH
is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH
is_sampling: bool = generation_type == GenerationType.SAMPLING
past_present_share_buffer: bool = args.past_present_share_buffer and (is_greedysearch or is_sampling)
past_present_share_buffer: bool = args.past_present_share_buffer
logger.info(f"**** past_present_share_buffer={past_present_share_buffer}, is_greedysearch={is_greedysearch}")
logger.info(f"**** past_present_share_buffer={past_present_share_buffer}")
if is_greedysearch or is_sampling:
if not is_gpt2:
@ -1464,6 +1509,29 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
if args.output_token_scores:
raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling")
# For BeamSearch, sharing buffers for past and present states is only supported
# when using `use_decoder_masked_multihead_attention`
if past_present_share_buffer and is_beamsearch and not args.use_decoder_masked_multihead_attention:
raise ValueError(
"`use_decoder_masked_multihead_attention` MUST be turned on to use `past_present_share_buffer` in case of BeamSearch"
)
# For any kind of sampling, using decoder masked multihead attention is only supported
# when using `past_present_share_buffer`
if args.use_decoder_masked_multihead_attention and not past_present_share_buffer:
raise ValueError(
"`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_multihead_attention`"
)
# For any kind of sampling, using decoder masked multihead attention is only supported
# on GPUs
if args.use_decoder_masked_multihead_attention and not args.use_gpu:
raise ValueError("`use_decoder_masked_multihead_attention` option is only supported on GPUs")
# Using decoder masked multihead attention is only supported for GPT2
if args.use_decoder_masked_multihead_attention and args.model_type in ["t5", "mt5"]:
raise ValueError("`use_decoder_masked_multihead_attention` option is only supported for GPT2")
if is_gpt2:
if args.decoder_onnx and os.path.exists(args.decoder_onnx):
logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
@ -1699,6 +1767,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
node.attribute.extend(attr_to_extend)
initializers = []
if args.model_type in ["t5", "mt5"]:
if args.run_shape_inference:
logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
@ -1745,37 +1814,43 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
logger.info(
f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in decoder and init decoder subgraphs are moved to the main graph"
)
# Update for init decoder subgraph
# Update init decoder subgraph in preparation to use past present share buffer
if past_present_share_buffer:
logger.info("*****update init decoder subgraph to make past and present share buffer******************")
update_decoder_subgraph_past_present_share_buffer(gpt2_init_decoder_model.graph)
# Update init decoder subgraph in preparation to use DecoderMaskedMultiheadAttention
# NOTE: Even if we will not use DecoderMaskedMultiheadAttention in the init decoder subgraph
# it makes the runtime changes cleaner if we keep both the init decoder and decoder subgraphs
# same in terms of the subgraph inputs.
if (
args.use_decoder_masked_multihead_attention
and not update_decoder_subgraph_use_decoder_masked_multihead_attention(
gpt2_init_decoder_model.graph, is_beamsearch, False
)
):
raise ValueError("Could not update the init decoder subgraph to use DecoderMaskedMultiheadAttention")
node.attribute.append(onnx.helper.make_attribute("init_decoder", gpt2_init_decoder_model.graph))
else:
# Move initializer from subgraph to main graph could reduce memory usage in inference.
initializers = move_initializers(decoder_model.graph)
logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph")
# Update for non-init decoder subgraph
# Update decoder subgraph in preparation to use past present share buffer
if past_present_share_buffer:
logger.info("*****update decoder subgraph to make past and present share buffer******************")
update_decoder_subgraph_past_present_share_buffer(decoder_model.graph)
# If at all, the user requests the use of `use_decoder_masked_multihead_attention`,
# we can only use it in the decoder subgraph - it cannot be used in the init_decoder subgraph
# as the kernel only supports the decoding use-case (i.e.) when input sequence length is 1.
if args.use_decoder_masked_multihead_attention:
logger.info("Update decoder subgraph to use DecoderMaskedMultiheadAttention")
if not past_present_share_buffer:
raise ValueError(
"`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_multihead_attention`"
)
if not args.use_gpu:
raise ValueError("`use_decoder_masked_multihead_attention` option is only supported on GPUs")
if not update_decoder_subgraph_use_decoder_masked_multihead_attention(decoder_model.graph):
raise ValueError("Could not update the decoder subgraph to use DecoderMaskedMultiheadAttention")
# Update decoder subgraph in preparation to use DecoderMaskedMultiheadAttention
if (
args.use_decoder_masked_multihead_attention
and not update_decoder_subgraph_use_decoder_masked_multihead_attention(
decoder_model.graph, is_beamsearch, True
)
):
raise ValueError("Could not update the decoder subgraph to use DecoderMaskedMultiheadAttention")
node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph))

View file

@ -156,6 +156,18 @@ class TestBeamSearchGpt(unittest.TestCase):
if self.enable_cuda:
self.run_beam_search("--repetition_penalty 1.0 --use_gpu -p fp16", is_greedy=True)
@pytest.mark.slow
def test_beam_search_use_decoder_masked_multihead_attention(self):
if self.enable_cuda:
self.run_beam_search(f"--past_present_share_buffer --use_decoder_masked_multihead_attention --use_gpu")
@pytest.mark.slow
def test_beam_search_use_decoder_masked_multihead_attention_fp16(self):
if self.enable_cuda:
self.run_beam_search(
f"--past_present_share_buffer --use_decoder_masked_multihead_attention --use_gpu -p fp16"
)
@pytest.mark.slow
def test_external_data(self):
self.run_beam_search(