diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h index e6cba6be1a..5f8ab335c3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -202,12 +202,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch std::vector fetches; // Initialize resources - onnxruntime::OrtStlAllocator hypothesis_score_allocator(this->cpu_allocator_); - onnxruntime::OrtStlAllocator beam_hyps_allocator(this->cpu_allocator_); - this->beam_scorer_ = std::make_unique(*parameters, - hypothesis_score_allocator, - beam_hyps_allocator, - this->cpu_allocator_); + this->beam_scorer_ = std::make_unique(*parameters, this->cpu_allocator_); BeamSearchCpuState cpu_state{*parameters, this->cpu_allocator_, @@ -233,11 +228,10 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch } } - constexpr bool use_position = true; BeamSearchState beam_state{*parameters, this->temp_space_allocator_, gpt_subgraph_.has_decoder_masked_attention_, - use_position}; + true /* use_position */}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, @@ -245,8 +239,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch parameters->num_beams, this->ort_stream_); - gsl::span input_ids = expanded_input_ids_in_cpu.Get().DataAsSpan(); - cpu_state.SetExpandedSequence(input_ids); + cpu_state.SetExpandedSequence(expanded_input_ids_in_cpu.Get().DataAsSpan()); #ifdef DEBUG_GENERATION const IConsoleDumper* dumper = this->GetConsoleDumper(); @@ -382,7 +375,7 @@ Status BeamSearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetch final_beam_scores = cpu_state.final_beam_scores; } - this->beam_scorer_->Finalize(&(cpu_state.sequences), + this->beam_scorer_->Finalize(cpu_state.sequences, final_beam_scores, output_sequences, output_sequences_scores); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index f63da7bc34..662424dc58 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -183,18 +183,12 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches // Copy decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam. cpu_state.SetUnexpandedSequence(decoder_input_ids.Get().DataAsSpan()); - onnxruntime::OrtStlAllocator hypothesis_score_allocator(this->cpu_allocator_); - onnxruntime::OrtStlAllocator beam_hyps_allocator(this->cpu_allocator_); - this->beam_scorer_ = std::make_unique(*parameters, - hypothesis_score_allocator, - beam_hyps_allocator, - this->cpu_allocator_); + this->beam_scorer_ = std::make_unique(*parameters, this->cpu_allocator_); - constexpr bool use_position = false; BeamSearchState beam_state{*parameters, this->temp_space_allocator_, decoder_subgraph_.has_decoder_masked_attention_, - use_position}; + false /* use_position */}; init_beam_state_func_(&beam_state, cpu_state.sequence_lengths, @@ -373,7 +367,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches final_beam_scores = cpu_state.final_beam_scores; } - this->beam_scorer_->Finalize(&(cpu_state.sequences), + this->beam_scorer_->Finalize(cpu_state.sequences, final_beam_scores, output_sequences, output_sequences_scores); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 8227db44e7..402de05ebb 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -18,43 +18,44 @@ namespace contrib { namespace transformers { using ::onnxruntime::rnn::detail::Allocate; -BeamHypotheses::BeamHypotheses(int num_beams, - float length_penalty, - bool early_stopping, - onnxruntime::OrtStlAllocator& hypothesis_score_allocator) - : num_beams_(num_beams), - length_penalty_(length_penalty), - early_stopping_(early_stopping), - worst_score_(1e9), - beams_(hypothesis_score_allocator) { +void BeamHypotheses::Init(const IGenerationParameters& parameters, gsl::span beams) { + length_penalty_ = parameters.length_penalty; + early_stopping_ = parameters.early_stopping; + beams_ = beams; + beams_used_ = 0; } void BeamHypotheses::Add(gsl::span& hypothesis, float sum_logprobs) { auto length = hypothesis.size(); float score = sum_logprobs / pow(static_cast(length), length_penalty_); - if (this->Size() < num_beams_ || score > worst_score_) { - HypothesisScore item(hypothesis, score); - beams_.push(item); - if (this->Size() > num_beams_) { - beams_.pop(); - } - worst_score_ = beams_.top().score; - } + size_t index = beams_used_; + // If the array is full, don't add unless it's better than the worst element + if (index == beams_.size()) { + if (score <= beams_[--index].score) + return; + } else + beams_used_++; + + // Rotate existing elements over while the new element scores higher + for (; index > 0 && score > beams_[index - 1].score; index--) + beams_[index] = beams_[index - 1]; + + beams_[index] = HypothesisScore{hypothesis, score}; } -bool BeamHypotheses::IsDone(float best_sum_logprobs, int current_length) { +bool BeamHypotheses::IsDone(float best_sum_logprobs, int current_length) const { // If there are enough hypotheses and that none of the hypotheses being generated can become better // than the worst one in the heap, then we are done with this sentence. - if (Size() < num_beams_) + if (static_cast(beams_used_) < beams_.size()) return false; if (early_stopping_) return true; float current_score = best_sum_logprobs / pow(static_cast(current_length), length_penalty_); - return worst_score_ >= current_score; + return beams_.back().score >= current_score; } void BeamHypotheses::Output( @@ -63,63 +64,49 @@ void BeamHypotheses::Output( gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length) gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty { - ORT_ENFORCE(top_k <= Size()); - int remove_count = Size() - top_k; - for (int i = 0; i < remove_count; i++) { - beams_.pop(); - } - - // Since pop get the worst sequence, so output it in the reverse order. - // The first (worst) beam shall be put at the last position among top_k sequences. - int index = top_k - 1; - while (!beams_.empty()) { - auto item = beams_.top(); - gsl::span& source = item.hypothesis; + // Copy the top_k beams into the sequences + ORT_ENFORCE(top_k <= beams_used_); + for (int index = 0; index < top_k; index++) { + auto& item = beams_[index]; gsl::span target = sequences.subspan(static_cast(index) * max_length, max_length); // Note that word_ids might be less than max_length. // Since the sequences has been filled with pad token ID, so padding is not needed here. - gsl::copy(source, target); + gsl::copy(item.hypothesis, target); if (!sequences_scores.empty()) sequences_scores[index] = item.score; - - beams_.pop(); - index--; } } BeamSearchScorer::BeamSearchScorer(const IGenerationParameters& parameters, - onnxruntime::OrtStlAllocator& hypothesis_score_allocator, - onnxruntime::OrtStlAllocator& beam_hyps_allocator, AllocatorPtr& allocator) : batch_size_{static_cast(parameters.batch_size)}, num_beams_{static_cast(parameters.num_beams)}, max_length_{static_cast(parameters.max_length)}, num_beam_hyps_to_keep_{static_cast(parameters.num_return_sequences)}, pad_token_id_{parameters.pad_token_id}, - eos_token_id_{parameters.eos_token_id}, - beam_hyps_(beam_hyps_allocator) { - for (size_t i = 0; i < batch_size_; i++) { - beam_hyps_.push_back(BeamHypotheses(num_beams_, parameters.length_penalty, parameters.early_stopping, hypothesis_score_allocator)); - } - + eos_token_id_{parameters.eos_token_id} { size_t batch_beam_size = batch_size_ * num_beams_; + auto beams = Allocate(allocator, batch_beam_size, hypothesis_scores_ptr_); + beam_hyps_ = Allocate(allocator, batch_size_, beam_hyps_ptr_); + for (size_t i = 0; i < batch_size_; i++) + beam_hyps_[i].Init(parameters, beams.subspan(i * num_beams_, num_beams_)); + done_ = Allocate(allocator, batch_size_, done_ptr_, true /* fill allocated array */, false /* fill with false */); - constexpr bool no_fill = false; // Do not fill values after allocation - next_beam_scores_ = Allocate(allocator, batch_beam_size, next_beam_scores_ptr_, no_fill); - next_beam_tokens_ = Allocate(allocator, batch_beam_size, next_beam_tokens_ptr_, no_fill); - next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill); + next_beam_scores_ = Allocate(allocator, batch_beam_size, next_beam_scores_ptr_); + next_beam_tokens_ = Allocate(allocator, batch_beam_size, next_beam_tokens_ptr_); + next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_); // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. size_t per_beam = (SafeInt(max_length_) * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2; hypothesis_buffer_length_ = batch_beam_size * per_beam; - hypothesis_buffer_ = Allocate(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill); + hypothesis_buffer_ = Allocate(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_); } -bool BeamSearchScorer::IsDone() { +bool BeamSearchScorer::IsDone() const { for (auto done : done_) { if (!done) return false; @@ -127,7 +114,7 @@ bool BeamSearchScorer::IsDone() { return true; } -void BeamSearchScorer::Process(ISequences* sequences, +void BeamSearchScorer::Process(ISequences& sequences, gsl::span& next_scores, gsl::span& next_tokens, gsl::span& next_indices) { @@ -135,7 +122,7 @@ void BeamSearchScorer::Process(ISequences* sequences, // It contains word ID of whole sequence generated so far. // It is different from subgraph input_ids, which only need one word when past state is not empty. - const int sequence_length = sequences->GetSequenceLength(); + const int sequence_length = sequences.GetSequenceLength(); ORT_ENFORCE(next_scores.size() == next_tokens.size()); ORT_ENFORCE(next_scores.size() == next_indices.size()); @@ -143,7 +130,7 @@ void BeamSearchScorer::Process(ISequences* sequences, for (size_t batch = 0; batch < batch_size_; batch++) { BeamHypotheses& beam_hyp = beam_hyps_[batch]; if (done_[batch]) { - ORT_ENFORCE(beam_hyp.Size() >= gsl::narrow_cast(num_beams_), + ORT_ENFORCE(beam_hyp.Size() == gsl::narrow_cast(num_beams_), "Batch can only be done if all beams have been generated"); // Pad the batch. @@ -172,7 +159,7 @@ void BeamSearchScorer::Process(ISequences* sequences, } // Clone the sequence and append to buffer. - gsl::span src = sequences->GetSequence(batch_beam_idx); + gsl::span src = sequences.GetSequence(batch_beam_idx); auto clone = hypothesis_buffer_.subspan(hypothesis_buffer_offset_, sequence_length); gsl::copy(src, clone); hypothesis_buffer_offset_ += static_cast(sequence_length); @@ -205,11 +192,10 @@ void BeamSearchScorer::Process(ISequences* sequences, } } -void BeamSearchScorer::Finalize(ISequences* sequences, +void BeamSearchScorer::Finalize(ISequences& sequences, gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) { - ORT_ENFORCE(sequences != nullptr); ORT_ENFORCE(output_sequences != nullptr); // Finalize all open beam hypotheses and add to generated hypotheses. @@ -222,7 +208,7 @@ void BeamSearchScorer::Finalize(ISequences* sequences, for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) { size_t batch_beam_index = batch_index * num_beams_ + beam_index; float final_score = final_beam_scores[batch_beam_index]; - auto final_tokens = sequences->GetSequence(batch_beam_index); + auto final_tokens = sequences.GetSequence(batch_beam_index); beam_hyp.Add(final_tokens, final_score); } } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 237328d2db..a60f37a67a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -20,69 +20,50 @@ namespace contrib { namespace transformers { struct HypothesisScore { - HypothesisScore(gsl::span& _hypothesis, float _score) - : hypothesis(_hypothesis), score(_score) {} - gsl::span hypothesis; float score; }; -class HypothesisScoreCompare { - public: - bool operator()(const HypothesisScore& a, const HypothesisScore& b) { - return a.score > b.score; - } -}; - -class BeamHypotheses { - public: - BeamHypotheses(int num_beams, - float length_penalty, - bool early_stopping, - onnxruntime::OrtStlAllocator& hypothesis_score_allocator); +struct BeamHypotheses { + void Init(const IGenerationParameters& parameters, gsl::span beams); // Number of hypotheses - int Size() { return static_cast(beams_.size()); } + int Size() const { return beams_used_; } // Add a new hypothesis void Add(gsl::span& hypothesis, float sum_logprobs); - bool IsDone(float best_sum_logprobs, int current_length); + bool IsDone(float best_sum_logprobs, int current_length) const; - // Output results. Note that it will clear all beams. + // Output results void Output(int top_k, // number of sequences to return int max_length, // max sequence length gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) private: - int num_beams_; float length_penalty_; bool early_stopping_; - float worst_score_; - - // Min-heap for top k - std::priority_queue, HypothesisScoreCompare> beams_; + gsl::span beams_; // Beam width sized array of hypotheses, sorted by highest scoring + int beams_used_; // Number of elements used in beams_ }; class BeamSearchScorer : public IBeamScorer { public: BeamSearchScorer(const IGenerationParameters& parameters, - onnxruntime::OrtStlAllocator& hypothesis_score_allocator, - onnxruntime::OrtStlAllocator& beam_hyps_allocator, AllocatorPtr& allocator); - void Process(ISequences* sequences, + void Process(ISequences& sequences, gsl::span& next_scores, gsl::span& next_tokens, gsl::span& next_indices) override; - void Finalize(ISequences* sequences, + void Finalize(ISequences& sequences, gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) override; - bool IsDone(); + bool IsDone() const; gsl::span& GetNextScores() { return next_beam_scores_; } gsl::span& GetNextTokens() { return next_beam_tokens_; } @@ -113,7 +94,9 @@ class BeamSearchScorer : public IBeamScorer { size_t hypothesis_buffer_length_{}; // Total number of elements size_t hypothesis_buffer_offset_{}; // Offset of available buffer, or length of used buffer. - onnxruntime::FastAllocVector beam_hyps_; + IAllocatorUniquePtr hypothesis_scores_ptr_; // num_beams_ * batch_size_, divided into num_beams_ chunks per BeamHypothesis in beam_hyps_ + IAllocatorUniquePtr beam_hyps_ptr_; + gsl::span beam_hyps_; // batch_size_ count }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index 76630a30df..3e5a795401 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -429,7 +429,7 @@ Status ProcessLogits(const OrtValue& logits, // #endif beam_scorer->Process( - sequences, + *sequences, next_scores, next_tokens, next_indices); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index bb105942bc..93774d0b6b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -93,30 +93,27 @@ struct ISamplingState { gsl::span cumulative_probs; }; -class ISequences { - public: +struct ISequences { virtual ~ISequences() {} virtual gsl::span GetSequence(int beam_index) const = 0; virtual int GetSequenceLength() const = 0; }; -class ILogitsProcessorList { - public: +struct ILogitsProcessorList { virtual ~ILogitsProcessorList() {} virtual void Process(const ISequences* sequences, gsl::span& next_token_scores, int step) = 0; }; // Interface for all scorers for beam search or beam sample. -class IBeamScorer { - public: +struct IBeamScorer { virtual ~IBeamScorer() {} - virtual void Process(ISequences* sequences, + virtual void Process(ISequences& sequences, gsl::span& next_scores, gsl::span& next_tokens, gsl::span& next_indices) = 0; - virtual void Finalize(ISequences* sequences, + virtual void Finalize(ISequences& sequences, gsl::span& final_beam_scores, Tensor* output_sequences, Tensor* output_sequence_scores) = 0; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 2e4bdc5cdf..efe5dbe497 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -586,7 +586,7 @@ Status ProcessLogits(const OrtValue& logits, // // TODO: Implement BeamScorer on CUDA beam_scorer->Process( - sequences, + *sequences, next_scores, next_tokens, next_indices);