diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 3a0171a064..a982999f0a 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -5,6 +5,7 @@ Do not modify directly.*
* com.microsoft
* com.microsoft.Attention
* com.microsoft.AttnLSTM
+ * com.microsoft.BeamSearch
* com.microsoft.BiasDropout
* com.microsoft.BiasGelu
* com.microsoft.BiasSoftmax
@@ -337,6 +338,75 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.BeamSearch**
+
+ Beam Search for text generation. Supports GPT-2 decoder.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- body : graph (required)
+- The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output
+- early_stopping : int
+- early stop or not
+- eos_token_id : int (required)
+- The id of the end-of-sequence token
+- no_repeat_ngram_size : int
+- no repeat ngrams size
+- pad_token_id : int (required)
+- The id of the padding token
+
+
+#### Inputs (6 - 9)
+
+
+- input_ids : I
+- The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)
+- max_length : I
+- The maximum length of the sequence to be generated. Shape is (1)
+- min_length (optional) : I
+- The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)
+- num_beams : I
+- Number of beams for beam search. 1 means no beam search. Shape is (1)
+- num_return_sequences : I
+- The number of returned sequences in the batch. Shape is (1)
+- temperature : T
+- The value used to module the next token probabilities. Accepts value > 0.0. Shape is (1)
+- length_penalty (optional) : T
+- Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)
+- repetition_penalty (optional) : T
+- The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
+- vocab_mask (optional) : M
+- Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
+
+#### Outputs (1 - 3)
+
+
+- sequences : I
+- Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)
+- sequences_scores (optional) : T
+- Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)
+- scores (optional) : T
+- Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)
+
+
+#### Type Constraints
+
+
+- T : tensor(float), tensor(float16)
+- Constrain input and output types to float tensors.
+- I : tensor(int32)
+- Constrain to integer types
+- M : tensor(int32)
+- Constrain mask to integer types
+
+
+
### **com.microsoft.BiasDropout**
output, dropout_mask = Dropout(data + bias, ratio) + residual, Intended to specialize the dropout pattern commonly found in transformer models.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 052367130e..6f09fb5026 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -377,6 +377,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**
*in* weight:**T**
*in* bias:**T**
*in* mask_index:**M**
*in* past:**T**
*in* extra_add:**T**
*out* output:**T**
*out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**
*in* W:**T**
*in* R:**T**
*in* B:**T**
*in* sequence_lens:**T1**
*in* initial_h:**T**
*in* initial_c:**T**
*in* P:**T**
*in* QW:**T**
*in* MW:**T**
*in* V:**T**
*in* M:**T**
*in* memory_seq_lens:**T1**
*in* AW:**T**
*out* Y:**T**
*out* Y_h:**T**
*out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)
**T1** = tensor(int32)|
+|BeamSearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* num_beams:**I**
*in* num_return_sequences:**I**
*in* temperature:**T**
*in* length_penalty:**T**
*in* repetition_penalty:**T**
*in* vocab_mask:**M**
*out* sequences:**I**
*out* sequences_scores:**T**
*out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**
*in* cur_tokens:**T**
*in* prev_suffix_match_idx:**T**
*in* pred_tokens:**T**
*out* tokens:**T**
*out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**
*in* B:**T**
*out* C:**T**|1+|**T** = tensor(double), tensor(float)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 77669f74ea..1f0ee6d17e 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -22,6 +22,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
@@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
// add more kernels here
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
new file mode 100644
index 0000000000..1225442192
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -0,0 +1,650 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// there's no way to use a raw pointer as the copy destination with std::copy_n
+// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
+// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4996)
+#endif
+
+#include
+#include "core/providers/cpu/controlflow/utils.h"
+#include "core/providers/cpu/math/top_k.h"
+#include "core/framework/allocator.h"
+#include "core/framework/framework_common.h"
+#include "core/framework/op_kernel_context_internal.h"
+#include "core/framework/session_state.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/framework/utils.h"
+#include "core/providers/cpu/tensor/utils.h"
+#include "core/framework/session_options.h"
+#include "core/framework/TensorSeq.h"
+#include "gsl/gsl"
+#include "core/providers/cpu/math/softmax_shared.h"
+#include "beam_search.h"
+#include "logits_processor.h"
+#include "sequences.h"
+#include "dump_tensor.h"
+
+#ifdef _MSC_VER
+#pragma warning(pop)
+#endif
+
+using namespace ONNX_NAMESPACE;
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+namespace contrib {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ BeamSearch, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ transformers::BeamSearch);
+
+REGISTER_KERNEL_TYPED(float)
+
+namespace transformers {
+
+template
+struct BeamSearchState {
+ gsl::span beam_scores; // shape (batch_size, num_beams)
+ gsl::span next_token_logits; // shape (batch_size * num_beams, vocab_size)
+ gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size)
+ gsl::span next_tokens; // shape (batch_size, 2 * num_beams)
+ gsl::span next_indices; // shape (batch_size, 2 * num_beams)
+ gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids.
+
+ gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
+ gsl::span remaining_scores; // subspan that is avaiable for appending next token scores.
+
+ Sequences sequences;
+
+ void Init(AllocatorPtr allocator,
+ int batch_size,
+ int num_beams,
+ int vocab_size,
+ int sequence_length,
+ int max_length,
+ bool output_scores) {
+ size_t batch_beam_size = SafeInt(batch_size) * num_beams;
+ beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size, true, static_cast(0));
+
+ // Initialize score of first beam of each group with 0 and the rest with -1e9.
+ // This ensures that the beams in the same group don't produce same tokens every time.
+ for (int i = 0; i < batch_size; i++) {
+ for (int j = 1; j < num_beams; j++) {
+ beam_scores[i * num_beams + j] = -1e9;
+ }
+ }
+
+ size_t next_token_size = SafeInt(batch_beam_size) * vocab_size;
+ next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size, true, static_cast(0));
+ next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size, true, static_cast(0));
+
+ next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0));
+
+ next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size, true, static_cast(0));
+
+ next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size, true, static_cast(0));
+
+ if (output_scores) {
+ size_t elements = SafeInt(max_length - sequence_length) * batch_size * num_beams * vocab_size;
+ scores = AllocateBuffer(allocator, scores_buffer_, elements);
+ remaining_scores = scores;
+ }
+
+ // sequences will be initialized later since it has dependency on input_ids
+ }
+
+ private:
+ BufferUniquePtr beam_scores_buffer_;
+ BufferUniquePtr next_token_logits_buffer_;
+ BufferUniquePtr next_token_scores_buffer_;
+ BufferUniquePtr next_tokens_buffer_;
+ BufferUniquePtr next_indices_buffer_;
+ BufferUniquePtr next_positions_buffer_;
+ BufferUniquePtr scores_buffer_;
+};
+
+template
+class BeamSearchImpl {
+ public:
+ BeamSearchImpl(OpKernelContextInternal& context,
+ const SessionState& session_state,
+ GptSubgraph& gpt_subgraph,
+ concurrency::ThreadPool* thread_pool,
+ void* stream,
+ BeamSearchParameters& params);
+
+ // Initialize by validating all the inputs, and allocating the output tensors.
+ Status Initialize();
+
+ // Execute beam search in iterations util stopping criteria is reached.
+ // In each iteration, GPT subgraph is called, and next token for each sequence is generated.
+ Status Execute(const FeedsFetchesManager& cached_ffm);
+
+ private:
+ // Validate inputs.
+ Status CheckInputs(const OpKernelContextInternal& context);
+
+ // Prepare the inputs for first inference of subgraph
+ void CreateInitialFeeds(gsl::span& next_positions, std::vector& feeds);
+
+ // Update the input for next iteration.
+ Status UpdateFeeds(
+ const std::vector& last_outputs,
+ std::vector& next_inputs,
+ int current_length,
+ gsl::span& next_positions,
+ gsl::span beam_next_tokens,
+ gsl::span beam_indices);
+
+ // Process logits and append next tokens to sequences.
+ Status GenerateNextToken(const OrtValue& logits,
+ gsl::span& beam_next_tokens,
+ gsl::span& beam_indices,
+ BeamSearchState& beam_state);
+
+ // Calculate scores from logits, then apply filtering and select next token for each beam.
+ Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
+ BeamSearchState& beam_state,
+ AllocatorPtr& allocator);
+
+ OpKernelContextInternal& context_;
+
+ const SessionState& session_state_;
+
+ GptSubgraph& gpt_subgraph_;
+
+ concurrency::ThreadPool* thread_pool_;
+
+ const std::vector& implicit_inputs_;
+
+ // Not used in CPU. Stream is for CUDA only.
+ void* stream_;
+
+ BeamSearchParameters* parameters_;
+
+ LogitsProcessorList logits_processors_;
+
+ std::unique_ptr> beam_scorer_;
+
+ AllocatorPtr allocator_;
+};
+
+template
+void BeamSearch::Init(const OpKernelInfo& info) {
+ // Make sure the body attribute was present even though we don't need it here.
+ ONNX_NAMESPACE::GraphProto proto;
+ ORT_ENFORCE(info.GetAttr("body", &proto).IsOK());
+ ORT_IGNORE_RETURN_VALUE(proto);
+
+ parameters_.ParseFromAttributes(info);
+
+ stream_ = nullptr;
+}
+
+template
+std::unique_ptr BeamSearch::Create(const OpKernelInfo& info,
+ void* stream) {
+ auto result = std::make_unique(info);
+ result->SetComputeStream(stream);
+ return result;
+}
+
+template
+common::Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
+ const std::string& attribute_name,
+ const SessionState& subgraph_session_state) {
+ ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
+ const auto& node = Node();
+ gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer());
+ ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
+ feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
+ parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
+ gpt_subgraph_->num_heads,
+ gpt_subgraph_->head_size,
+ gpt_subgraph_->num_layers);
+ return Status::OK();
+}
+
+template
+Status BeamSearch::Compute(OpKernelContext* ctx) const {
+ auto* ctx_internal = static_cast(ctx);
+ auto* session_state = ctx_internal->SubgraphSessionState("body");
+ ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute.");
+ ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
+
+ concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
+
+ BeamSearchParameters parameters = parameters_; // make a copy since we will update the parameters based on inputs later
+
+ BeamSearchImpl impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, stream_, parameters};
+
+ auto status = impl.Initialize();
+ ORT_RETURN_IF_ERROR(status);
+
+ status = impl.Execute(*feeds_fetches_manager_);
+
+ return status;
+}
+
+template
+BeamSearchImpl::BeamSearchImpl(OpKernelContextInternal& context,
+ const SessionState& session_state,
+ GptSubgraph& gpt_subgraph,
+ concurrency::ThreadPool* thread_pool,
+ void* stream,
+ BeamSearchParameters& params)
+ : context_(context),
+ session_state_(session_state),
+ gpt_subgraph_(gpt_subgraph),
+ thread_pool_(thread_pool),
+ implicit_inputs_(context_.GetImplicitInputs()),
+ stream_(stream),
+ parameters_(¶ms),
+ allocator_(nullptr) {
+ parameters_->ParseFromInputs(&context);
+
+ allocator_ = session_state.GetExecutionProviders()
+ .Get(onnxruntime::kCpuExecutionProvider)
+ ->GetAllocator(0, OrtMemTypeDefault);
+}
+
+template
+Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) {
+ // Input shapes:
+ // input_ids : (batch_size, sequence_length)
+ // vocab_mask : (vocab_size) or nullptr
+
+ const Tensor* input_ids = context.Input(0);
+ const auto& dims = input_ids->Shape().GetDims();
+ if (dims.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ",
+ dims.size());
+ }
+
+ const Tensor* vocab_mask = context.Input(8);
+ if (vocab_mask != nullptr) { // vocab_mask is optional
+ const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
+ if (vocab_mask_dims.size() != 1) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ",
+ vocab_mask_dims.size());
+ }
+
+ // There is dependency on vocab_size parameter, which shall be set before calling this function.
+ if (static_cast(vocab_mask_dims[0]) != parameters_->vocab_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ",
+ vocab_mask_dims[0]);
+ }
+
+ // store vocab mask in parameters.
+ parameters_->vocab_mask = vocab_mask->DataAsSpan();
+ }
+
+ return Status::OK();
+}
+
+template
+Status BeamSearchImpl::Initialize() {
+ auto status = Status::OK();
+
+#define CHECK_SCALAR_INPUT(name, index, required) \
+ auto* name##_tensor = context_.Input(index); \
+ if (name##_tensor) { \
+ if (!name##_tensor->Shape().IsScalar()) { \
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \
+ name##_tensor->Shape()); \
+ } \
+ } else if (required) { \
+ return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \
+ }
+
+ CHECK_SCALAR_INPUT(min_length, 1, false);
+
+ CHECK_SCALAR_INPUT(max_length, 2, true);
+
+ CHECK_SCALAR_INPUT(num_beams, 3, true);
+
+ CHECK_SCALAR_INPUT(num_return_sequences, 4, true);
+
+ CHECK_SCALAR_INPUT(temperature, 5, true);
+
+ CHECK_SCALAR_INPUT(length_penalty, 6, true);
+
+ ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'.");
+
+ ORT_RETURN_IF_ERROR(CheckInputs(context_));
+
+ // This flag will be updated later when the scores output exists.
+ parameters_->output_scores = false;
+
+ // Initialize processsors after CheckInputs so that parameters_->vocab_mask is ready.
+ logits_processors_.Init(*parameters_);
+
+ return status;
+}
+
+template
+void BeamSearchImpl::CreateInitialFeeds(gsl::span& next_positions, std::vector& feeds) {
+ const OrtValue* input_ids_value = context_.GetInputOrtValue(0);
+ const Tensor& input_ids = input_ids_value->Get();
+ gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, next_positions, feeds);
+}
+
+template
+Status BeamSearchImpl::ProcessLogits(
+ const OrtValue& logits,
+ BeamSearchState& beam_state,
+ AllocatorPtr& allocator) {
+ const int64_t batch_beam_size = static_cast(parameters_->BatchBeamSize());
+ const int& vocab_size = parameters_->vocab_size;
+
+ const T* logits_data = logits.Get().Data();
+
+ // Logits has shape (batch_size * num_beams, input_length, vocab_size),
+ // where input_length equals to parameters_->sequence_length for first subgraph call, and 1 for the remaining calls.
+ const TensorShape& logits_shape = logits.Get().Shape();
+ ORT_ENFORCE(logits_shape.NumDimensions() == 3);
+ auto input_length = logits_shape[1];
+
+ // Get logits for the last token:
+ // next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size)
+ // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1.
+ gsl::span& next_token_logits = beam_state.next_token_logits;
+ if (input_length > 1) {
+ const T* current_logits = logits_data + (input_length - 1) * vocab_size;
+ for (int i = 0; i < batch_beam_size; i++) {
+ gsl::span source(current_logits, vocab_size);
+ gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size);
+ gsl::copy(source, target);
+ current_logits += input_length * vocab_size;
+ }
+ }
+
+#ifdef DEBUG_BEAM_SEARCH
+ //DumpOrtValue("logits", logits);
+ DumpTensor("next_token_logits", next_token_logits.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
+#endif
+
+ // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
+ gsl::span& next_token_scores = beam_state.next_token_scores;
+ Status status = SoftmaxCPU(batch_beam_size, // rows
+ vocab_size, // elements per row
+ input_length > 1 ? next_token_logits.data() : logits_data,
+ next_token_scores.data(),
+ true,
+ thread_pool_);
+ if (!status.IsOK()) {
+ return status;
+ }
+
+#ifdef DEBUG_BEAM_SEARCH
+ DumpTensor("next_token_scores after softmax", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
+#endif
+
+ // Apply all score processors that updates scores
+ logits_processors_.Process(&(beam_state.sequences), next_token_scores);
+
+#ifdef DEBUG_BEAM_SEARCH
+ DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
+#endif
+
+ // Add beam score to next token scores. Corresponding python code is like:
+ // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
+ // TODO: use thread pool to parrellel
+ int offset = 0;
+ int batch_beam_index = 0;
+ for (int i = 0; i < parameters_->batch_size; i++) {
+ for (int j = 0; j < parameters_->num_beams; j++, batch_beam_index++) {
+ for (int k = 0; k < parameters_->vocab_size; k++, offset++) {
+ next_token_scores[offset] += beam_state.beam_scores[batch_beam_index];
+ }
+ }
+ }
+
+#ifdef DEBUG_BEAM_SEARCH
+ DumpTensor("next_token_scores after adding beam_scores", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
+#endif
+
+ if (parameters_->output_scores) {
+ // Append next token scores to the scores output.
+ gsl::copy(next_token_scores, beam_state.remaining_scores);
+ beam_state.remaining_scores = beam_state.remaining_scores.subspan(next_token_scores.size());
+ }
+
+ // Apply top-k selection like the following:
+ // next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
+ // next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
+ int64_t next_token_scores_dims[] = {parameters_->batch_size, parameters_->num_beams * vocab_size};
+ TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
+ auto element_type = DataTypeImpl::GetType();
+ OrtValue next_token_scores_value;
+ Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value);
+ const Tensor& input = next_token_scores_value.Get();
+
+ const int axis = 1;
+ const unsigned top_k = static_cast(2 * parameters_->num_beams);
+ const bool largest = true;
+ const bool sorted = true; // results returned in sorted order.
+
+ std::unique_ptr topk_scores;
+ std::unique_ptr topk_indices;
+ status = GetTopK(&input, axis, top_k, largest, sorted, allocator, thread_pool_, topk_scores, topk_indices);
+ if (!status.IsOK()) {
+ return status;
+ }
+
+#ifdef DEBUG_BEAM_SEARCH
+ DumpTensor("topk_scores", *(topk_scores.get()));
+ DumpTensor("topk_indices", *(topk_indices.get()));
+#endif
+
+ // Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following:
+ // next_indices = (next_tokens / vocab_size).long()
+ // next_tokens = next_tokens % vocab_size
+ gsl::span next_token_indices = topk_indices->DataAsSpan();
+ offset = 0;
+ for (int i = 0; i < parameters_->batch_size; i++) {
+ for (unsigned int j = 0; j < top_k; j++, offset++) {
+ beam_state.next_indices[offset] = next_token_indices[offset] / vocab_size;
+ beam_state.next_tokens[offset] = next_token_indices[offset] % vocab_size;
+ }
+ }
+
+ gsl::span next_scores = topk_scores->DataAsSpan();
+ gsl::span next_tokens(beam_state.next_tokens.data(), beam_state.next_tokens.size());
+ gsl::span next_indices(beam_state.next_indices.data(), beam_state.next_indices.size());
+
+#ifdef DEBUG_BEAM_SEARCH
+ DumpTensor("next_scores before scorer", next_scores.data(), parameters_->batch_size, top_k);
+ DumpTensor("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, top_k);
+ DumpTensor("next_indices before scorer", next_indices.data(), parameters_->batch_size, top_k);
+#endif
+
+ beam_scorer_->Process(
+ &(beam_state.sequences),
+ next_scores,
+ next_tokens,
+ next_indices);
+
+ return Status::OK();
+}
+
+template
+Status BeamSearchImpl::GenerateNextToken(
+ const OrtValue& logits,
+ gsl::span& beam_next_tokens,
+ gsl::span& beam_indices,
+ BeamSearchState& beam_state) {
+ // Process logits to get next token scores
+ ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_));
+
+ gsl::span& beam_scores = beam_scorer_->GetNextScores();
+ // It is optional to clone beam_scores. Change it to use same buffer also works:
+ // beam_state.beam_scores = beam_scores
+ // Here we make a copy to reduce the coupling with little cost (the buffer size is small).
+ gsl::copy(beam_scores, beam_state.beam_scores);
+
+ beam_next_tokens = beam_scorer_->GetNextTokens();
+ beam_indices = beam_scorer_->GetNextIndices();
+
+#ifdef DEBUG_BEAM_SEARCH
+ DumpTensor("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams);
+ DumpTensor("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams);
+ DumpTensor("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams);
+#endif
+
+ beam_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens);
+
+#ifdef DEBUG_BEAM_SEARCH
+ beam_state.sequences.PrintSequences();
+#endif
+ return Status::OK();
+}
+
+template
+Status BeamSearchImpl::UpdateFeeds(
+ const std::vector& last_outputs,
+ std::vector& next_inputs,
+ int current_length,
+ gsl::span& next_positions,
+ gsl::span beam_next_tokens,
+ gsl::span beam_indices) {
+ return gpt_subgraph_.UpdateFeeds(last_outputs, next_inputs, current_length, next_positions,
+ beam_next_tokens, beam_indices, parameters_->num_beams);
+}
+
+template
+Status BeamSearchImpl::Execute(const FeedsFetchesManager& ffm) {
+ auto status = Status::OK();
+
+ std::vector sequences_dims{parameters_->batch_size, parameters_->num_return_sequences, parameters_->max_length};
+ TensorShape sequences_shape(sequences_dims);
+ Tensor* output_sequences = context_.Output(0, sequences_shape);
+
+ std::vector sequences_scores_dims{parameters_->batch_size, parameters_->num_return_sequences};
+ TensorShape sequences_scores_shape(sequences_scores_dims);
+ Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape);
+
+ std::vector scores_dims{
+ parameters_->max_length - parameters_->sequence_length,
+ parameters_->batch_size, parameters_->num_beams, parameters_->vocab_size};
+ TensorShape scores_shape(scores_dims);
+ Tensor* output_scores = context_.Output(2, scores_shape);
+
+ // Update the flag to indicate whether scores exists in output
+ parameters_->output_scores = (output_scores != nullptr);
+
+ std::vector feeds;
+ std::vector fetches;
+
+ // Initialize resources
+ AllocatorPtr temp_space_allocator;
+ ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator));
+
+ BeamSearchState beam_state;
+ beam_state.Init(temp_space_allocator,
+ parameters_->batch_size,
+ parameters_->num_beams,
+ parameters_->vocab_size,
+ parameters_->sequence_length,
+ parameters_->max_length,
+ parameters_->output_scores);
+
+ beam_scorer_ = std::make_unique>(parameters_->batch_size,
+ parameters_->num_beams,
+ parameters_->max_length,
+ parameters_->length_penalty,
+ parameters_->early_stopping,
+ parameters_->num_return_sequences,
+ parameters_->pad_token_id,
+ parameters_->eos_token_id);
+ beam_scorer_->Initialize(allocator_, parameters_->sequence_length); // TODO: use temp_space_allocator
+
+ CreateInitialFeeds(beam_state.next_positions, feeds);
+ const OrtValue& input_ids = feeds[0];
+ beam_state.sequences.Init(temp_space_allocator,
+ input_ids,
+ parameters_->BatchBeamSize(),
+ parameters_->sequence_length,
+ parameters_->max_length);
+
+#ifdef DEBUG_BEAM_SEARCH
+ DumpOrtValue("input_ids", input_ids);
+ DumpOrtValue("position_ids", feeds[1]);
+ DumpOrtValue("attention_mask", feeds[2]);
+#endif
+
+ int current_length = parameters_->sequence_length;
+ while (current_length < parameters_->max_length) {
+#ifdef DEBUG_BEAM_SEARCH
+ DumpString("***CurrentLength", std::to_string(current_length), true);
+#endif
+
+ status = utils::ExecuteSubgraph(session_state_, ffm, feeds, fetches, {},
+ ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger());
+
+ ORT_RETURN_IF_ERROR(status);
+
+ const OrtValue& logits = fetches[0];
+ gsl::span beam_next_tokens;
+ gsl::span beam_indices;
+ ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state));
+
+ // When all batches are finished, stop earlier to avoid wasting computation.
+ if (beam_scorer_->IsDone()) {
+ break;
+ }
+
+ // Increase sequence length after a new token is generated.
+ ++current_length;
+
+ // Prepare inputs for next round of subgraph call.
+ if (current_length < parameters_->max_length) {
+ ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
+ beam_state.next_positions,
+ beam_next_tokens.as_span(),
+ beam_indices.as_span()));
+ }
+ fetches.clear();
+
+#ifdef DEBUG_BEAM_SEARCH
+ if (current_length - parameters_->sequence_length == 3) { // only dump a few steps.
+ DisableTensorDump();
+ }
+#endif
+ }
+
+ gsl::span beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
+ beam_scorer_->Finalize(&(beam_state.sequences),
+ beam_scores,
+ output_sequences,
+ output_sequences_scores);
+
+ // Output per token scores
+ if (output_scores != nullptr) {
+ gsl::span target = output_scores->MutableDataAsSpan();
+ gsl::span source = gsl::span(beam_state.scores.data(), beam_state.scores.size());
+ assert(target.length() == source.length());
+ gsl::copy(source, target);
+ }
+
+ return status;
+}
+
+// Instantiation
+template class BeamSearchImpl;
+template class BeamSearch;
+
+} // namespace transformers
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
new file mode 100644
index 0000000000..9dc5cac408
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
@@ -0,0 +1,48 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include
+#include "gsl/gsl"
+#include "core/common/common.h"
+#include "core/framework/feeds_fetches_manager.h"
+#include "core/framework/op_kernel.h"
+#include "core/providers/cpu/controlflow/utils.h"
+#include "beam_search_parameters.h"
+#include "beam_search_scorer.h"
+#include "gpt_subgraph.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace transformers {
+
+template
+class BeamSearch : public controlflow::IControlFlowKernel {
+ public:
+ BeamSearch(const OpKernelInfo& info) : IControlFlowKernel(info) { Init(info); }
+ void Init(const OpKernelInfo& info);
+
+ Status Compute(OpKernelContext* ctx) const override;
+
+ Status SetupSubgraphExecutionInfo(const SessionState& session_state,
+ const std::string& attribute_name,
+ const SessionState& subgraph_session_state) override;
+
+ static std::unique_ptr Create(const OpKernelInfo& info, void* stream);
+
+ protected:
+ void SetComputeStream(void* stream) { stream_ = stream; }
+
+ private:
+ // Subgraph and FeedsFetchesManager re-used for each subgraph execution.
+ std::unique_ptr gpt_subgraph_;
+ FeedsFetchesManager* feeds_fetches_manager_;
+
+ void* stream_;
+
+ BeamSearchParameters parameters_;
+};
+
+} // namespace transformers
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
new file mode 100644
index 0000000000..fee3ec4753
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -0,0 +1,72 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+#include "beam_search_parameters.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace transformers {
+
+constexpr int kMaxSequenceLength = 4096;
+
+Status BeamSearchParameters::Validate() const {
+ ORT_RETURN_IF(eos_token_id < 0, "eos_token_id is invalid");
+ ORT_RETURN_IF(pad_token_id < 0, "pad_token_id is invalid");
+ ORT_RETURN_IF(min_length >= max_length, "min_length shall be smaller than max_length");
+ return Status::OK();
+}
+
+void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
+ early_stopping = info.GetAttrOrDefault("early_stopping", 0) == 1;
+ eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1));
+ pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1));
+ no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0));
+}
+
+void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
+ ORT_ENFORCE(context != nullptr);
+ const Tensor* input_ids = context->Input(0);
+ const auto& dims = input_ids->Shape().GetDims();
+ ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
+ batch_size = static_cast(dims[0]);
+ sequence_length = static_cast(dims[1]);
+
+ auto* max_length_tensor = context->Input(1);
+ max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : kMaxSequenceLength;
+ ORT_ENFORCE(max_length > sequence_length, "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")");
+ ORT_ENFORCE(max_length <= kMaxSequenceLength, "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength);
+
+ auto* min_length_tensor = context->Input(2);
+ min_length = min_length_tensor ? static_cast(*min_length_tensor->Data()) : 0;
+
+ auto* num_beams_tensor = context->Input(3);
+ num_beams = num_beams_tensor ? static_cast(*num_beams_tensor->Data()) : 1;
+ // TODO: limit num_beams > 1 when we can have another operator for greedy search.
+ ORT_ENFORCE(num_beams >= 1, "num_beams shall be a positive integer, got ", num_beams);
+
+ auto* num_return_sequences_tensor = context->Input(4);
+ num_return_sequences = num_return_sequences_tensor ? static_cast(*num_return_sequences_tensor->Data()) : 1;
+ ORT_ENFORCE(num_return_sequences >= 1, "num_return_sequences shall be a positive integer, got ", num_return_sequences);
+ ORT_ENFORCE(num_beams >= num_return_sequences, "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
+
+ auto* temperature_tensor = context->Input(5);
+ temperature = temperature_tensor ? static_cast(*temperature_tensor->Data()) : 1;
+ ORT_ENFORCE(temperature > 0.0f, "temperature shall be greater than 0, got ", temperature);
+
+ auto* length_penalty_tensor = context->Input(6);
+ length_penalty = length_penalty_tensor ? static_cast(*length_penalty_tensor->Data()) : 1;
+
+ auto* repetition_penalty_tensor = context->Input(7);
+ repetition_penalty = repetition_penalty_tensor ? static_cast(*repetition_penalty_tensor->Data()) : 1.0f;
+ ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);
+}
+
+void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
+ vocab_size = vocabulary_size;
+ num_heads = heads;
+ head_size = hidden_size_per_head;
+ num_layers = layers;
+}
+
+} // namespace transformers
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h
new file mode 100644
index 0000000000..26de2a9840
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h
@@ -0,0 +1,55 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+#include "core/common/common.h"
+#include "core/framework/op_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace transformers {
+
+struct BeamSearchParameters {
+ // Parameters from node attributes
+ int eos_token_id;
+ int pad_token_id;
+ int no_repeat_ngram_size;
+ bool early_stopping;
+
+ // Parameters from inputs
+ int min_length;
+ int max_length;
+ int num_beams;
+ int num_return_sequences;
+ float temperature;
+ float length_penalty;
+ float repetition_penalty;
+ int batch_size; // deduce from first dimension of input_ids
+ int sequence_length; // deduce from second dimension of input_ids
+
+ gsl::span vocab_mask;
+
+ // Parameters from outputs.
+ bool output_scores; // whether scores existed in output
+
+ // Parameters from subgraph.
+ int vocab_size;
+ // Below are used in CPU, reserved for CUDA.
+ int num_heads;
+ int head_size;
+ int num_layers;
+
+ Status Validate() const;
+
+ int BatchBeamSize() const { return batch_size * num_beams; }
+
+ void ParseFromAttributes(const OpKernelInfo& info);
+
+ void ParseFromInputs(OpKernelContext* context);
+
+ void SetSubgraphParameters(int vocab_size, int num_heads, int head_size, int num_layers);
+};
+
+} // namespace transformers
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
new file mode 100644
index 0000000000..bb7aeb989e
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
@@ -0,0 +1,285 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
+#include "core/common/common.h"
+#include "core/framework/allocator.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/framework/utils.h"
+#include "core/providers/cpu/tensor/utils.h"
+#include "core/providers/cpu/rnn/rnn_helpers.h"
+#include "beam_search_scorer.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace transformers {
+using ::onnxruntime::rnn::detail::Allocate;
+
+template
+BeamHypotheses::BeamHypotheses(int num_beams, T length_penalty, bool early_stopping)
+ : num_beams_(num_beams),
+ length_penalty_(length_penalty),
+ early_stopping_(early_stopping),
+ worst_score_(1e9) {}
+
+template
+void BeamHypotheses::Add(gsl::span& hypothesis, T sum_logprobs) {
+ auto length = hypothesis.size();
+ // TODO: when T is FP16, compute in FP32, then cast result back to FP16. length_penalty_ might also be float.
+ T score = sum_logprobs / pow(static_cast(length), length_penalty_);
+
+ if (this->Size() < num_beams_ || score > worst_score_) {
+ HypothesisScore item(hypothesis, score);
+ beams_.push(item);
+ if (this->Size() > num_beams_) {
+ beams_.pop();
+ }
+ worst_score_ = beams_.top().score;
+ }
+}
+
+template
+bool BeamHypotheses::IsDone(T best_sum_logprobs, int current_length) {
+ // If there are enough hypotheses and that none of the hypotheses being generated can become better
+ // than the worst one in the heap, then we are done with this sentence.
+
+ if (Size() < num_beams_)
+ return false;
+
+ if (early_stopping_)
+ return true;
+
+ T current_score = best_sum_logprobs / pow(static_cast(current_length), length_penalty_);
+ return worst_score_ >= current_score;
+}
+
+template
+void BeamHypotheses::Output(
+ int top_k,
+ int max_length,
+ gsl::span& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
+ gsl::span& sequences_scores) // buffer of shape (num_return_sequences) or empty
+{
+ ORT_ENFORCE(top_k <= Size());
+ int remove_count = Size() - top_k;
+ for (int i = 0; i < remove_count; i++) {
+ beams_.pop();
+ }
+
+ // Since pop get the worst sequence, so output it in the reverse order.
+ // The frist (worst) beam shall be put at the last position among top_k sequences.
+ int index = top_k - 1;
+ while (!beams_.empty()) {
+ auto item = beams_.top();
+ gsl::span& source = item.hypothesis;
+ gsl::span target = sequences.subspan(index * max_length, max_length);
+
+ // Note that word_ids might be less than max_length.
+ // Since the sequences has been filled with pad token ID, so padding is not needed here.
+ // Since data type need cast from int64_t to int32_t, we cannot use gsl::copy(word_ids, sequence) here.
+ for (size_t i = 0; i < source.length(); i++) {
+ target[i] = static_cast(source[i]);
+ }
+
+ if (!sequences_scores.empty())
+ sequences_scores[index] = item.score;
+
+ beams_.pop();
+ index--;
+ }
+}
+
+template
+BeamSearchScorer::BeamSearchScorer(int batch_size,
+ int num_beams,
+ int max_length,
+ T length_penalty,
+ bool early_stopping,
+ int num_return_sequences,
+ int pad_token_id,
+ int eos_token_id)
+ : batch_size_(batch_size),
+ num_beams_(num_beams),
+ max_length_(max_length),
+ num_beam_hyps_to_keep_(num_return_sequences),
+ pad_token_id_(pad_token_id),
+ eos_token_id_(eos_token_id),
+ hypothesis_buffer_length_(0),
+ hypothesis_buffer_offset_(0) {
+ for (int batch = 0; batch < batch_size; batch++) {
+ beam_hyps.push_back(BeamHypotheses(num_beams, length_penalty, early_stopping));
+ }
+}
+
+template
+bool BeamSearchScorer::IsDone() {
+ for (int batch = 0; batch < batch_size_; batch++) {
+ if (!done_[batch])
+ return false;
+ }
+ return true;
+}
+
+template
+void BeamSearchScorer::Initialize(AllocatorPtr& allocator, int sequence_length){
+ ORT_ENFORCE(next_beam_scores_.empty()); // Make sure this is called only once.
+
+ size_t batch_beam_size = static_cast(batch_size_ * num_beams_);
+ const bool no_fill = false; // do not fill values after allocation
+ next_beam_scores_ = Allocate(allocator, batch_beam_size, next_beam_scores_ptr_, no_fill);
+ next_beam_tokens_ = Allocate(allocator, batch_beam_size, next_beam_tokens_ptr_, no_fill);
+ next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill);
+
+ // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
+ int buffer_per_beam = (max_length_ * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2;
+ hypothesis_buffer_length_ = batch_beam_size * static_cast(buffer_per_beam);
+ hypothesis_buffer_ = Allocate(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill);
+
+ done_ = Allocate(allocator, static_cast(batch_size_), done_ptr_, no_fill);
+ std::fill_n(done_.data(), done_.size(), false);
+}
+
+template
+void BeamSearchScorer::Process(ISequences* sequences,
+ gsl::span& next_scores,
+ gsl::span& next_tokens,
+ gsl::span& next_indices) {
+ // Sequences shape is (batch_size * num_beams, total_sequence_length)
+ // It contains word ID of whole sequence generated so far.
+ // It is different from subgraph input_ids, which only need one word when past state is not empty.
+
+ const int sequence_length = sequences->GetSequenceLength();
+
+ ORT_ENFORCE(next_scores.size() == next_tokens.size());
+ ORT_ENFORCE(next_scores.size() == next_indices.size());
+
+ for (int batch = 0; batch < batch_size_; batch++) {
+ BeamHypotheses& beam_hyp = beam_hyps[batch];
+ if (done_[batch]) {
+ ORT_ENFORCE(beam_hyp.Size() >= num_beams_, "Batch can only be done if all beams have been generated");
+
+ // Pad the batch.
+ for (int j = 0; j < num_beams_; j++) {
+ next_beam_scores_[batch * num_beams_ + j] = 0.0f;
+ next_beam_tokens_[batch * num_beams_ + j] = pad_token_id_;
+ next_beam_indices_[batch * num_beams_ + j] = 0;
+ }
+ continue;
+ }
+
+ // Next tokens for this sentence.
+ int beam_idx = 0;
+ int top_k = 2 * num_beams_;
+ for (int j = 0; j < top_k; j++) {
+ int64_t next_token = next_tokens[batch * top_k + j];
+ T next_score = next_scores[batch * top_k + j];
+ int64_t next_index = next_indices[batch * top_k + j];
+
+ int batch_beam_idx = batch * num_beams_ + static_cast(next_index);
+ // Add to generated hypotheses if end of sentence.
+ if ((eos_token_id_ >= 0) && (next_token == eos_token_id_)) {
+ bool is_beam_token_worse_than_top_num_beams = (j >= num_beams_);
+ if (is_beam_token_worse_than_top_num_beams) {
+ continue;
+ }
+
+ // Clone the sequence and append to buffer.
+ gsl::span src = sequences->GetSequence(batch_beam_idx);
+ auto clone = hypothesis_buffer_.subspan(hypothesis_buffer_offset_, sequence_length);
+ gsl::copy(src, clone);
+ hypothesis_buffer_offset_ += sequence_length;
+ auto sequence = clone.template as_span();
+ beam_hyp.Add(sequence, next_score);
+ } else {
+ // Add next predicted token since it is not eos_token.
+ next_beam_scores_[batch * num_beams_ + beam_idx] = next_score;
+ next_beam_tokens_[batch * num_beams_ + beam_idx] = next_token;
+ next_beam_indices_[batch * num_beams_ + beam_idx] = batch_beam_idx;
+ ++beam_idx;
+ }
+
+ // Once the beam for next step is full, don't add more tokens to it.
+ if (beam_idx == num_beams_)
+ break;
+ }
+
+ ORT_ENFORCE(beam_idx == num_beams_);
+ ORT_ENFORCE(hypothesis_buffer_offset_ <= batch_size_ * num_beams_ * max_length_);
+
+ // Check if we are done so that we can save a pad step if all(done)
+ if (!done_[batch]) {
+ gsl::span topk_scores = next_scores.subspan(batch * num_beams_, top_k);
+ const T* best_sum_logprobs = std::max_element(topk_scores.begin(), topk_scores.end());
+ if (beam_hyp.IsDone(*best_sum_logprobs, sequence_length)) {
+ done_[batch] = true;
+ }
+ }
+ }
+}
+
+template
+void BeamSearchScorer::Finalize(ISequences* sequences,
+ gsl::span& final_beam_scores,
+ Tensor* output_sequences,
+ Tensor* output_sequence_scores) {
+ ORT_ENFORCE(sequences != nullptr);
+ ORT_ENFORCE(output_sequences != nullptr);
+
+ // Finalize all open beam hypotheses and add to generated hypotheses.
+ for (int batch_index = 0; batch_index < batch_size_; batch_index++) {
+ BeamHypotheses& beam_hyp = beam_hyps[batch_index];
+ if (done_[batch_index]) {
+ continue;
+ }
+
+ for (int beam_index = 0; beam_index < num_beams_; beam_index++) {
+ int batch_beam_index = batch_index * num_beams_ + beam_index;
+ T final_score = final_beam_scores[batch_beam_index];
+ auto final_tokens = sequences->GetSequence(batch_beam_index);
+ beam_hyp.Add(final_tokens, final_score);
+ }
+ }
+
+ // Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length).
+ gsl::span output = output_sequences->MutableDataAsSpan();
+
+ // Fill output sequences with pad token ID so that we do not need append it later.
+ std::fill_n(output.data(), output.size(), pad_token_id_);
+
+ // Score of each sequence, with shape (batch_size * num_return_sequences).
+ gsl::span sequence_scores;
+ if (output_sequence_scores != nullptr) {
+ sequence_scores = output_sequence_scores->MutableDataAsSpan();
+ }
+
+ // Span is empty when output_sequence_scores is NULL.
+ gsl::span batch_sequence_score;
+
+ // Select the best hypotheses according to number of sequences to return.
+ for (int batch_index = 0; batch_index < batch_size_; batch_index++) {
+ BeamHypotheses& beam_hyp = beam_hyps[batch_index];
+
+ const int num_return_sequences = num_beam_hyps_to_keep_;
+ auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, num_return_sequences * max_length_);
+
+ if (output_sequence_scores != nullptr) {
+ batch_sequence_score = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences);
+ }
+
+ beam_hyp.Output(
+ num_return_sequences,
+ max_length_,
+ batch_output,
+ batch_sequence_score);
+ }
+}
+
+// Instantiation
+template class HypothesisScoreCompare;
+template class BeamHypotheses;
+template class BeamSearchScorer;
+
+} // namespace transformers
+} // namespace contrib
+} // namespace onnxruntime
\ No newline at end of file
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
new file mode 100644
index 0000000000..2a15080236
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
@@ -0,0 +1,144 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// The implementation is based on huggingface transformers generation_beam_search.py
+
+#pragma once
+#include
+#include
+#include "core/common/common.h"
+#include "core/framework/allocator.h"
+#include "core/framework/tensorprotoutils.h"
+#include "core/framework/utils.h"
+#include "core/providers/cpu/tensor/utils.h"
+#include "sequences.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace transformers {
+
+// Interface for all scorers for beam search or beam sample.
+template
+class IBeamScorer {
+ public:
+ virtual ~IBeamScorer() {}
+
+ virtual void Initialize(AllocatorPtr& allocator, int sequence_length) = 0;
+
+ virtual void Process(ISequences* sequences,
+ gsl::span& next_scores,
+ gsl::span& next_tokens,
+ gsl::span& next_indices) = 0;
+
+ virtual void Finalize(ISequences* sequences,
+ gsl::span& final_beam_scores,
+ Tensor* output_sequences,
+ Tensor* output_sequence_scores) = 0;
+};
+
+template
+struct HypothesisScore {
+ HypothesisScore(gsl::span& _hypothesis, T _score)
+ : hypothesis(_hypothesis), score(_score) {}
+
+ gsl::span hypothesis;
+ T score;
+};
+
+template
+class HypothesisScoreCompare {
+ public:
+ bool operator()(const HypothesisScore& a, const HypothesisScore& b) {
+ return a.score > b.score;
+ }
+};
+
+template
+class BeamHypotheses {
+ public:
+ BeamHypotheses(int num_beams, T length_penalty, bool early_stopping);
+
+ // Number of hypotheses
+ int Size() { return static_cast(beams_.size()); }
+
+ // Add a new hypothesis
+ void Add(gsl::span