BeamScorer to use contiguous arrays for BeamHypotheses (#15923)

### Description
Change BeamHypotheses to not use a stl::priority_queue and instead all
BeamHypotheses use a single buffer that they each get a small slice of.

As the beam count is really small (typically 4,8, max of 32) and the
array size fixed, the BeamHypotheses just does a sorted insert into an
array.

This also allows for the BeamHypotheses inside of the BeamSearchScorer
to be a single fixed allocation vs an onnxruntime::FastAllocVector.

### Motivation and Context
The goal is to simplify the memory usage and make the code more easily
ported to CUDA.
This commit is contained in:
Ryan Hill 2023-05-13 14:17:45 -07:00 committed by GitHub
parent 896a963492
commit 310273cbe4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 70 additions and 117 deletions

View file

@ -202,12 +202,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
std::vector<OrtValue> fetches;
// Initialize resources
onnxruntime::OrtStlAllocator<HypothesisScore> hypothesis_score_allocator(this->cpu_allocator_);
onnxruntime::OrtStlAllocator<BeamHypotheses> beam_hyps_allocator(this->cpu_allocator_);
this->beam_scorer_ = std::make_unique<BeamSearchScorer>(*parameters,
hypothesis_score_allocator,
beam_hyps_allocator,
this->cpu_allocator_);
this->beam_scorer_ = std::make_unique<BeamSearchScorer>(*parameters, this->cpu_allocator_);
BeamSearchCpuState cpu_state{*parameters,
this->cpu_allocator_,
@ -233,11 +228,10 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
}
}
constexpr bool use_position = true;
BeamSearchState<T> beam_state{*parameters,
this->temp_space_allocator_,
gpt_subgraph_.has_decoder_masked_attention_,
use_position};
true /* use_position */};
init_beam_state_func_(&beam_state,
cpu_state.sequence_lengths,
@ -245,8 +239,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
parameters->num_beams,
this->ort_stream_);
gsl::span<const int32_t> input_ids = expanded_input_ids_in_cpu.Get<Tensor>().DataAsSpan<int32_t>();
cpu_state.SetExpandedSequence(input_ids);
cpu_state.SetExpandedSequence(expanded_input_ids_in_cpu.Get<Tensor>().DataAsSpan<int32_t>());
#ifdef DEBUG_GENERATION
const IConsoleDumper* dumper = this->GetConsoleDumper();
@ -382,7 +375,7 @@ Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetch
final_beam_scores = cpu_state.final_beam_scores;
}
this->beam_scorer_->Finalize(&(cpu_state.sequences),
this->beam_scorer_->Finalize(cpu_state.sequences,
final_beam_scores,
output_sequences,
output_sequences_scores);

View file

@ -183,18 +183,12 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
// Copy decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam.
cpu_state.SetUnexpandedSequence(decoder_input_ids.Get<Tensor>().DataAsSpan<int32_t>());
onnxruntime::OrtStlAllocator<HypothesisScore> hypothesis_score_allocator(this->cpu_allocator_);
onnxruntime::OrtStlAllocator<BeamHypotheses> beam_hyps_allocator(this->cpu_allocator_);
this->beam_scorer_ = std::make_unique<BeamSearchScorer>(*parameters,
hypothesis_score_allocator,
beam_hyps_allocator,
this->cpu_allocator_);
this->beam_scorer_ = std::make_unique<BeamSearchScorer>(*parameters, this->cpu_allocator_);
constexpr bool use_position = false;
BeamSearchState<T> beam_state{*parameters,
this->temp_space_allocator_,
decoder_subgraph_.has_decoder_masked_attention_,
use_position};
false /* use_position */};
init_beam_state_func_(&beam_state,
cpu_state.sequence_lengths,
@ -373,7 +367,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
final_beam_scores = cpu_state.final_beam_scores;
}
this->beam_scorer_->Finalize(&(cpu_state.sequences),
this->beam_scorer_->Finalize(cpu_state.sequences,
final_beam_scores,
output_sequences,
output_sequences_scores);

View file

@ -18,43 +18,44 @@ namespace contrib {
namespace transformers {
using ::onnxruntime::rnn::detail::Allocate;
BeamHypotheses::BeamHypotheses(int num_beams,
float length_penalty,
bool early_stopping,
onnxruntime::OrtStlAllocator<HypothesisScore>& hypothesis_score_allocator)
: num_beams_(num_beams),
length_penalty_(length_penalty),
early_stopping_(early_stopping),
worst_score_(1e9),
beams_(hypothesis_score_allocator) {
void BeamHypotheses::Init(const IGenerationParameters& parameters, gsl::span<HypothesisScore> beams) {
length_penalty_ = parameters.length_penalty;
early_stopping_ = parameters.early_stopping;
beams_ = beams;
beams_used_ = 0;
}
void BeamHypotheses::Add(gsl::span<const int32_t>& hypothesis, float sum_logprobs) {
auto length = hypothesis.size();
float score = sum_logprobs / pow(static_cast<float>(length), length_penalty_);
if (this->Size() < num_beams_ || score > worst_score_) {
HypothesisScore item(hypothesis, score);
beams_.push(item);
if (this->Size() > num_beams_) {
beams_.pop();
}
worst_score_ = beams_.top().score;
}
size_t index = beams_used_;
// If the array is full, don't add unless it's better than the worst element
if (index == beams_.size()) {
if (score <= beams_[--index].score)
return;
} else
beams_used_++;
// Rotate existing elements over while the new element scores higher
for (; index > 0 && score > beams_[index - 1].score; index--)
beams_[index] = beams_[index - 1];
beams_[index] = HypothesisScore{hypothesis, score};
}
bool BeamHypotheses::IsDone(float best_sum_logprobs, int current_length) {
bool BeamHypotheses::IsDone(float best_sum_logprobs, int current_length) const {
// If there are enough hypotheses and that none of the hypotheses being generated can become better
// than the worst one in the heap, then we are done with this sentence.
if (Size() < num_beams_)
if (static_cast<size_t>(beams_used_) < beams_.size())
return false;
if (early_stopping_)
return true;
float current_score = best_sum_logprobs / pow(static_cast<float>(current_length), length_penalty_);
return worst_score_ >= current_score;
return beams_.back().score >= current_score;
}
void BeamHypotheses::Output(
@ -63,63 +64,49 @@ void BeamHypotheses::Output(
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
gsl::span<float>& sequences_scores) // buffer of shape (num_return_sequences) or empty
{
ORT_ENFORCE(top_k <= Size());
int remove_count = Size() - top_k;
for (int i = 0; i < remove_count; i++) {
beams_.pop();
}
// Since pop get the worst sequence, so output it in the reverse order.
// The first (worst) beam shall be put at the last position among top_k sequences.
int index = top_k - 1;
while (!beams_.empty()) {
auto item = beams_.top();
gsl::span<const int32_t>& source = item.hypothesis;
// Copy the top_k beams into the sequences
ORT_ENFORCE(top_k <= beams_used_);
for (int index = 0; index < top_k; index++) {
auto& item = beams_[index];
gsl::span<int32_t> target = sequences.subspan(static_cast<gsl::index>(index) * max_length, max_length);
// Note that word_ids might be less than max_length.
// Since the sequences has been filled with pad token ID, so padding is not needed here.
gsl::copy(source, target);
gsl::copy(item.hypothesis, target);
if (!sequences_scores.empty())
sequences_scores[index] = item.score;
beams_.pop();
index--;
}
}
BeamSearchScorer::BeamSearchScorer(const IGenerationParameters& parameters,
onnxruntime::OrtStlAllocator<HypothesisScore>& hypothesis_score_allocator,
onnxruntime::OrtStlAllocator<BeamHypotheses>& beam_hyps_allocator,
AllocatorPtr& allocator)
: batch_size_{static_cast<size_t>(parameters.batch_size)},
num_beams_{static_cast<size_t>(parameters.num_beams)},
max_length_{static_cast<size_t>(parameters.max_length)},
num_beam_hyps_to_keep_{static_cast<size_t>(parameters.num_return_sequences)},
pad_token_id_{parameters.pad_token_id},
eos_token_id_{parameters.eos_token_id},
beam_hyps_(beam_hyps_allocator) {
for (size_t i = 0; i < batch_size_; i++) {
beam_hyps_.push_back(BeamHypotheses(num_beams_, parameters.length_penalty, parameters.early_stopping, hypothesis_score_allocator));
}
eos_token_id_{parameters.eos_token_id} {
size_t batch_beam_size = batch_size_ * num_beams_;
auto beams = Allocate<HypothesisScore>(allocator, batch_beam_size, hypothesis_scores_ptr_);
beam_hyps_ = Allocate<BeamHypotheses>(allocator, batch_size_, beam_hyps_ptr_);
for (size_t i = 0; i < batch_size_; i++)
beam_hyps_[i].Init(parameters, beams.subspan(i * num_beams_, num_beams_));
done_ = Allocate<bool>(allocator, batch_size_, done_ptr_, true /* fill allocated array */, false /* fill with false */);
constexpr bool no_fill = false; // Do not fill values after allocation
next_beam_scores_ = Allocate<float>(allocator, batch_beam_size, next_beam_scores_ptr_, no_fill);
next_beam_tokens_ = Allocate<int32_t>(allocator, batch_beam_size, next_beam_tokens_ptr_, no_fill);
next_beam_indices_ = Allocate<int32_t>(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill);
next_beam_scores_ = Allocate<float>(allocator, batch_beam_size, next_beam_scores_ptr_);
next_beam_tokens_ = Allocate<int32_t>(allocator, batch_beam_size, next_beam_tokens_ptr_);
next_beam_indices_ = Allocate<int32_t>(allocator, batch_beam_size, next_beam_indices_ptr_);
// Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
size_t per_beam = (SafeInt<size_t>(max_length_) * (max_length_ + 1) - (parameters.sequence_length - 1) * parameters.sequence_length) / 2;
hypothesis_buffer_length_ = batch_beam_size * per_beam;
hypothesis_buffer_ = Allocate<int32_t>(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill);
hypothesis_buffer_ = Allocate<int32_t>(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_);
}
bool BeamSearchScorer::IsDone() {
bool BeamSearchScorer::IsDone() const {
for (auto done : done_) {
if (!done)
return false;
@ -127,7 +114,7 @@ bool BeamSearchScorer::IsDone() {
return true;
}
void BeamSearchScorer::Process(ISequences* sequences,
void BeamSearchScorer::Process(ISequences& sequences,
gsl::span<const float>& next_scores,
gsl::span<const int32_t>& next_tokens,
gsl::span<const int32_t>& next_indices) {
@ -135,7 +122,7 @@ void BeamSearchScorer::Process(ISequences* sequences,
// It contains word ID of whole sequence generated so far.
// It is different from subgraph input_ids, which only need one word when past state is not empty.
const int sequence_length = sequences->GetSequenceLength();
const int sequence_length = sequences.GetSequenceLength();
ORT_ENFORCE(next_scores.size() == next_tokens.size());
ORT_ENFORCE(next_scores.size() == next_indices.size());
@ -143,7 +130,7 @@ void BeamSearchScorer::Process(ISequences* sequences,
for (size_t batch = 0; batch < batch_size_; batch++) {
BeamHypotheses& beam_hyp = beam_hyps_[batch];
if (done_[batch]) {
ORT_ENFORCE(beam_hyp.Size() >= gsl::narrow_cast<int>(num_beams_),
ORT_ENFORCE(beam_hyp.Size() == gsl::narrow_cast<int>(num_beams_),
"Batch can only be done if all beams have been generated");
// Pad the batch.
@ -172,7 +159,7 @@ void BeamSearchScorer::Process(ISequences* sequences,
}
// Clone the sequence and append to buffer.
gsl::span<const int32_t> src = sequences->GetSequence(batch_beam_idx);
gsl::span<const int32_t> src = sequences.GetSequence(batch_beam_idx);
auto clone = hypothesis_buffer_.subspan(hypothesis_buffer_offset_, sequence_length);
gsl::copy(src, clone);
hypothesis_buffer_offset_ += static_cast<size_t>(sequence_length);
@ -205,11 +192,10 @@ void BeamSearchScorer::Process(ISequences* sequences,
}
}
void BeamSearchScorer::Finalize(ISequences* sequences,
void BeamSearchScorer::Finalize(ISequences& sequences,
gsl::span<const float>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) {
ORT_ENFORCE(sequences != nullptr);
ORT_ENFORCE(output_sequences != nullptr);
// Finalize all open beam hypotheses and add to generated hypotheses.
@ -222,7 +208,7 @@ void BeamSearchScorer::Finalize(ISequences* sequences,
for (size_t beam_index = 0; beam_index < num_beams_; beam_index++) {
size_t batch_beam_index = batch_index * num_beams_ + beam_index;
float final_score = final_beam_scores[batch_beam_index];
auto final_tokens = sequences->GetSequence(batch_beam_index);
auto final_tokens = sequences.GetSequence(batch_beam_index);
beam_hyp.Add(final_tokens, final_score);
}
}

View file

@ -20,69 +20,50 @@ namespace contrib {
namespace transformers {
struct HypothesisScore {
HypothesisScore(gsl::span<const int32_t>& _hypothesis, float _score)
: hypothesis(_hypothesis), score(_score) {}
gsl::span<const int32_t> hypothesis;
float score;
};
class HypothesisScoreCompare {
public:
bool operator()(const HypothesisScore& a, const HypothesisScore& b) {
return a.score > b.score;
}
};
class BeamHypotheses {
public:
BeamHypotheses(int num_beams,
float length_penalty,
bool early_stopping,
onnxruntime::OrtStlAllocator<HypothesisScore>& hypothesis_score_allocator);
struct BeamHypotheses {
void Init(const IGenerationParameters& parameters, gsl::span<HypothesisScore> beams);
// Number of hypotheses
int Size() { return static_cast<int>(beams_.size()); }
int Size() const { return beams_used_; }
// Add a new hypothesis
void Add(gsl::span<const int32_t>& hypothesis, float sum_logprobs);
bool IsDone(float best_sum_logprobs, int current_length);
bool IsDone(float best_sum_logprobs, int current_length) const;
// Output results. Note that it will clear all beams.
// Output results
void Output(int top_k, // number of sequences to return
int max_length, // max sequence length
gsl::span<int32_t>& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
gsl::span<float>& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
private:
int num_beams_;
float length_penalty_;
bool early_stopping_;
float worst_score_;
// Min-heap for top k
std::priority_queue<HypothesisScore, onnxruntime::FastAllocVector<HypothesisScore>, HypothesisScoreCompare> beams_;
gsl::span<HypothesisScore> beams_; // Beam width sized array of hypotheses, sorted by highest scoring
int beams_used_; // Number of elements used in beams_
};
class BeamSearchScorer : public IBeamScorer {
public:
BeamSearchScorer(const IGenerationParameters& parameters,
onnxruntime::OrtStlAllocator<HypothesisScore>& hypothesis_score_allocator,
onnxruntime::OrtStlAllocator<BeamHypotheses>& beam_hyps_allocator,
AllocatorPtr& allocator);
void Process(ISequences* sequences,
void Process(ISequences& sequences,
gsl::span<const float>& next_scores,
gsl::span<const int32_t>& next_tokens,
gsl::span<const int32_t>& next_indices) override;
void Finalize(ISequences* sequences,
void Finalize(ISequences& sequences,
gsl::span<const float>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) override;
bool IsDone();
bool IsDone() const;
gsl::span<float>& GetNextScores() { return next_beam_scores_; }
gsl::span<int32_t>& GetNextTokens() { return next_beam_tokens_; }
@ -113,7 +94,9 @@ class BeamSearchScorer : public IBeamScorer {
size_t hypothesis_buffer_length_{}; // Total number of elements
size_t hypothesis_buffer_offset_{}; // Offset of available buffer, or length of used buffer.
onnxruntime::FastAllocVector<BeamHypotheses> beam_hyps_;
IAllocatorUniquePtr<HypothesisScore> hypothesis_scores_ptr_; // num_beams_ * batch_size_, divided into num_beams_ chunks per BeamHypothesis in beam_hyps_
IAllocatorUniquePtr<BeamHypotheses> beam_hyps_ptr_;
gsl::span<BeamHypotheses> beam_hyps_; // batch_size_ count
};
} // namespace transformers

View file

@ -429,7 +429,7 @@ Status ProcessLogits(const OrtValue& logits, //
#endif
beam_scorer->Process(
sequences,
*sequences,
next_scores,
next_tokens,
next_indices);

View file

@ -93,30 +93,27 @@ struct ISamplingState {
gsl::span<T> cumulative_probs;
};
class ISequences {
public:
struct ISequences {
virtual ~ISequences() {}
virtual gsl::span<const int32_t> GetSequence(int beam_index) const = 0;
virtual int GetSequenceLength() const = 0;
};
class ILogitsProcessorList {
public:
struct ILogitsProcessorList {
virtual ~ILogitsProcessorList() {}
virtual void Process(const ISequences* sequences, gsl::span<float>& next_token_scores, int step) = 0;
};
// Interface for all scorers for beam search or beam sample.
class IBeamScorer {
public:
struct IBeamScorer {
virtual ~IBeamScorer() {}
virtual void Process(ISequences* sequences,
virtual void Process(ISequences& sequences,
gsl::span<const float>& next_scores,
gsl::span<const int32_t>& next_tokens,
gsl::span<const int32_t>& next_indices) = 0;
virtual void Finalize(ISequences* sequences,
virtual void Finalize(ISequences& sequences,
gsl::span<const float>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) = 0;

View file

@ -586,7 +586,7 @@ Status ProcessLogits(const OrtValue& logits, //
// TODO: Implement BeamScorer on CUDA
beam_scorer->Process(
sequences,
*sequences,
next_scores,
next_tokens,
next_indices);