From 05c31a036d5a2932ac09fe367cfcfbb125e32fb2 Mon Sep 17 00:00:00 2001 From: Viswanath Boga <44417868+viboga@users.noreply.github.com> Date: Thu, 14 Jul 2022 13:31:59 -0700 Subject: [PATCH] fixing positions for beam search gpt2 (#12156) * fixing positions for beam search gpt2 Co-authored-by: Tianlei Wu --- .../transformers/beam_search_device_helper.cc | 77 +++++++++---------- .../transformers/beam_search_device_helper.h | 9 +-- .../cpu/transformers/beam_search_impl_gpt.h | 25 ++++-- .../transformers/beam_search_device_helper.cc | 39 ++++------ .../transformers/beam_search_device_helper.h | 4 +- .../cuda/transformers/beam_search_impl.cu | 8 +- 6 files changed, 79 insertions(+), 83 deletions(-) diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index 7b163dd923..68b1aab919 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -197,7 +197,6 @@ void InitBeamState(transformers::IBeamSearchState* beam_state, // T5 does not need position, so next_positions is empty for T5. if (!beam_state->next_positions.empty()) { - memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes()); gsl::copy(sequence_lengths, beam_state->next_positions); } @@ -274,13 +273,13 @@ Status ProcessLogits(const OrtValue& logits, // // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) gsl::span& next_token_scores = beam_state->next_token_scores; ORT_RETURN_IF_ERROR( - SoftmaxCPU( - batch_beam_size, // rows - vocab_size, // elements per row - (input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(), - next_token_scores.data(), - true, - thread_pool)); + SoftmaxCPU( + batch_beam_size, // rows + vocab_size, // elements per row + (input_length == 1 && logits_batch_size == batch_beam_size) ? logits_data : next_token_logits.data(), + next_token_scores.data(), + true, + thread_pool)); #ifdef DEBUG_BEAM_SEARCH dumper->Print("next_token_scores after softmax", next_token_scores.data(), batch_size, num_beams, vocab_size); @@ -428,12 +427,12 @@ Status UpdateGptFeeds( std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams, int gpt_subgraph_first_past_input_idx, - int gpt_subgraph_first_present_output_idx, - const transformers::IConsoleDumper* dumper) { + int gpt_subgraph_first_present_output_idx) { // last_outputs: logits, present_0, present_1, ... // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 ORT_UNUSED_PARAMETER(stream); @@ -454,10 +453,12 @@ Status UpdateGptFeeds( } next_inputs[0] = input_ids; - // Update position IDs - int32_t* position_data = position_ids.GetMutable()->MutableData(); - for (int i = 0; i < batch_beam_size; i++) { - position_data[i]++; + if (increase_position) { + // Update position IDs + int32_t* position_data = position_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_beam_size; i++) { + position_data[i]++; + } } next_inputs[1] = position_ids; @@ -477,14 +478,6 @@ Status UpdateGptFeeds( } next_inputs[2] = attention_mask; -#ifdef DEBUG_BEAM_SEARCH - dumper->Print("input_ids", input_ids); - dumper->Print("position_ids", position_ids); - dumper->Print("attention_mask", attention_mask); -#else - ORT_UNUSED_PARAMETER(dumper); -#endif - // Update past state if (num_beams == 1) { // feed present_* output to past_* inputs one by one @@ -725,12 +718,12 @@ template Status UpdateGptFeeds( std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams, int gpt_subgraph_first_past_input_idx, - int gpt_subgraph_first_present_output_idx, - const transformers::IConsoleDumper* dumper); + int gpt_subgraph_first_present_output_idx); template Status UpdateDecoderFeeds( AllocatorPtr allocator, @@ -751,28 +744,28 @@ template Status UpdateDecoderFeeds( template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); template Status ExpandBuffer( - void* stream, - const OrtValue& input, - int num_beams, - AllocatorPtr allocator, - OrtValue& expanded, - bool only_copy_shape); + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); template Status ExpandBuffer( - void* stream, - const OrtValue& input, - int num_beams, - AllocatorPtr allocator, - OrtValue& expanded, - bool only_copy_shape); + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); template Status ExpandBuffer( - void* stream, - const OrtValue& input, - int num_beams, - AllocatorPtr allocator, - OrtValue& expanded, - bool only_copy_shape); + void* stream, + const OrtValue& input, + int num_beams, + AllocatorPtr allocator, + OrtValue& expanded, + bool only_copy_shape); } // namespace BeamSearchCpuDeviceHelper } // namespace contrib diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h index ab18eec25c..36ab8d8e93 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h @@ -96,12 +96,12 @@ using UpdateGptFeedsFunc = std::function& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams, int gpt_subgraph_first_past_input_idx, - int gpt_subgraph_first_present_output_idx, - const transformers::IConsoleDumper* dumper)>; + int gpt_subgraph_first_present_output_idx)>; // Create encoder inputs (for encoder-decoder model like T5). using CreateEncoderInputsFunc = std::function; } // namespace BeamSearchDeviceHelper - // These are CPU specific device helper implementations namespace BeamSearchCpuDeviceHelper { Status TopK( @@ -208,12 +207,12 @@ Status UpdateGptFeeds( std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams, int gpt_subgraph_first_past_input_idx, - int gpt_subgraph_first_present_output_idx, - const transformers::IConsoleDumper* dumper); + int gpt_subgraph_first_present_output_idx); // --------------------------------------------------------------- // Functions for encoder-decoder model like T5 diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index 9cf5daeba9..7674c2a781 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -56,6 +56,7 @@ class BeamSearchGpt : public BeamSearchBase { std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices); @@ -93,6 +94,7 @@ Status BeamSearchGpt::UpdateFeeds( std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices) { return update_feeds_func_(this->temp_space_allocator_, @@ -101,12 +103,12 @@ Status BeamSearchGpt::UpdateFeeds( next_inputs, current_length, position_ids, + increase_position, beam_next_tokens, beam_indices, this->parameters_->num_beams, gpt_subgraph_.GetFirstPastInputIndex(), - gpt_subgraph_.GetFirstPresentOutputIndex(), - this->GetConsoleDumper()); + gpt_subgraph_.GetFirstPresentOutputIndex()); } template @@ -186,11 +188,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager& feeds_fetches_manage #ifdef DEBUG_BEAM_SEARCH const IConsoleDumper* dumper = this->GetConsoleDumper(); - dumper->Print("input_ids", feeds[0]); - dumper->Print("position_ids", feeds[1]); - dumper->Print("attention_mask", feeds[2]); #endif - // Position ids for all iterations except the first. It uses memory buffer owned by next_positions. OrtValue position_ids; int64_t dims[] = {parameters->BatchBeamSize(), 1}; @@ -205,9 +203,19 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager& feeds_fetches_manage int iteration_counter = 0; while (current_length < parameters->max_length) { iteration_counter++; + #ifdef DEBUG_BEAM_SEARCH auto cur_len = std::to_string(current_length); dumper->Print("***CurrentLength", cur_len, true); + dumper->Print("iteration", iteration_counter, true); + + dumper->Print("input_ids", feeds[0]); + dumper->Print("position_ids", feeds[1]); + dumper->Print("attention_mask", feeds[2]); + for (size_t i = 3; i < feeds.size(); i++) { + dumper->Print("past", static_cast(i) - 3, true); + dumper->Print("", feeds[i]); + } #endif status = utils::ExecuteSubgraph(this->decoder_session_state_, @@ -241,8 +249,11 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager& feeds_fetches_manage // Prepare inputs for next round of subgraph call. if (current_length < parameters->max_length) { + // For the first iteration, position_ids is initialized as sequence lengths. We can add it to feeds directly. + // For the remaining iterations, we need increase position_ids first, then add it to feeds. + bool increase_position = (iteration_counter > 1); ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length, - position_ids, + position_ids, increase_position, beam_next_tokens.as_span(), beam_indices.as_span())); } diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index b712908259..780e98909c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -258,9 +258,8 @@ Status ProcessLogits(const OrtValue& logits, // // The output will be float for consideration of precision and easy integration with remaining parts. float* Y_data = next_token_scores.data(); - const CudaT* X_data = (input_length == 1 && logits_batch_size == batch_beam_size) ? - logits_data : - reinterpret_cast(next_token_logits.data()); + bool is_single_token = (input_length == 1 && logits_batch_size == batch_beam_size); + const CudaT* X_data = is_single_token ? logits_data : reinterpret_cast(next_token_logits.data()); dispatch_blockwise_softmax_forward( cuda_stream, Y_data, X_data, vocab_size, vocab_size, batch_size * num_beams); @@ -500,12 +499,12 @@ Status UpdateGptFeeds( std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams, int gpt_subgraph_first_past_input_idx, - int gpt_subgraph_first_present_output_idx, - const transformers::IConsoleDumper* dumper) { + int gpt_subgraph_first_present_output_idx) { // Update input_ids with next tokens. int batch_beam_size = static_cast(beam_next_tokens.length()); int64_t dims[] = {batch_beam_size, 1}; @@ -519,7 +518,7 @@ Status UpdateGptFeeds( next_inputs[0] = input_ids; // Update position IDs - int32_t* position_data = position_ids.GetMutable()->MutableData(); + int32_t* position_data = increase_position ? position_ids.GetMutable()->MutableData() : nullptr; next_inputs[1] = position_ids; // Update attention mask @@ -538,14 +537,6 @@ Status UpdateGptFeeds( next_inputs[2] = attention_mask; -#ifdef DEBUG_BEAM_SEARCH - dumper->Print("input_ids", input_ids); - dumper->Print("position_ids", position_ids); - dumper->Print("attention_mask", attention_mask); -#else - ORT_UNUSED_PARAMETER(dumper); -#endif - // Update past state if (num_beams == 1) { const int k = gpt_subgraph_first_past_input_idx - gpt_subgraph_first_present_output_idx; @@ -662,12 +653,12 @@ Status ExpandBuffer(void* stream, for (int i = 0; i < batch_size; i++) { for (int j = 0; j < num_beams; j++) { CUDA_RETURN_IF_ERROR( - cudaMemcpyAsync( - target, - input_data + i * chunk_size, - sizeof(T) * chunk_size, - cudaMemcpyDeviceToDevice, - cuda_stream)); + cudaMemcpyAsync( + target, + input_data + i * chunk_size, + sizeof(T) * chunk_size, + cudaMemcpyDeviceToDevice, + cuda_stream)); target += chunk_size; } } @@ -714,12 +705,12 @@ template Status UpdateGptFeeds( std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams, int gpt_subgraph_first_past_input_idx, - int gpt_subgraph_first_present_output_idx, - const transformers::IConsoleDumper* dumper); + int gpt_subgraph_first_present_output_idx); // Float16 template void InitBeamState(transformers::IBeamSearchState* beam_state, @@ -748,12 +739,12 @@ template Status UpdateGptFeeds( std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams, int gpt_subgraph_first_past_input_idx, - int gpt_subgraph_first_present_output_idx, - const transformers::IConsoleDumper* dumper); + int gpt_subgraph_first_present_output_idx); template Status UpdateDecoderFeeds( AllocatorPtr allocator, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h index 14f64e923e..4424fee6d5 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h @@ -68,12 +68,12 @@ Status UpdateGptFeeds( std::vector& next_inputs, int current_length, OrtValue& position_ids, + bool increase_position, gsl::span beam_next_tokens, gsl::span beam_indices, int num_beams, int gpt_subgraph_first_past_input_idx, - int gpt_subgraph_first_present_output_idx, - const transformers::IConsoleDumper* dumper); + int gpt_subgraph_first_present_output_idx); // --------------------------------------------------------------- // Functions for encoder-decoder model like T5 diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu index 4f93b1dded..6bc52758c7 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu @@ -248,9 +248,11 @@ __global__ void UpdateGptInputsKernel(const T* old_mask_data, int j = index % current_length; mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast(1); - // Update sequence length (or next positions). - if (index < batch_beam_size) { - next_positions[index]++; + if (next_positions != nullptr) { + // Update sequence length (or next positions). + if (index < batch_beam_size) { + next_positions[index]++; + } } } }