mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-16 21:00:14 +00:00
[CUDA] Add option to use DecoderMaskedMultiheadAttention in BeamSearch (#14990)
This commit is contained in:
parent
da084b0fc1
commit
ed7ab1660d
26 changed files with 645 additions and 174 deletions
|
|
@ -1133,7 +1133,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Custom scale will be used if specified. Default value is 1/sqrt(head_size)</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (3 - 7)
|
||||
#### Inputs (7 - 9)
|
||||
|
||||
<dl>
|
||||
<dt><tt>input</tt> : T</dt>
|
||||
|
|
@ -1144,20 +1144,24 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection</dd>
|
||||
<dt><tt>mask_index</tt> (optional) : M</dt>
|
||||
<dd>Mask values of shape (batch_size, total_sequence_length)</dd>
|
||||
<dt><tt>past</tt> (optional) : T</dt>
|
||||
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>past</tt> : T</dt>
|
||||
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size). The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape (batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape (batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.</dd>
|
||||
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
|
||||
<dd>additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)</dd>
|
||||
<dt><tt>past_sequence_length</tt> (optional) : M</dt>
|
||||
<dt><tt>past_sequence_length</tt> : M</dt>
|
||||
<dd>When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).</dd>
|
||||
<dt><tt>beam_width</tt> (optional) : M</dt>
|
||||
<dd>The beam width that is being used while decoding.If not provided, the beam width will be assumed to be 1.</dd>
|
||||
<dt><tt>cache_indirection</tt> (optional) : M</dt>
|
||||
<dd>A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifieswhich beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - 2)
|
||||
#### Outputs
|
||||
|
||||
<dl>
|
||||
<dt><tt>output</tt> : T</dt>
|
||||
<dd>3D output tensor with shape (batch_size, sequence_length, v_hidden_size)</dd>
|
||||
<dt><tt>present</tt> (optional) : T</dt>
|
||||
<dt><tt>present</tt> : T</dt>
|
||||
<dd>past state for key and value with shape (2, batch_size, num_heads, total_sequence_length, head_size). If past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size), while effective_seq_length = (past_sequence_length + kv_sequence_length).</dd>
|
||||
</dl>
|
||||
|
||||
|
|
|
|||
|
|
@ -798,7 +798,7 @@ Do not modify directly.*
|
|||
|ComplexMulConj|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|ConvTransposeWithDynamicPads|*in* X:**T**<br> *in* W:**T**<br> *in* Pads:**tensor(int64)**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|DecoderAttention|*in* query:**T**<br> *in* key:**T**<br> *in* q_weight:**T**<br> *in* kv_weight:**T**<br> *in* bias:**T**<br> *in* key_padding_mask:**B**<br> *in* key_cache:**T**<br> *in* value_cache:**T**<br> *in* static_kv:**B**<br> *in* use_past:**B**<br> *in* has_layer_state:**B**<br> *in* has_key_padding_mask:**B**<br> *out* output:**T**<br> *out* new_key_cache:**T**<br> *out* new_value_cache:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|DecoderMaskedMultiheadAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|DecoderMaskedMultiheadAttention|*in* input:**T**<br> *in* weights:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* relative_position_bias:**T**<br> *in* past_sequence_length:**M**<br> *in* beam_width:**M**<br> *in* cache_indirection:**M**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|DequantizeLinear|*in* x:**T1**<br> *in* x_scale:**T2**<br> *in* x_zero_point:**T1**<br> *out* y:**T2**|1+|**T1** = tensor(int8), tensor(uint8)<br/> **T2** = tensor(float16)|
|
||||
|DequantizeWithOrder|*in* input:**Q**<br> *in* scale_input:**S**<br> *out* output:**F**|1+|**F** = tensor(float), tensor(float16)<br/> **Q** = tensor(int8)<br/> **S** = tensor(float)|
|
||||
|EmbedLayerNormalization|*in* input_ids:**T1**<br> *in* segment_ids:**T1**<br> *in* word_embedding:**T**<br> *in* position_embedding:**T**<br> *in* segment_embedding:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* mask:**T1**<br> *in* position_ids:**T1**<br> *out* output:**T**<br> *out* mask_index:**T1**<br> *out* embedding_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|
|
|
|||
|
|
@ -163,6 +163,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
if (has_init_decoder_) {
|
||||
ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
|
||||
ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
|
||||
"past_present_share_buffer mode must be same for init decoder and decoder subgraphes");
|
||||
}
|
||||
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
|
@ -181,12 +183,15 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
thread_pool, ctx->GetComputeStream(), dumper_, parameters,
|
||||
GenerationCpuDeviceHelper::CreateGptInputs,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
|
||||
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
|
||||
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
|
||||
device_copy_int32_func_ ? device_copy_int32_func_ : GenerationCpuDeviceHelper::DeviceCopy<int32_t>,
|
||||
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>};
|
||||
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>,
|
||||
cuda_device_prop_,
|
||||
cuda_device_arch_};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
|
|
@ -200,12 +205,15 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
thread_pool, ctx->GetComputeStream(), dumper_, parameters,
|
||||
GenerationCpuDeviceHelper::CreateGptInputs,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_fp16_func_,
|
||||
init_beam_state_fp16_func_,
|
||||
device_copy_func_,
|
||||
device_copy_int32_func_,
|
||||
update_gpt_feeds_fp16_func_};
|
||||
update_gpt_feeds_fp16_func_,
|
||||
cuda_device_prop_,
|
||||
cuda_device_arch_};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ class BeamSearch : public IControlFlowKernel {
|
|||
|
||||
// device helpers that is same for both GPT and encoder-decoder models.
|
||||
void SetDeviceHelpers(
|
||||
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
|
||||
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const GenerationDeviceHelper::TopkFunc& topk_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
|
|
@ -52,6 +53,7 @@ class BeamSearch : public IControlFlowKernel {
|
|||
const GenerationDeviceHelper::ProcessLogitsFunc<MLFloat16>& process_logits_fp16_func,
|
||||
const GenerationDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
|
||||
const GenerationDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_fp16_func) {
|
||||
reorder_past_state_func_ = reorder_past_state_func;
|
||||
add_to_feeds_func_ = add_to_feeds_func;
|
||||
topk_func_ = topk_func;
|
||||
device_copy_func_ = device_copy_func;
|
||||
|
|
@ -83,8 +85,12 @@ class BeamSearch : public IControlFlowKernel {
|
|||
expand_buffer_float16_func_ = expand_buffer_float16_func;
|
||||
}
|
||||
|
||||
const void* cuda_device_prop_ = nullptr;
|
||||
int cuda_device_arch_ = 0;
|
||||
|
||||
private:
|
||||
// Device specific functions
|
||||
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
|
||||
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
GenerationDeviceHelper::TopkFunc topk_func_;
|
||||
GenerationDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,9 @@ struct BeamSearchState : public IBeamSearchState<T> {
|
|||
int vocab_size,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
int num_heads,
|
||||
int head_size,
|
||||
int has_decoder_masked_multihead_attention,
|
||||
bool output_scores,
|
||||
bool use_position) {
|
||||
size_t batch_beam_size = SafeInt<size_t>(batch_size) * num_beams;
|
||||
|
|
@ -49,6 +52,21 @@ struct BeamSearchState : public IBeamSearchState<T> {
|
|||
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements);
|
||||
this->remaining_scores = this->scores;
|
||||
}
|
||||
|
||||
if (has_decoder_masked_multihead_attention) {
|
||||
// We need a temp staging buffer to do the past 'K' state re-ordering that is needed
|
||||
// when using DecoderMaskedMultiheadAttention
|
||||
TensorShape staging_for_past_state_reorder_buffer_shape = {static_cast<int64_t>(batch_beam_size), num_heads, max_length, head_size};
|
||||
|
||||
Tensor temp(DataTypeImpl::GetType<T>(), staging_for_past_state_reorder_buffer_shape, allocator);
|
||||
|
||||
this->staging_for_past_state_reorder = std::move(temp);
|
||||
|
||||
// We need a buffer on GPU to hold the final chosen indices after BeamScorer has finished processing
|
||||
// TODO: This is a temporary work-around as BeamScorer currently only runs on CPU.
|
||||
// We can remove these kinds of work-arounds once BeamScorer runs on CUDA eventually.
|
||||
this->chosen_indices = AllocateBuffer<int32_t>(allocator, chosen_indices_buffer_, batch_beam_size);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -61,6 +79,7 @@ struct BeamSearchState : public IBeamSearchState<T> {
|
|||
BufferUniquePtr beam_scores_buffer_;
|
||||
BufferUniquePtr scores_buffer_;
|
||||
BufferUniquePtr topk_temp_buffer_;
|
||||
BufferUniquePtr chosen_indices_buffer_;
|
||||
};
|
||||
|
||||
struct BeamSearchCpuState : public IBeamSearchCpuState {
|
||||
|
|
@ -124,7 +143,7 @@ struct BeamSearchCpuState : public IBeamSearchCpuState {
|
|||
|
||||
// Base class of beam search implementation that is common for both GPT-2 and T5.
|
||||
template <typename T>
|
||||
class BeamSearchBase : public GenerateBase {
|
||||
class BeamSearchBase : public GenerateBase {
|
||||
public:
|
||||
BeamSearchBase(OpKernelContextInternal& context,
|
||||
const SessionState& decoder_session_state,
|
||||
|
|
@ -136,13 +155,13 @@ class BeamSearchBase : public GenerateBase {
|
|||
const GenerationDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func)
|
||||
: GenerateBase(context,
|
||||
decoder_session_state,
|
||||
thread_pool,
|
||||
ort_stream,
|
||||
cuda_dumper,
|
||||
topk_func,
|
||||
device_copy_func),
|
||||
: GenerateBase(context,
|
||||
decoder_session_state,
|
||||
thread_pool,
|
||||
ort_stream,
|
||||
cuda_dumper,
|
||||
topk_func,
|
||||
device_copy_func),
|
||||
parameters_(¶ms),
|
||||
process_logits_func_(process_logits_func),
|
||||
device_copy_int32_func_(device_copy_int32_func) {
|
||||
|
|
@ -188,11 +207,11 @@ Status BeamSearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
|
|||
// input_ids : (batch_size, sequence_length)
|
||||
// vocab_mask : (vocab_size) or nullptr
|
||||
ORT_RETURN_IF_ERROR(this->CheckInputsImpl(parameters_,
|
||||
context.Input<Tensor>(0), // input_ids
|
||||
context.Input<Tensor>(7), // vocab_mask
|
||||
context.Input<Tensor>(8), // prefix_vocab_mask
|
||||
context.Input<Tensor>(9), // attention_mask
|
||||
nullptr)); // presence_mask
|
||||
context.Input<Tensor>(0), // input_ids
|
||||
context.Input<Tensor>(7), // vocab_mask
|
||||
context.Input<Tensor>(8), // prefix_vocab_mask
|
||||
context.Input<Tensor>(9), // attention_mask
|
||||
nullptr)); // presence_mask
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,12 +27,15 @@ class BeamSearchGpt : public BeamSearchBase<T> {
|
|||
BeamSearchParameters& params,
|
||||
const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func,
|
||||
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
|
||||
const GenerationDeviceHelper::TopkFunc& topk_func,
|
||||
const GenerationDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
|
||||
const GenerationDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
|
||||
const GenerationDeviceHelper::UpdateGptFeedsFunc<T>& update_feeds_func)
|
||||
const GenerationDeviceHelper::UpdateGptFeedsFunc<T>& update_feeds_func,
|
||||
const void* cuda_device_prop,
|
||||
int cuda_device_arch)
|
||||
: BeamSearchBase<T>(context, decoder_session_state, thread_pool,
|
||||
ort_stream, cuda_dumper, params,
|
||||
topk_func, process_logits_func, device_copy_func, device_copy_int32_func),
|
||||
|
|
@ -42,7 +45,17 @@ class BeamSearchGpt : public BeamSearchBase<T> {
|
|||
create_inputs_func_(create_inputs_func),
|
||||
add_to_feeds_func_(add_to_feeds_func),
|
||||
init_beam_state_func_(init_beam_state_func),
|
||||
update_feeds_func_(update_feeds_func) {
|
||||
reorder_past_state_func_(reorder_past_state_func),
|
||||
update_feeds_func_(update_feeds_func),
|
||||
cuda_device_prop_(cuda_device_prop),
|
||||
cuda_device_arch_(cuda_device_arch) {
|
||||
if (gpt_subgraph_.has_decoder_masked_multihead_attention_) {
|
||||
ORT_ENFORCE(cuda_device_arch_ >= 530,
|
||||
"Decoder masked multihead attention can only be used on "
|
||||
"GPU cards of compute capability 5.3 or higher. "
|
||||
"This card has compute capability ",
|
||||
cuda_device_arch_);
|
||||
}
|
||||
}
|
||||
|
||||
// Execute beam search in iterations util stopping criteria is reached.
|
||||
|
|
@ -55,7 +68,8 @@ class BeamSearchGpt : public BeamSearchBase<T> {
|
|||
Status CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer);
|
||||
IAllocatorUniquePtr<char>& buffer,
|
||||
bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
|
||||
// Update the input for next iteration.
|
||||
Status UpdateFeeds(
|
||||
|
|
@ -65,7 +79,11 @@ class BeamSearchGpt : public BeamSearchBase<T> {
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices);
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int past_sequence_length,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
|
||||
const SessionState* init_run_decoder_session_state_ = nullptr;
|
||||
GptSubgraph* init_run_gpt_subgraph_ = nullptr;
|
||||
|
|
@ -75,14 +93,19 @@ class BeamSearchGpt : public BeamSearchBase<T> {
|
|||
GenerationDeviceHelper::CreateGptInputsFunc create_inputs_func_;
|
||||
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
|
||||
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
|
||||
GenerationDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
|
||||
|
||||
const void* cuda_device_prop_ = nullptr;
|
||||
int cuda_device_arch_ = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer) {
|
||||
IAllocatorUniquePtr<char>& buffer,
|
||||
bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
|
||||
const OrtValue* input_ids_value = this->context_.GetInputOrtValue(0);
|
||||
const Tensor& input_ids = input_ids_value->Get<Tensor>();
|
||||
const OrtValue* attn_mask_value = this->context_.GetInputOrtValue(9);
|
||||
|
|
@ -99,7 +122,9 @@ Status BeamSearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths
|
|||
this->create_inputs_func_,
|
||||
this->add_to_feeds_func_,
|
||||
buffer,
|
||||
this->ort_stream_);
|
||||
this->ort_stream_,
|
||||
this->parameters_->max_length,
|
||||
add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
}
|
||||
|
||||
return gpt_subgraph_.CreateInitialFeeds(input_ids,
|
||||
|
|
@ -113,7 +138,9 @@ Status BeamSearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths
|
|||
this->create_inputs_func_,
|
||||
this->add_to_feeds_func_,
|
||||
buffer,
|
||||
this->ort_stream_);
|
||||
this->ort_stream_,
|
||||
this->parameters_->max_length,
|
||||
add_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -124,7 +151,11 @@ Status BeamSearchGpt<T>::UpdateFeeds(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices) {
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int past_sequence_length,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
|
||||
return update_feeds_func_(this->temp_space_allocator_,
|
||||
this->ort_stream_,
|
||||
last_outputs,
|
||||
|
|
@ -133,12 +164,15 @@ Status BeamSearchGpt<T>::UpdateFeeds(
|
|||
position_ids,
|
||||
increase_position,
|
||||
beam_next_tokens,
|
||||
beam_indices,
|
||||
beam_indices_cpu,
|
||||
beam_indices_gpu,
|
||||
this->parameters_->num_beams,
|
||||
gpt_subgraph_.GetFirstPastInputIndex(),
|
||||
gpt_subgraph_.GetFirstPresentOutputIndex(),
|
||||
false,
|
||||
-1);
|
||||
gpt_subgraph_.past_present_share_buffer_,
|
||||
past_sequence_length,
|
||||
input_sequence_len,
|
||||
has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -192,7 +226,22 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
|
|||
// buffer in GPU for input_ids, position_ids and attention_mask
|
||||
IAllocatorUniquePtr<char> buffer;
|
||||
OrtValue expanded_input_ids_in_cpu;
|
||||
ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
|
||||
ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer,
|
||||
gpt_subgraph_.has_decoder_masked_multihead_attention_));
|
||||
|
||||
if (gpt_subgraph_.past_present_share_buffer_) { // Reuse past and present
|
||||
fetches.reserve(static_cast<int64_t>(gpt_subgraph_.GetFirstPresentOutputIndex()) + gpt_subgraph_.num_layers);
|
||||
fetches.resize(gpt_subgraph_.GetFirstPresentOutputIndex(), OrtValue());
|
||||
for (int layer = 0; layer < gpt_subgraph_.num_layers; layer++) {
|
||||
int feed_idx = gpt_subgraph_.GetFirstPastInputIndex() + layer;
|
||||
OrtValue& past_tensor_value = feeds[feed_idx];
|
||||
Tensor* past_tensor = past_tensor_value.GetMutable<Tensor>();
|
||||
OrtValue present_tensor_value;
|
||||
Tensor::InitOrtValue(past_tensor->DataType(), past_tensor->Shape(), past_tensor->MutableData<T>(),
|
||||
past_tensor->Location(), present_tensor_value);
|
||||
fetches.push_back(present_tensor_value);
|
||||
}
|
||||
}
|
||||
|
||||
BeamSearchState<T> beam_state;
|
||||
constexpr bool use_position = true;
|
||||
|
|
@ -202,6 +251,9 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
|
|||
parameters->vocab_size,
|
||||
parameters->sequence_length,
|
||||
parameters->max_length,
|
||||
parameters->num_heads,
|
||||
parameters->head_size,
|
||||
gpt_subgraph_.has_decoder_masked_multihead_attention_,
|
||||
parameters->output_scores,
|
||||
use_position);
|
||||
|
||||
|
|
@ -297,17 +349,49 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
|
|||
// Increase sequence length after a new token is generated.
|
||||
++current_length;
|
||||
|
||||
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
|
||||
// contains DecoderMaskedMultiheadAttention nodes
|
||||
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_multihead_attention_) {
|
||||
size_t offset = static_cast<size_t>(gpt_subgraph_.GetFirstPresentOutputIndex());
|
||||
// We will use the same staging buffer while transposing all the layers' past state
|
||||
// and this is okay because we use the same stream to do the staging copy and the transpose
|
||||
// operations.
|
||||
// If we ever do them in different streams, we must use different staging buffers to avoid data
|
||||
// races.
|
||||
for (size_t i = 0; i < static_cast<size_t>(gpt_subgraph_.num_layers); ++i) {
|
||||
ORT_RETURN_IF_ERROR(reorder_past_state_func_(cuda_device_prop_,
|
||||
*fetches[offset + i].GetMutable<Tensor>(),
|
||||
beam_state.staging_for_past_state_reorder,
|
||||
this->ort_stream_));
|
||||
}
|
||||
}
|
||||
|
||||
// Prepare inputs for next round of subgraph call.
|
||||
if (current_length < parameters->max_length) {
|
||||
gsl::span<const int32_t> place_holder;
|
||||
// For the first iteration, position_ids is initialized as sequence lengths. We can add it to feeds directly.
|
||||
// For the remaining iterations, we need increase position_ids first, then add it to feeds.
|
||||
bool increase_position = (iteration_counter > 1);
|
||||
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
|
||||
position_ids, increase_position,
|
||||
ReinterpretAsSpan<const int32_t>(beam_next_tokens),
|
||||
ReinterpretAsSpan<const int32_t>(beam_indices)));
|
||||
ReinterpretAsSpan<const int32_t>(beam_indices),
|
||||
gpt_subgraph_.has_decoder_masked_multihead_attention_
|
||||
? ReinterpretAsSpan<const int32_t>(beam_state.chosen_indices)
|
||||
: place_holder,
|
||||
current_length - 1,
|
||||
parameters->sequence_length,
|
||||
gpt_subgraph_.has_decoder_masked_multihead_attention_));
|
||||
}
|
||||
|
||||
if (gpt_subgraph_.past_present_share_buffer_) {
|
||||
// clear fetched values before presents[]
|
||||
for (int idx = 0; idx < gpt_subgraph_.GetFirstPresentOutputIndex(); idx++) {
|
||||
fetches[idx] = OrtValue();
|
||||
}
|
||||
} else {
|
||||
fetches.clear();
|
||||
}
|
||||
fetches.clear();
|
||||
}
|
||||
|
||||
gsl::span<const float> final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
|
||||
|
|
|
|||
|
|
@ -193,6 +193,9 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
|
|||
parameters->vocab_size,
|
||||
parameters->sequence_length,
|
||||
parameters->max_length,
|
||||
parameters->num_heads,
|
||||
parameters->head_size,
|
||||
false, // TODO: Support past/present state buffer re-use for T5
|
||||
parameters->output_scores,
|
||||
use_position);
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ class BeamSearchScorer : public IBeamScorer {
|
|||
|
||||
gsl::span<float>& GetNextScores() { return next_beam_scores_; }
|
||||
gsl::span<int32_t>& GetNextTokens() { return next_beam_tokens_; }
|
||||
gsl::span<int32_t>& GetNextIndices() { return next_beam_indices_; }
|
||||
gsl::span<int32_t>& GetNextIndices() override { return next_beam_indices_; }
|
||||
|
||||
private:
|
||||
size_t batch_size_;
|
||||
|
|
|
|||
|
|
@ -407,18 +407,18 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
|
||||
template <typename T>
|
||||
Status GreedySearchProcessLogits(
|
||||
const OrtValue& logits, // logits output of subgraph
|
||||
transformers::IGreedySearchState<T>* greedy_state, // state
|
||||
transformers::ISamplingState<T>* sampling_state, // sampling_state
|
||||
transformers::ISequences* sequences, // sequences
|
||||
AllocatorPtr& allocator, // default allocator
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
bool do_sampling, // whether to do sampling
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper) { // tensor dumper
|
||||
const OrtValue& logits, // logits output of subgraph
|
||||
transformers::IGreedySearchState<T>* greedy_state, // state
|
||||
transformers::ISamplingState<T>* sampling_state, // sampling_state
|
||||
transformers::ISequences* sequences, // sequences
|
||||
AllocatorPtr& allocator, // default allocator
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
bool do_sampling, // whether to do sampling
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper) { // tensor dumper
|
||||
|
||||
int batch_size = parameters->batch_size;
|
||||
int vocab_size = parameters->vocab_size;
|
||||
|
|
@ -499,8 +499,8 @@ Status GreedySearchProcessLogits(
|
|||
topk_indices));
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("topk_scores", topk_scores);
|
||||
dumper->Print("topk_indices", topk_indices);
|
||||
dumper->Print("topk_scores", topk_scores);
|
||||
dumper->Print("topk_indices", topk_indices);
|
||||
#endif
|
||||
|
||||
gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
|
||||
|
|
@ -574,15 +574,21 @@ Status UpdateGptFeeds(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
bool past_present_share_buffer,
|
||||
int past_sequence_len) {
|
||||
int past_sequence_len,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
|
||||
// last_outputs: logits, present_0, present_1, ...
|
||||
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
|
||||
ORT_UNUSED_PARAMETER(stream);
|
||||
ORT_UNUSED_PARAMETER(beam_indices_gpu);
|
||||
ORT_UNUSED_PARAMETER(input_sequence_len);
|
||||
ORT_UNUSED_PARAMETER(has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
|
||||
// The following updates inputs for subgraph
|
||||
|
||||
|
|
@ -631,14 +637,14 @@ Status UpdateGptFeeds(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
if (num_beams == 1) { // Update past state
|
||||
if (num_beams == 1) { // Update past state
|
||||
// feed present_* output to past_* inputs one by one
|
||||
const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx;
|
||||
for (size_t i = gpt_subgraph_first_present_output_idx; i < last_outputs.size(); ++i) {
|
||||
next_inputs[i + k] = last_outputs[i];
|
||||
}
|
||||
} else {
|
||||
PickGptPastState<T>(last_outputs, next_inputs, beam_indices,
|
||||
PickGptPastState<T>(last_outputs, next_inputs, beam_indices_cpu,
|
||||
gpt_subgraph_first_past_input_idx,
|
||||
gpt_subgraph_first_present_output_idx, allocator);
|
||||
}
|
||||
|
|
@ -890,12 +896,15 @@ template Status UpdateGptFeeds<float>(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
bool past_present_share_buffer,
|
||||
int past_sequence_len);
|
||||
int past_sequence_len,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
|
||||
template Status UpdateDecoderFeeds<float>(
|
||||
AllocatorPtr allocator,
|
||||
|
|
|
|||
|
|
@ -128,12 +128,15 @@ using UpdateGptFeedsFunc = std::function<Status(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
bool past_present_share_buffer,
|
||||
int past_sequence_len)>;
|
||||
int past_sequence_len,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention)>;
|
||||
|
||||
// Create encoder inputs (for encoder-decoder model like T5).
|
||||
using CreateEncoderInputsFunc = std::function<Status(
|
||||
|
|
@ -262,12 +265,15 @@ Status UpdateGptFeeds(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
bool past_present_share_buffer,
|
||||
int past_sequence_len);
|
||||
int past_sequence_len,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Functions for encoder-decoder model like T5
|
||||
|
|
|
|||
|
|
@ -38,6 +38,11 @@ struct IBeamSearchState {
|
|||
// temp token: (batch_size * num_beams, 2 * num_beams)
|
||||
// in total, it will be:
|
||||
// 2 * (batch_size * num_beams * (parts_vocab + 1), 2 * num_beams)
|
||||
|
||||
// The final chosen indices after BeamScorer has finished processing
|
||||
gsl::span<int32_t> chosen_indices; // shape (batch_size, num_beams)
|
||||
|
||||
Tensor staging_for_past_state_reorder; // Tensor of shape (batch_size * num_beams, num_heads, max_length, head_size)
|
||||
};
|
||||
|
||||
struct IBeamSearchCpuState {
|
||||
|
|
@ -117,6 +122,8 @@ class IBeamScorer {
|
|||
gsl::span<const float>& final_beam_scores,
|
||||
Tensor* output_sequences,
|
||||
Tensor* output_sequence_scores) = 0;
|
||||
|
||||
virtual gsl::span<int32_t>& GetNextIndices() = 0;
|
||||
};
|
||||
|
||||
struct IGenerationParameters {
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ struct GreedySearchState : public IGreedySearchState<T> {
|
|||
int max_length,
|
||||
int num_heads,
|
||||
int head_size,
|
||||
bool allocate_staging_buffer_for_past_state_reorder,
|
||||
bool has_decoder_masked_multihead_attention,
|
||||
bool is_cuda) {
|
||||
// below buffers are on cpu
|
||||
this->sequences_space = AllocateBuffer<int32_t>(cpu_allocator,
|
||||
|
|
@ -111,8 +111,9 @@ struct GreedySearchState : public IGreedySearchState<T> {
|
|||
this->topk_scores_buffer,
|
||||
this->topk_tokens_buffer);
|
||||
|
||||
// If at all we need to, we only need to re-order past state for CUDA
|
||||
if (allocate_staging_buffer_for_past_state_reorder) {
|
||||
// If at all we need to, we only need to re-order past state for CUDA as
|
||||
//`DecoderMaskedMultiheadAttention` is only supported on CUDA
|
||||
if (has_decoder_masked_multihead_attention) {
|
||||
TensorShape staging_for_past_state_reorder_buffer_shape = {batch_size, num_heads, max_length, head_size};
|
||||
|
||||
Tensor temp(DataTypeImpl::GetType<T>(), staging_for_past_state_reorder_buffer_shape, allocator);
|
||||
|
|
|
|||
|
|
@ -170,11 +170,14 @@ Status GreedySearchGpt<T, ParametersT>::UpdateFeeds(
|
|||
increase_position,
|
||||
next_tokens,
|
||||
place_holder,
|
||||
place_holder,
|
||||
this->parameters_->num_beams,
|
||||
gpt_subgraph_.GetFirstPastInputIndex(),
|
||||
gpt_subgraph_.GetFirstPresentOutputIndex(),
|
||||
gpt_subgraph_.past_present_share_buffer_,
|
||||
past_sequence_length);
|
||||
past_sequence_length,
|
||||
-1, // Input sequence length needn't be passed in for GreedySearch
|
||||
false);
|
||||
}
|
||||
|
||||
template <typename T, typename ParametersT>
|
||||
|
|
@ -329,7 +332,7 @@ Status GreedySearchGpt<T, ParametersT>::Execute(const FeedsFetchesManager* init_
|
|||
// Reorder past state after first run if the GPT subgraph (the one used after the first iteration)
|
||||
// contains DecoderMaskedMultiheadAttention nodes
|
||||
if (iteration_counter == 1 && gpt_subgraph_.has_decoder_masked_multihead_attention_) {
|
||||
size_t offset = static_cast<int64_t>(gpt_subgraph_.GetFirstPresentOutputIndex());
|
||||
size_t offset = static_cast<size_t>(gpt_subgraph_.GetFirstPresentOutputIndex());
|
||||
// We will use the same staging buffer while transposing all the layers' past state
|
||||
// and this is okay because we use the same stream to do the staging copy and the transpose
|
||||
// operations.
|
||||
|
|
|
|||
|
|
@ -91,6 +91,9 @@ Status Subgraph::Setup(const SessionState& session_state,
|
|||
past_present_share_buffer_ = true;
|
||||
// past_sequence_length is on CPU memory
|
||||
feed_locations.push_back(OrtDevice());
|
||||
} else if (feed_names[i] == "beam_width") {
|
||||
// beam_width is on CPU memory
|
||||
feed_locations.push_back(OrtDevice());
|
||||
} else {
|
||||
feed_locations.push_back(default_location.device);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ Status GptSubgraph::CreateInitialFeeds(
|
|||
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
IAllocatorUniquePtr<char>& buffer,
|
||||
Stream* ort_stream,
|
||||
int max_seq_len_past_present_share_buffer) {
|
||||
int past_present_share_buffer_max_seq_len,
|
||||
bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
|
||||
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
|
||||
|
||||
const IExecutionProvider* provider = GetProvider();
|
||||
|
|
@ -83,20 +84,48 @@ Status GptSubgraph::CreateInitialFeeds(
|
|||
feeds.push_back(empty_past);
|
||||
}
|
||||
} else {
|
||||
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, max_seq_len_past_present_share_buffer, head_size};
|
||||
// Past state feeds
|
||||
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, past_present_share_buffer_max_seq_len, head_size};
|
||||
TensorShape past_shape(&past_state_dims[0], 5);
|
||||
// The remaining inputs are past state except the last one
|
||||
for (int i = first_past_input_index_; i < num_subgraph_inputs - 1; ++i) {
|
||||
|
||||
// The remaining inputs are past state except the last one or three (see below for details)
|
||||
// If `add_beam_search_specific_inputs_for_decoder_masked_multihead_attention` is false, then the last input is `past_sequence_length`
|
||||
|
||||
// If `add_beam_search_specific_inputs_for_decoder_masked_multihead_attention` is true, then the last inputs are `past_sequence_length`,
|
||||
// `beam_width`, and `cache_indirection`
|
||||
auto past_end_iter = add_beam_search_specific_inputs_for_decoder_masked_multihead_attention ? num_subgraph_inputs - 3 : num_subgraph_inputs - 1;
|
||||
for (int i = first_past_input_index_; i < past_end_iter; ++i) {
|
||||
OrtValue past_tensor;
|
||||
Tensor::InitOrtValue(past_type, past_shape, default_allocator, past_tensor);
|
||||
feeds.push_back(past_tensor);
|
||||
}
|
||||
|
||||
// Past sequence length feed
|
||||
int64_t past_seq_len_dims[] = {1};
|
||||
TensorShape past_seq_len_shape(&past_seq_len_dims[0], 1);
|
||||
OrtValue past_seq_len_tensor_value;
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), past_seq_len_shape, cpu_allocator, past_seq_len_tensor_value);
|
||||
feeds.push_back(past_seq_len_tensor_value);
|
||||
*past_seq_len_tensor_value.GetMutable<Tensor>()->MutableData<int32_t>() = 0;
|
||||
|
||||
// Add beam search specific inputs
|
||||
if (add_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
|
||||
// Beam width feed
|
||||
int64_t num_beams_dims[] = {1};
|
||||
TensorShape num_beams_shape(&num_beams_dims[0], 1);
|
||||
OrtValue num_beams_tensor_value;
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), num_beams_shape, cpu_allocator, num_beams_tensor_value);
|
||||
feeds.push_back(num_beams_tensor_value);
|
||||
*num_beams_tensor_value.GetMutable<Tensor>()->MutableData<int32_t>() = static_cast<int32_t>(num_beams);
|
||||
|
||||
// Cache indirection feed
|
||||
int64_t cache_indirection_dims[] = {batch_size, num_beams, past_present_share_buffer_max_seq_len};
|
||||
TensorShape cache_indirection_shape(&cache_indirection_dims[0], 3);
|
||||
OrtValue default_cache_indirection;
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), cache_indirection_shape,
|
||||
default_allocator, default_cache_indirection);
|
||||
feeds.push_back(default_cache_indirection);
|
||||
}
|
||||
}
|
||||
|
||||
// Pass in implicit inputs
|
||||
|
|
@ -112,8 +141,12 @@ Status GptSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
|||
ORT_RETURN_IF(num_subgraph_outputs <= first_present_output_index_,
|
||||
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in outputs).");
|
||||
|
||||
ORT_RETURN_IF(!((num_subgraph_inputs == num_subgraph_outputs + 2) || (num_subgraph_inputs == num_subgraph_outputs + 3)),
|
||||
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2 or 3 (if past_present_share_buffer)");
|
||||
ORT_RETURN_IF(!((num_subgraph_inputs == num_subgraph_outputs + 2) ||
|
||||
(num_subgraph_inputs == num_subgraph_outputs + 3) ||
|
||||
(num_subgraph_inputs == num_subgraph_outputs + 5)),
|
||||
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2 or "
|
||||
"3 (if past_present_share_buffer) or "
|
||||
"5 (if past_present_share_buffer and use_decoder_masked_multihead_attention for BeamSearch)");
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
|
||||
"subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
|
||||
|
|
|
|||
|
|
@ -16,9 +16,9 @@ class GptSubgraph : public Subgraph {
|
|||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {
|
||||
first_past_input_index_ = 3;
|
||||
first_present_output_index_ = 1;
|
||||
}
|
||||
first_past_input_index_ = 3;
|
||||
first_present_output_index_ = 1;
|
||||
}
|
||||
|
||||
// Create inputs for first inference of subgraph.
|
||||
Status CreateInitialFeeds(
|
||||
|
|
@ -34,7 +34,8 @@ class GptSubgraph : public Subgraph {
|
|||
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
IAllocatorUniquePtr<char>& buffer,
|
||||
Stream* ort_stream,
|
||||
int max_seq_len_past_present_share_buffer = -1);
|
||||
int past_present_share_buffer_max_seq_len = -1,
|
||||
bool add_beam_search_specific_inputs_for_decoder_masked_multihead_attention = false);
|
||||
|
||||
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) override;
|
||||
|
|
|
|||
|
|
@ -16,20 +16,23 @@ namespace contrib {
|
|||
namespace cuda {
|
||||
|
||||
static constexpr int kPastSequenceLengthInputIndex = 6;
|
||||
static constexpr int kBeamWidthInputIndex = 7;
|
||||
static constexpr int kCacheIndirectionInputIndex = 8;
|
||||
static constexpr int kPastInputIndex = 4;
|
||||
static constexpr int kPresentOutputIndex = 1;
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T1, T2) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
DecoderMaskedMultiheadAttention, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T1, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T1>()) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex), \
|
||||
#define REGISTER_KERNEL_TYPED(T1, T2) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
DecoderMaskedMultiheadAttention, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T1, \
|
||||
kCudaExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.MayInplace(kPastInputIndex, kPresentOutputIndex) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T1>()) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, kPastSequenceLengthInputIndex) \
|
||||
.InputMemoryType(OrtMemTypeCPUInput, kBeamWidthInputIndex), \
|
||||
DecoderMaskedMultiheadAttention<T1, T2>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float, float)
|
||||
|
|
@ -44,6 +47,8 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
const Tensor* past = context->Input<Tensor>(kPastInputIndex);
|
||||
const Tensor* relative_position_bias = context->Input<Tensor>(5);
|
||||
const Tensor* past_seq_len = context->Input<Tensor>(kPastSequenceLengthInputIndex);
|
||||
const Tensor* beam_width = context->Input<Tensor>(kBeamWidthInputIndex);
|
||||
const Tensor* cache_indir = context->Input<Tensor>(kCacheIndirectionInputIndex);
|
||||
|
||||
auto& device_prop = GetDeviceProp();
|
||||
DecoderMaskedMultiheadAttentionParams parameters;
|
||||
|
|
@ -105,7 +110,7 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
auto* present_data = present->MutableData<T1>();
|
||||
auto* past_data = past->Data<T1>();
|
||||
|
||||
// No production use-case will incure this copy cost as the implementation of
|
||||
// No production use-case will incur this copy cost as the implementation of
|
||||
// GreedySearch/BeamSearch is written in such a way that the past and present buffers
|
||||
// will be shared.
|
||||
// This is just to circumvent the OpTester's limitation of not being able to bind a specific
|
||||
|
|
@ -142,13 +147,13 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
// Update the q, k, and v buffers
|
||||
parameters.q = gemm_buffer.get();
|
||||
parameters.k = reinterpret_cast<CudaT*>(gemm_buffer.get()) + parameters.hidden_size;
|
||||
parameters.v = reinterpret_cast<CudaT*>(gemm_buffer.get()) + 2 * parameters.hidden_size;
|
||||
parameters.v = reinterpret_cast<CudaT*>(gemm_buffer.get()) + 2 * static_cast<int64_t>(parameters.hidden_size);
|
||||
|
||||
// Update the q, k, and v bias
|
||||
const T1* bias_data = bias->Data<T1>();
|
||||
parameters.q_bias = const_cast<T1*>(bias_data);
|
||||
parameters.k_bias = const_cast<T1*>(bias_data + parameters.hidden_size);
|
||||
parameters.v_bias = const_cast<T1*>(bias_data + 2 * parameters.hidden_size);
|
||||
parameters.v_bias = const_cast<T1*>(bias_data + 2 * static_cast<int64_t>(parameters.hidden_size));
|
||||
|
||||
// Half of the past/present buffer correspond to K - the other half is V.
|
||||
auto k_size = present->Shape().Size() / 2;
|
||||
|
|
@ -167,6 +172,22 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
parameters.mask = mask_index->Data<int32_t>();
|
||||
}
|
||||
|
||||
// Beam width (in case we are using this op inside BeamSearch)
|
||||
if (beam_width != nullptr) {
|
||||
parameters.beam_width = static_cast<int>(*beam_width->Data<int32_t>());
|
||||
}
|
||||
|
||||
// Cache indirection (in case we are using this op inside BeamSearch)
|
||||
if (parameters.beam_width > 1) {
|
||||
// If beam width > 1, then cache indirection buffer MUST be present
|
||||
if (cache_indir == nullptr) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"If beam width is greater than 1, then cache indirection buffer MUST be present");
|
||||
}
|
||||
|
||||
parameters.cache_indir = cache_indir->Data<int32_t>();
|
||||
}
|
||||
|
||||
switch (parameters.head_size) {
|
||||
case 64:
|
||||
mmha_launch_kernel<T2, 64>(parameters, cuda_stream);
|
||||
|
|
@ -182,7 +203,6 @@ Status DecoderMaskedMultiheadAttention<T1, T2>::ComputeInternal(OpKernelContext*
|
|||
"Got head size: ",
|
||||
parameters.head_size);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@
|
|||
// Modifications:
|
||||
// (1) Removed some code paths from the original implementation that had features which is not supported by
|
||||
// corresponding ORT kernel - for example- CrossAttention support, FP8, INT8, supports, etc.
|
||||
// (2) When dealing with masked tokens, this kernel implementation deviates from FasterTransformer by applying
|
||||
// mask filter values. Appropriate commentary exists in the code below.
|
||||
|
||||
#include "decoder_masked_multihead_attention_impl.h"
|
||||
#include "decoder_masked_multihead_attention_impl_utils.h"
|
||||
|
|
@ -286,7 +288,6 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
|
|||
|
||||
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
|
||||
bool has_beams = params.cache_indir != nullptr;
|
||||
|
||||
const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_max_seq_length] : nullptr;
|
||||
|
||||
for (int ti = ko; ti < ti_end; ti += K_PER_ITER) {
|
||||
|
|
@ -294,16 +295,24 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
|
|||
|
||||
// The keys loaded from the key cache.
|
||||
K_vec_k k_vec[K_VECS_PER_THREAD];
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
||||
int jj = ii * params.max_sequence_length + ti;
|
||||
|
||||
if (ti < tlength) {
|
||||
if (has_beams) {
|
||||
if (has_beams) {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
||||
int jj = ii * params.max_sequence_length + ti;
|
||||
|
||||
if (ti < tlength) {
|
||||
const int beam_offset = beam_indices[ti] * params.num_heads * params.max_sequence_length * head_size;
|
||||
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
|
||||
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B])));
|
||||
} else {
|
||||
}
|
||||
}
|
||||
} else {
|
||||
#pragma unroll
|
||||
for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) {
|
||||
int jj = ii * params.max_sequence_length + ti;
|
||||
|
||||
if (ti < tlength) {
|
||||
k_vec[ii] = vec_conversion<K_vec_k, K_vec_m>(
|
||||
(*reinterpret_cast<const K_vec_m*>(&k_cache_batch[jj * QK_ELTS_IN_16B])));
|
||||
}
|
||||
|
|
@ -314,9 +323,16 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
|
|||
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
|
||||
float qk = Qk_dot<T, THREADS_PER_KEY>::dot(q_vec, k_vec) * inv_sqrt_dh;
|
||||
|
||||
// This is a deviation from FasterTransformer kernel implementation
|
||||
// but this aligns with ORT's other Attention kernels which strives to
|
||||
// mimic PyTorch when dealing with mask filter values
|
||||
if (is_masked) {
|
||||
qk += params.mask_filter_value;
|
||||
}
|
||||
|
||||
// Store the product to shared memory. There's one qk value per timestep. Update the max.
|
||||
if (ti < tlength && tidx % THREADS_PER_KEY == 0) {
|
||||
qk_max = is_masked ? qk_max : fmaxf(qk_max, qk);
|
||||
qk_max = fmaxf(qk_max, qk);
|
||||
qk_smem[ti] = qk;
|
||||
}
|
||||
}
|
||||
|
|
@ -355,8 +371,10 @@ __global__ void masked_multihead_attention_kernel(DecoderMaskedMultiheadAttentio
|
|||
// Compute the logits and start the sum.
|
||||
float sum = 0.f;
|
||||
for (int ti = tidx; ti <= tlength; ti += THREADS_PER_BLOCK) {
|
||||
bool is_masked = (params.mask != nullptr) && (params.mask[bi_total_seq_length + ti] == 0);
|
||||
float logit = is_masked ? 0.f : __expf(qk_smem[ti] - qk_max);
|
||||
// This is a deviation from FasterTransformer kernel implementation
|
||||
// but this aligns with ORT's other Attention kernels which strives to
|
||||
// mimic PyTorch when dealing with mask filter values
|
||||
float logit = __expf(qk_smem[ti] - qk_max);
|
||||
sum += logit;
|
||||
qk_smem[ti] = logit;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -35,7 +35,8 @@ transformers::CudaTensorConsoleDumper g_cuda_dumper;
|
|||
|
||||
BeamSearch::BeamSearch(const OpKernelInfo& info)
|
||||
: onnxruntime::contrib::transformers::BeamSearch(info) {
|
||||
SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds,
|
||||
SetDeviceHelpers(GenerationCudaDeviceHelper::ReorderPastState,
|
||||
GenerationCudaDeviceHelper::AddToFeeds,
|
||||
GenerationCudaDeviceHelper::TopK,
|
||||
GenerationCudaDeviceHelper::DeviceCopy<float>,
|
||||
GenerationCudaDeviceHelper::DeviceCopy<int32_t>,
|
||||
|
|
@ -54,6 +55,11 @@ BeamSearch::BeamSearch(const OpKernelInfo& info)
|
|||
GenerationCudaDeviceHelper::ExpandBuffer<MLFloat16>);
|
||||
|
||||
SetConsoleDumper(&g_cuda_dumper);
|
||||
|
||||
cuda_device_prop_ = &reinterpret_cast<const CUDAExecutionProvider*>(info.GetExecutionProvider())->GetDeviceProp();
|
||||
|
||||
cuda_device_arch_ = static_cast<const cudaDeviceProp*>(cuda_device_prop_)->major * 100 +
|
||||
static_cast<const cudaDeviceProp*>(cuda_device_prop_)->minor * 10;
|
||||
}
|
||||
|
||||
Status BeamSearch::ComputeInternal(OpKernelContext* context) const {
|
||||
|
|
|
|||
|
|
@ -424,7 +424,7 @@ void LaunchSortPairs(void* d_temp_storage,
|
|||
d_offsets + 1,
|
||||
0,
|
||||
sizeof(T) * 8,
|
||||
stream));
|
||||
stream));
|
||||
} else {
|
||||
CUDA_CALL_THROW(cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
|
||||
temp_storage_bytes,
|
||||
|
|
@ -469,7 +469,7 @@ template void LaunchSortPairs(void* d_temp_storage,
|
|||
// A stateful callback functor that maintains a running prefix to be applied
|
||||
// during consecutive scan operations.
|
||||
struct BlockPrefixCallbackOp {
|
||||
float running_total; // running prefix
|
||||
float running_total; // running prefix
|
||||
|
||||
__device__ BlockPrefixCallbackOp(float running_total) : running_total(running_total) {}
|
||||
// Callback operator to be entered by the first warp of threads in the block.
|
||||
|
|
@ -746,6 +746,65 @@ void TorchMultinomialKernelLauncher(float* d_input,
|
|||
d_presence_mask);
|
||||
}
|
||||
|
||||
__global__ void UpdateDecoderMaskedMultiheadAttentionCacheIndirectionKernel(int32_t* tgt_indir_cache,
|
||||
const int32_t* src_indir_cache,
|
||||
const int32_t* beam_ids,
|
||||
int batch_size,
|
||||
int beam_width,
|
||||
int input_seq_length,
|
||||
int max_seq_length,
|
||||
int current_length) {
|
||||
int time_step = threadIdx.x + blockIdx.x * blockDim.x;
|
||||
int bb_id = threadIdx.y + blockIdx.y * blockDim.y;
|
||||
const int batch_id = bb_id / beam_width;
|
||||
const int beam_id = bb_id % beam_width;
|
||||
|
||||
if (bb_id >= beam_width * batch_size || time_step >= current_length) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int src_beam = beam_ids[batch_id * beam_width + beam_id] % beam_width;
|
||||
|
||||
const int tgt_offset = batch_id * beam_width * max_seq_length + beam_id * max_seq_length + time_step;
|
||||
|
||||
if (time_step < input_seq_length) {
|
||||
// For time steps that correspond to the input sequence,
|
||||
// the beam that it comes from is always 0.
|
||||
tgt_indir_cache[tgt_offset] = static_cast<int32_t>(0);
|
||||
} else if (time_step == (current_length - 1)) {
|
||||
// For the final (newly generated) time step,
|
||||
// the beam that it comes from is always the beam that we
|
||||
// are currently processing (i.e.) from this point on, these time-steps
|
||||
// form the new beams.
|
||||
tgt_indir_cache[tgt_offset] = static_cast<int32_t>(beam_id);
|
||||
} else {
|
||||
// For all other time-steps, we look up the source indirection, to
|
||||
// see which beam it came from based on the `src_beam`.
|
||||
const int src_offset = batch_id * beam_width * max_seq_length + src_beam * max_seq_length + time_step;
|
||||
tgt_indir_cache[tgt_offset] = src_indir_cache[src_offset];
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateDecoderMaskedMultiheadAttentionCacheIndirection(int32_t* tgt_indir_cache,
|
||||
const int32_t* src_indir_cache,
|
||||
const int32_t* beam_ids,
|
||||
int batch_size,
|
||||
int beam_width,
|
||||
int input_seq_length,
|
||||
int max_seq_length,
|
||||
int current_length,
|
||||
cudaStream_t stream) {
|
||||
const dim3 block(32);
|
||||
const dim3 grid((current_length + block.x - 1) / block.x, batch_size * beam_width);
|
||||
UpdateDecoderMaskedMultiheadAttentionCacheIndirectionKernel<<<grid, block, 0, stream>>>(tgt_indir_cache,
|
||||
src_indir_cache,
|
||||
beam_ids,
|
||||
batch_size,
|
||||
beam_width,
|
||||
input_seq_length,
|
||||
max_seq_length,
|
||||
current_length);
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ void LaunchUpdateGptKernel(const int32_t* old_mask_data,
|
|||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void GetTempStorageSize(const T *d_keys_in,
|
||||
void GetTempStorageSize(const T* d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
|
|
@ -77,15 +77,15 @@ void LaunchSetupParamsKernel(int* d_values_in,
|
|||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void LaunchSortPairs(void *d_temp_storage,
|
||||
void LaunchSortPairs(void* d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const T *d_keys_in,
|
||||
T *d_keys_out,
|
||||
const int *d_values_in,
|
||||
int *d_values_out,
|
||||
const T* d_keys_in,
|
||||
T* d_keys_out,
|
||||
const int* d_values_in,
|
||||
int* d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int *d_offsets,
|
||||
int* d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
|
|
@ -109,6 +109,16 @@ void TorchMultinomialKernelLauncher(float* d_input,
|
|||
int* d_presence_mask,
|
||||
cudaStream_t stream);
|
||||
|
||||
void UpdateDecoderMaskedMultiheadAttentionCacheIndirection(int32_t* tgt_indir_cache,
|
||||
const int32_t* src_indir_cache,
|
||||
const int32_t* beam_ids,
|
||||
int batch_size,
|
||||
int beam_width,
|
||||
int input_seq_length,
|
||||
int max_seq_length,
|
||||
int current_length,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -481,6 +481,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, 2 * num_beams);
|
||||
dumper->Print("next_scores before scorer", beam_state->next_scores.data(), batch_size, 2 * num_beams);
|
||||
#endif
|
||||
|
||||
// TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
|
||||
beam_state->next_scores.data(),
|
||||
beam_state->next_scores.size_bytes(),
|
||||
|
|
@ -528,6 +530,7 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
|
||||
#endif
|
||||
|
||||
// TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
|
||||
data,
|
||||
topk_scores->SizeInBytes(),
|
||||
|
|
@ -535,6 +538,7 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
cuda_stream));
|
||||
}
|
||||
|
||||
// TODO: Remove these kinds of cross-device copies once BeamScorer runs on CUDA
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(),
|
||||
beam_state->next_tokens.data(),
|
||||
beam_state->next_tokens.size_bytes(),
|
||||
|
|
@ -551,13 +555,31 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
gsl::span<const int32_t> next_tokens(cpu_state->topk_tokens.data(), beam_state->next_tokens.size());
|
||||
gsl::span<const int32_t> next_indices(cpu_state->topk_indices.data(), beam_state->next_indices.size());
|
||||
|
||||
// Limitation: beam scorer runs in CPU. It might be better to use CUDA kernel to replace it.
|
||||
// TODO: Implement BeamScorer on CUDA
|
||||
beam_scorer->Process(
|
||||
sequences,
|
||||
next_scores,
|
||||
next_tokens,
|
||||
next_indices);
|
||||
|
||||
// TODO: This is a temporary work-around as BeamScorer currently only runs on CPU.
|
||||
// We can remove these kinds of work-arounds once BeamScorer runs on CUDA eventually.
|
||||
auto chosen_indices = beam_scorer->GetNextIndices();
|
||||
auto beam_state_chosen_indices = beam_state->chosen_indices;
|
||||
|
||||
if (!beam_state_chosen_indices.empty()) {
|
||||
// If we have allocated `chosen_indices` in beam_state, it means that we
|
||||
// will be needing the chosen indices from BeamScorer as we are using
|
||||
// DecoderMaskedMultiheadAttention, so copy it over.
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state_chosen_indices.data(),
|
||||
chosen_indices.data(),
|
||||
chosen_indices.size_bytes(),
|
||||
cudaMemcpyHostToDevice,
|
||||
cuda_stream));
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
|
||||
}
|
||||
|
||||
#ifdef ENABLE_NVTX_PROFILE
|
||||
processLogitsRange.End();
|
||||
#endif
|
||||
|
|
@ -869,12 +891,15 @@ Status UpdateGptFeeds(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
bool past_present_share_buffer,
|
||||
int past_sequence_len) {
|
||||
int past_sequence_len,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
|
||||
#ifdef ENABLE_NVTX_PROFILE
|
||||
profile::NvtxNestedRangeCreator updateFeedsRange("UpdateGptFeeds", profile::Color::Yellow);
|
||||
updateFeedsRange.Begin();
|
||||
|
|
@ -889,6 +914,8 @@ Status UpdateGptFeeds(
|
|||
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids);
|
||||
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
cudaStream_t cuda_stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
|
||||
|
||||
// TODO(hasesh): When BeamScorer is implemented on CUDA, figure out a way to avoid this cross-device copy
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
|
||||
cudaMemcpyHostToDevice, cuda_stream));
|
||||
next_inputs[0] = input_ids;
|
||||
|
|
@ -914,8 +941,41 @@ Status UpdateGptFeeds(
|
|||
next_inputs[2] = attention_mask;
|
||||
|
||||
if (past_present_share_buffer) {
|
||||
const int k = (static_cast<int>(last_outputs.size()) - gpt_subgraph_first_present_output_idx) + gpt_subgraph_first_past_input_idx;
|
||||
*(next_inputs[k].GetMutable<Tensor>()->MutableData<int32_t>()) = past_sequence_len;
|
||||
// Update past sequence length input
|
||||
const int past_sequence_length_idx = (static_cast<int>(last_outputs.size()) - gpt_subgraph_first_present_output_idx) + gpt_subgraph_first_past_input_idx;
|
||||
*(next_inputs[past_sequence_length_idx].GetMutable<Tensor>()->MutableData<int32_t>()) = past_sequence_len;
|
||||
|
||||
// Update beam search specific input for DecoderMaskedMultiheadAttention (cache indirection) if present
|
||||
|
||||
// If the last input is not `past_sequence_length`, then the beam search specific inputs
|
||||
// for `DecoderMaskedMultiheadAttention` is present
|
||||
if (has_beam_search_specific_inputs_for_decoder_masked_multihead_attention) {
|
||||
ORT_ENFORCE(!beam_indices_gpu.empty(), "Beam indices must be present on CUDA while using DecoderMaskedMultiheadAttention with BeamSearch");
|
||||
|
||||
// The cache indirection feed comes 2 feeds after the `past_sequence_length` feed
|
||||
const OrtValue& old_cache_indirection = next_inputs[past_sequence_length_idx + 2];
|
||||
|
||||
// New cache indirection updated for next decoding run
|
||||
OrtValue cache_indirection;
|
||||
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), old_cache_indirection.Get<Tensor>().Shape(), allocator, cache_indirection);
|
||||
|
||||
// The third index of the past/present tensor is the max_sequence_length
|
||||
int max_sequence_length = static_cast<int>(last_outputs[gpt_subgraph_first_present_output_idx].Get<Tensor>().Shape()[3]);
|
||||
|
||||
// Launch kernel to update the cache indirection buffer
|
||||
cuda::UpdateDecoderMaskedMultiheadAttentionCacheIndirection(cache_indirection.GetMutable<Tensor>()->MutableData<int32_t>(),
|
||||
old_cache_indirection.Get<Tensor>().Data<int32_t>(),
|
||||
reinterpret_cast<const int32_t*>(beam_indices_gpu.data()),
|
||||
batch_beam_size / num_beams,
|
||||
num_beams,
|
||||
input_sequence_len,
|
||||
max_sequence_length,
|
||||
current_length,
|
||||
cuda_stream);
|
||||
// Update cache indirection for next decoding run
|
||||
next_inputs[past_sequence_length_idx + 2] = cache_indirection;
|
||||
}
|
||||
} else {
|
||||
if (num_beams == 1) {
|
||||
const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx;
|
||||
|
|
@ -924,7 +984,7 @@ Status UpdateGptFeeds(
|
|||
next_inputs[i + k] = last_outputs[i];
|
||||
}
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PickGptPastState<T>(last_outputs, next_inputs, beam_indices, allocator,
|
||||
ORT_RETURN_IF_ERROR(PickGptPastState<T>(last_outputs, next_inputs, beam_indices_cpu, allocator,
|
||||
gpt_subgraph_first_past_input_idx,
|
||||
gpt_subgraph_first_present_output_idx, ort_stream));
|
||||
}
|
||||
|
|
@ -1114,12 +1174,15 @@ template Status UpdateGptFeeds<float>(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
bool past_present_share_buffer,
|
||||
int past_sequence_len);
|
||||
int past_sequence_len,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
|
||||
// Float16
|
||||
template void InitBeamState<MLFloat16>(
|
||||
|
|
@ -1171,12 +1234,15 @@ template Status UpdateGptFeeds<MLFloat16>(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
bool past_present_share_buffer,
|
||||
int past_sequence_len);
|
||||
int past_sequence_len,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
|
||||
template Status UpdateDecoderFeeds<float>(
|
||||
AllocatorPtr allocator,
|
||||
|
|
|
|||
|
|
@ -97,12 +97,15 @@ Status UpdateGptFeeds(
|
|||
OrtValue& position_ids,
|
||||
bool increase_position,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
gsl::span<const int32_t> beam_indices_cpu,
|
||||
gsl::span<const int32_t> beam_indices_gpu,
|
||||
int num_beams,
|
||||
int gpt_subgraph_first_past_input_idx,
|
||||
int gpt_subgraph_first_present_output_idx,
|
||||
bool past_present_share_buffer,
|
||||
int past_sequence_len);
|
||||
int past_sequence_len,
|
||||
int input_sequence_len,
|
||||
bool has_beam_search_specific_inputs_for_decoder_masked_multihead_attention);
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Functions for encoder-decoder model like T5
|
||||
|
|
|
|||
|
|
@ -367,9 +367,14 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"past",
|
||||
"past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)"
|
||||
"When past_present_share_buffer is set, "
|
||||
"its shape is (2, batch_size, num_heads, max_sequence_length, head_size)",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
"its shape is (2, batch_size, num_heads, max_sequence_length, head_size). "
|
||||
"The first `batch_size * num_heads * max_sequence_length * head_size` elements correspond to keys "
|
||||
"and the next `batch_size * num_heads * max_sequence_length * head_size` elements correspond to values. "
|
||||
"The keys buffer is re-ordered in such a way that its virtual sub-tensor of shape "
|
||||
"(batch_size, num_heads, max_sequence_length, head_size) which may be perceived as being of shape "
|
||||
"(batch_size, num_heads, max_sequence_length, head_size / x, x) is reordered to "
|
||||
"become (batch_size, num_heads, head_size / x, max_sequence_length, x) where `x = 16 / sizeof(T)`.",
|
||||
"T")
|
||||
.Input(5,
|
||||
"relative_position_bias",
|
||||
"additional add to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length)",
|
||||
|
|
@ -378,6 +383,17 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
.Input(6,
|
||||
"past_sequence_length",
|
||||
"When past_present_share_buffer is used, it is required to specify past_sequence_length (could be 0).",
|
||||
"M")
|
||||
.Input(7,
|
||||
"beam_width",
|
||||
"The beam width that is being used while decoding."
|
||||
"If not provided, the beam width will be assumed to be 1.",
|
||||
"M",
|
||||
OpSchema::Optional)
|
||||
.Input(8,
|
||||
"cache_indirection",
|
||||
"A buffer of shape [batch_size, beam_width, max_output_length] where an [i, j, k] entry specifies"
|
||||
"which beam the 'k' th token came from for the 'j' th beam for batch 'i' in the current iteration",
|
||||
"M",
|
||||
OpSchema::Optional)
|
||||
.Output(0,
|
||||
|
|
@ -390,8 +406,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"If past_present_share_buffer is set, "
|
||||
"its shape is (2, batch_size, num_heads, max_sequence_length, head_size), "
|
||||
"while effective_seq_length = (past_sequence_length + kv_sequence_length).",
|
||||
"T",
|
||||
OpSchema::Optional)
|
||||
"T")
|
||||
.TypeConstraint("T",
|
||||
{"tensor(float)", "tensor(float16)"},
|
||||
"Constrain input and output types to float tensors.")
|
||||
|
|
|
|||
|
|
@ -1026,7 +1026,7 @@ def shape_of(vi):
|
|||
return tuple([d.dim_param if (d.dim_param) else d.dim_value for d in vi.type.tensor_type.shape.dim])
|
||||
|
||||
|
||||
def update_decoder_subgraph_past_present_share_buffer(subg):
|
||||
def update_decoder_subgraph_past_present_share_buffer(subg: GraphProto):
|
||||
input_past_0 = 3
|
||||
output_past_0 = 1
|
||||
new_inputs = []
|
||||
|
|
@ -1074,37 +1074,82 @@ def update_decoder_subgraph_past_present_share_buffer(subg):
|
|||
return subg
|
||||
|
||||
|
||||
def update_decoder_subgraph_use_decoder_masked_multihead_attention(subg) -> bool:
|
||||
def update_decoder_subgraph_use_decoder_masked_multihead_attention(
|
||||
subg: GraphProto, is_beam_search: bool, switch_attention: bool
|
||||
) -> bool:
|
||||
"""Update the Attention nodes to DecoderMaskedMultiheadAttention.
|
||||
|
||||
Args:
|
||||
subg (GraphProto): GraphProto of the decoder subgraph
|
||||
is_beam_search (bool): Boolean specifying if the sampling algo is BeamSearch
|
||||
switch_attention (bool): Boolean specifying if `Attention` is to be switched with `DecoderMaskedMultiheadAttention`
|
||||
"""
|
||||
decoder_masked_attention_supported_attr = [
|
||||
"past_present_share_buffer",
|
||||
"num_heads",
|
||||
"scale",
|
||||
"domain",
|
||||
]
|
||||
new_nodes = []
|
||||
for node in subg.node:
|
||||
if node.op_type == "Attention":
|
||||
kwargs = kwargs_of(node)
|
||||
for k in kwargs.copy():
|
||||
# The Attention operator does not support different qkv hidden sizes when past/present
|
||||
# input/output exists (GPT2 model). Hence, we should never run into this.
|
||||
# But, if we do, do not go ahead with the optimization.
|
||||
if k == "qkv_hidden_sizes":
|
||||
return False
|
||||
if is_beam_search:
|
||||
new_inputs = []
|
||||
for i, vi in enumerate(subg.input):
|
||||
new_inputs.extend([vi])
|
||||
|
||||
# Add 2 BeamSearch specific inputs
|
||||
new_inputs.extend([onnx.helper.make_tensor_value_info("beam_width", onnx.TensorProto.INT32, shape=[1])])
|
||||
new_inputs.extend(
|
||||
[
|
||||
onnx.helper.make_tensor_value_info(
|
||||
"cache_indirection", onnx.TensorProto.INT32, shape=["batch_size", "beam_width", "max_seq_len"]
|
||||
)
|
||||
]
|
||||
)
|
||||
subg.ClearField("input")
|
||||
subg.input.extend(new_inputs)
|
||||
|
||||
if switch_attention:
|
||||
decoder_masked_attention_supported_attr = [
|
||||
"past_present_share_buffer",
|
||||
"num_heads",
|
||||
"scale",
|
||||
"domain",
|
||||
]
|
||||
|
||||
new_nodes = []
|
||||
for node in subg.node:
|
||||
if node.op_type == "Attention":
|
||||
kwargs = kwargs_of(node)
|
||||
for k in kwargs.copy():
|
||||
# The Attention operator does not support different qkv hidden sizes when past/present
|
||||
# input/output exists (GPT2 model). Hence, we should never run into this.
|
||||
# But, if we do, do not go ahead with the optimization.
|
||||
if k == "qkv_hidden_sizes":
|
||||
return False
|
||||
|
||||
if k not in decoder_masked_attention_supported_attr:
|
||||
# Log the fact that we are removing certain attributes from the node
|
||||
# We don't need to log it for "unidirectional" as we are aware that
|
||||
# decoding attention kernels are unidirectional by definition.
|
||||
if k != "unidirectional":
|
||||
logger.warning(
|
||||
f"Removing attribute: {k} from Attention node while switching to DecoderMaskedMultiheadAttention"
|
||||
)
|
||||
|
||||
del kwargs[k]
|
||||
|
||||
nis = []
|
||||
nis.extend(node.input)
|
||||
|
||||
# Add 2 BeamSearch specific inputs
|
||||
if is_beam_search:
|
||||
while len(nis) < 7:
|
||||
nis.extend([""])
|
||||
if len(nis) < 8:
|
||||
nis.extend(["beam_width"])
|
||||
if len(nis) < 9:
|
||||
nis.extend(["cache_indirection"])
|
||||
|
||||
node = onnx.helper.make_node(
|
||||
"DecoderMaskedMultiheadAttention", nis, node.output, name=node.name, **kwargs
|
||||
)
|
||||
new_nodes.extend([node])
|
||||
subg.ClearField("node")
|
||||
subg.node.extend(new_nodes)
|
||||
|
||||
if k not in decoder_masked_attention_supported_attr:
|
||||
del kwargs[k]
|
||||
nis = []
|
||||
nis.extend(node.input)
|
||||
node = onnx.helper.make_node("DecoderMaskedMultiheadAttention", nis, node.output, name=node.name, **kwargs)
|
||||
new_nodes.extend([node])
|
||||
subg.ClearField("node")
|
||||
subg.node.extend(new_nodes)
|
||||
return True
|
||||
|
||||
|
||||
|
|
@ -1452,9 +1497,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH
|
||||
is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH
|
||||
is_sampling: bool = generation_type == GenerationType.SAMPLING
|
||||
past_present_share_buffer: bool = args.past_present_share_buffer and (is_greedysearch or is_sampling)
|
||||
past_present_share_buffer: bool = args.past_present_share_buffer
|
||||
|
||||
logger.info(f"**** past_present_share_buffer={past_present_share_buffer}, is_greedysearch={is_greedysearch}")
|
||||
logger.info(f"**** past_present_share_buffer={past_present_share_buffer}")
|
||||
|
||||
if is_greedysearch or is_sampling:
|
||||
if not is_gpt2:
|
||||
|
|
@ -1464,6 +1509,29 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
if args.output_token_scores:
|
||||
raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling")
|
||||
|
||||
# For BeamSearch, sharing buffers for past and present states is only supported
|
||||
# when using `use_decoder_masked_multihead_attention`
|
||||
if past_present_share_buffer and is_beamsearch and not args.use_decoder_masked_multihead_attention:
|
||||
raise ValueError(
|
||||
"`use_decoder_masked_multihead_attention` MUST be turned on to use `past_present_share_buffer` in case of BeamSearch"
|
||||
)
|
||||
|
||||
# For any kind of sampling, using decoder masked multihead attention is only supported
|
||||
# when using `past_present_share_buffer`
|
||||
if args.use_decoder_masked_multihead_attention and not past_present_share_buffer:
|
||||
raise ValueError(
|
||||
"`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_multihead_attention`"
|
||||
)
|
||||
|
||||
# For any kind of sampling, using decoder masked multihead attention is only supported
|
||||
# on GPUs
|
||||
if args.use_decoder_masked_multihead_attention and not args.use_gpu:
|
||||
raise ValueError("`use_decoder_masked_multihead_attention` option is only supported on GPUs")
|
||||
|
||||
# Using decoder masked multihead attention is only supported for GPT2
|
||||
if args.use_decoder_masked_multihead_attention and args.model_type in ["t5", "mt5"]:
|
||||
raise ValueError("`use_decoder_masked_multihead_attention` option is only supported for GPT2")
|
||||
|
||||
if is_gpt2:
|
||||
if args.decoder_onnx and os.path.exists(args.decoder_onnx):
|
||||
logger.info(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
|
||||
|
|
@ -1699,6 +1767,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
node.attribute.extend(attr_to_extend)
|
||||
|
||||
initializers = []
|
||||
|
||||
if args.model_type in ["t5", "mt5"]:
|
||||
if args.run_shape_inference:
|
||||
logger.info(f"Symbolic shape inference on {args.encoder_decoder_init_onnx}. The file will be overwritten.")
|
||||
|
|
@ -1745,37 +1814,43 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
logger.info(
|
||||
f"{len(initializers)} shared initializers ({[i.name for i in initializers]}) in decoder and init decoder subgraphs are moved to the main graph"
|
||||
)
|
||||
# Update for init decoder subgraph
|
||||
|
||||
# Update init decoder subgraph in preparation to use past present share buffer
|
||||
if past_present_share_buffer:
|
||||
logger.info("*****update init decoder subgraph to make past and present share buffer******************")
|
||||
update_decoder_subgraph_past_present_share_buffer(gpt2_init_decoder_model.graph)
|
||||
|
||||
# Update init decoder subgraph in preparation to use DecoderMaskedMultiheadAttention
|
||||
# NOTE: Even if we will not use DecoderMaskedMultiheadAttention in the init decoder subgraph
|
||||
# it makes the runtime changes cleaner if we keep both the init decoder and decoder subgraphs
|
||||
# same in terms of the subgraph inputs.
|
||||
if (
|
||||
args.use_decoder_masked_multihead_attention
|
||||
and not update_decoder_subgraph_use_decoder_masked_multihead_attention(
|
||||
gpt2_init_decoder_model.graph, is_beamsearch, False
|
||||
)
|
||||
):
|
||||
raise ValueError("Could not update the init decoder subgraph to use DecoderMaskedMultiheadAttention")
|
||||
|
||||
node.attribute.append(onnx.helper.make_attribute("init_decoder", gpt2_init_decoder_model.graph))
|
||||
else:
|
||||
# Move initializer from subgraph to main graph could reduce memory usage in inference.
|
||||
initializers = move_initializers(decoder_model.graph)
|
||||
logger.info(f"{len(initializers)} initializers from the decoder are moved to the main graph")
|
||||
|
||||
# Update for non-init decoder subgraph
|
||||
# Update decoder subgraph in preparation to use past present share buffer
|
||||
if past_present_share_buffer:
|
||||
logger.info("*****update decoder subgraph to make past and present share buffer******************")
|
||||
update_decoder_subgraph_past_present_share_buffer(decoder_model.graph)
|
||||
|
||||
# If at all, the user requests the use of `use_decoder_masked_multihead_attention`,
|
||||
# we can only use it in the decoder subgraph - it cannot be used in the init_decoder subgraph
|
||||
# as the kernel only supports the decoding use-case (i.e.) when input sequence length is 1.
|
||||
if args.use_decoder_masked_multihead_attention:
|
||||
logger.info("Update decoder subgraph to use DecoderMaskedMultiheadAttention")
|
||||
|
||||
if not past_present_share_buffer:
|
||||
raise ValueError(
|
||||
"`past_present_share_buffer` MUST be turned on to use `use_decoder_masked_multihead_attention`"
|
||||
)
|
||||
|
||||
if not args.use_gpu:
|
||||
raise ValueError("`use_decoder_masked_multihead_attention` option is only supported on GPUs")
|
||||
|
||||
if not update_decoder_subgraph_use_decoder_masked_multihead_attention(decoder_model.graph):
|
||||
raise ValueError("Could not update the decoder subgraph to use DecoderMaskedMultiheadAttention")
|
||||
# Update decoder subgraph in preparation to use DecoderMaskedMultiheadAttention
|
||||
if (
|
||||
args.use_decoder_masked_multihead_attention
|
||||
and not update_decoder_subgraph_use_decoder_masked_multihead_attention(
|
||||
decoder_model.graph, is_beamsearch, True
|
||||
)
|
||||
):
|
||||
raise ValueError("Could not update the decoder subgraph to use DecoderMaskedMultiheadAttention")
|
||||
|
||||
node.attribute.append(onnx.helper.make_attribute("decoder", decoder_model.graph))
|
||||
|
||||
|
|
|
|||
|
|
@ -156,6 +156,18 @@ class TestBeamSearchGpt(unittest.TestCase):
|
|||
if self.enable_cuda:
|
||||
self.run_beam_search("--repetition_penalty 1.0 --use_gpu -p fp16", is_greedy=True)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_beam_search_use_decoder_masked_multihead_attention(self):
|
||||
if self.enable_cuda:
|
||||
self.run_beam_search(f"--past_present_share_buffer --use_decoder_masked_multihead_attention --use_gpu")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_beam_search_use_decoder_masked_multihead_attention_fp16(self):
|
||||
if self.enable_cuda:
|
||||
self.run_beam_search(
|
||||
f"--past_present_share_buffer --use_decoder_masked_multihead_attention --use_gpu -p fp16"
|
||||
)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_external_data(self):
|
||||
self.run_beam_search(
|
||||
|
|
|
|||
Loading…
Reference in a new issue