Add BeamSearch operator for GPT-2 decoding (#9680)

* Add BeamSearch operator and CPU implementation
* Add ONNX conversion script
This commit is contained in:
Tianlei Wu 2021-12-16 16:08:05 -08:00 committed by GitHub
parent fab39b4704
commit ef36488df0
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 3262 additions and 3 deletions

View file

@ -5,6 +5,7 @@ Do not modify directly.*
* com.microsoft
* <a href="#com.microsoft.Attention">com.microsoft.Attention</a>
* <a href="#com.microsoft.AttnLSTM">com.microsoft.AttnLSTM</a>
* <a href="#com.microsoft.BeamSearch">com.microsoft.BeamSearch</a>
* <a href="#com.microsoft.BiasDropout">com.microsoft.BiasDropout</a>
* <a href="#com.microsoft.BiasGelu">com.microsoft.BiasGelu</a>
* <a href="#com.microsoft.BiasSoftmax">com.microsoft.BiasSoftmax</a>
@ -337,6 +338,75 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.BeamSearch"></a><a name="com.microsoft.beamsearch">**com.microsoft.BeamSearch**</a>
Beam Search for text generation. Supports GPT-2 decoder.
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>body</tt> : graph (required)</dt>
<dd>The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output</dd>
<dt><tt>early_stopping</tt> : int</dt>
<dd>early stop or not</dd>
<dt><tt>eos_token_id</tt> : int (required)</dt>
<dd>The id of the end-of-sequence token</dd>
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
<dd>no repeat ngrams size</dd>
<dt><tt>pad_token_id</tt> : int (required)</dt>
<dd>The id of the padding token</dd>
</dl>
#### Inputs (6 - 9)
<dl>
<dt><tt>input_ids</tt> : I</dt>
<dd>The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)</dd>
<dt><tt>max_length</tt> : I</dt>
<dd>The maximum length of the sequence to be generated. Shape is (1)</dd>
<dt><tt>min_length</tt> (optional) : I</dt>
<dd>The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)</dd>
<dt><tt>num_beams</tt> : I</dt>
<dd>Number of beams for beam search. 1 means no beam search. Shape is (1)</dd>
<dt><tt>num_return_sequences</tt> : I</dt>
<dd>The number of returned sequences in the batch. Shape is (1)</dd>
<dt><tt>temperature</tt> : T</dt>
<dd>The value used to module the next token probabilities. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>length_penalty</tt> (optional) : T</dt>
<dd>Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)</dd>
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : M</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
</dl>
#### Outputs (1 - 3)
<dl>
<dt><tt>sequences</tt> : I</dt>
<dd>Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)</dd>
<dt><tt>sequences_scores</tt> (optional) : T</dt>
<dd>Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)</dd>
<dt><tt>scores</tt> (optional) : T</dt>
<dd>Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>I</tt> : tensor(int32)</dt>
<dd>Constrain to integer types</dd>
<dt><tt>M</tt> : tensor(int32)</dt>
<dd>Constrain mask to integer types</dd>
</dl>
### <a name="com.microsoft.BiasDropout"></a><a name="com.microsoft.biasdropout">**com.microsoft.BiasDropout**</a>
output, dropout_mask = Dropout(data + bias, ratio) + residual, Intended to specialize the dropout pattern commonly found in transformer models.

View file

@ -377,6 +377,7 @@ Do not modify directly.*
|**Operator Domain:** *com.microsoft*||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* temperature:**T**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|

View file

@ -22,6 +22,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
// add more kernels here
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>,

View file

@ -0,0 +1,650 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// there's no way to use a raw pointer as the copy destination with std::copy_n
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
#include <assert.h>
#include "core/providers/cpu/controlflow/utils.h"
#include "core/providers/cpu/math/top_k.h"
#include "core/framework/allocator.h"
#include "core/framework/framework_common.h"
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/framework/session_options.h"
#include "core/framework/TensorSeq.h"
#include "gsl/gsl"
#include "core/providers/cpu/math/softmax_shared.h"
#include "beam_search.h"
#include "logits_processor.h"
#include "sequences.h"
#include "dump_tensor.h"
#ifdef _MSC_VER
#pragma warning(pop)
#endif
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
BeamSearch, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
transformers::BeamSearch<T>);
REGISTER_KERNEL_TYPED(float)
namespace transformers {
template <typename T>
struct BeamSearchState {
gsl::span<T> beam_scores; // shape (batch_size, num_beams)
gsl::span<T> next_token_logits; // shape (batch_size * num_beams, vocab_size)
gsl::span<T> next_token_scores; // shape (batch_size, num_beams * vocab_size)
gsl::span<int64_t> next_tokens; // shape (batch_size, 2 * num_beams)
gsl::span<int64_t> next_indices; // shape (batch_size, 2 * num_beams)
gsl::span<int64_t> next_positions; // shape (batch_size, num_beams). Next position value for position_ids.
gsl::span<T> scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
gsl::span<T> remaining_scores; // subspan that is avaiable for appending next token scores.
Sequences sequences;
void Init(AllocatorPtr allocator,
int batch_size,
int num_beams,
int vocab_size,
int sequence_length,
int max_length,
bool output_scores) {
size_t batch_beam_size = SafeInt<size_t>(batch_size) * num_beams;
beam_scores = AllocateBuffer<T>(allocator, beam_scores_buffer_, batch_beam_size, true, static_cast<T>(0));
// Initialize score of first beam of each group with 0 and the rest with -1e9.
// This ensures that the beams in the same group don't produce same tokens every time.
for (int i = 0; i < batch_size; i++) {
for (int j = 1; j < num_beams; j++) {
beam_scores[i * num_beams + j] = -1e9;
}
}
size_t next_token_size = SafeInt<size_t>(batch_beam_size) * vocab_size;
next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size, true, static_cast<T>(0));
next_token_scores = AllocateBuffer<T>(allocator, next_token_scores_buffer_, next_token_size, true, static_cast<T>(0));
next_tokens = AllocateBuffer<int64_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size, true, static_cast<int64_t>(0));
next_indices = AllocateBuffer<int64_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size, true, static_cast<int64_t>(0));
next_positions = AllocateBuffer<int64_t>(allocator, next_positions_buffer_, batch_beam_size, true, static_cast<int64_t>(0));
if (output_scores) {
size_t elements = SafeInt<size_t>(max_length - sequence_length) * batch_size * num_beams * vocab_size;
scores = AllocateBuffer<T>(allocator, scores_buffer_, elements);
remaining_scores = scores;
}
// sequences will be initialized later since it has dependency on input_ids
}
private:
BufferUniquePtr beam_scores_buffer_;
BufferUniquePtr next_token_logits_buffer_;
BufferUniquePtr next_token_scores_buffer_;
BufferUniquePtr next_tokens_buffer_;
BufferUniquePtr next_indices_buffer_;
BufferUniquePtr next_positions_buffer_;
BufferUniquePtr scores_buffer_;
};
template <typename T>
class BeamSearchImpl {
public:
BeamSearchImpl(OpKernelContextInternal& context,
const SessionState& session_state,
GptSubgraph& gpt_subgraph,
concurrency::ThreadPool* thread_pool,
void* stream,
BeamSearchParameters& params);
// Initialize by validating all the inputs, and allocating the output tensors.
Status Initialize();
// Execute beam search in iterations util stopping criteria is reached.
// In each iteration, GPT subgraph is called, and next token for each sequence is generated.
Status Execute(const FeedsFetchesManager& cached_ffm);
private:
// Validate inputs.
Status CheckInputs(const OpKernelContextInternal& context);
// Prepare the inputs for first inference of subgraph
void CreateInitialFeeds(gsl::span<int64_t>& next_positions, std::vector<OrtValue>& feeds);
// Update the input for next iteration.
Status UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
gsl::span<int64_t>& next_positions,
gsl::span<const int64_t> beam_next_tokens,
gsl::span<const int64_t> beam_indices);
// Process logits and append next tokens to sequences.
Status GenerateNextToken(const OrtValue& logits,
gsl::span<int64_t>& beam_next_tokens,
gsl::span<int64_t>& beam_indices,
BeamSearchState<T>& beam_state);
// Calculate scores from logits, then apply filtering and select next token for each beam.
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
BeamSearchState<T>& beam_state,
AllocatorPtr& allocator);
OpKernelContextInternal& context_;
const SessionState& session_state_;
GptSubgraph& gpt_subgraph_;
concurrency::ThreadPool* thread_pool_;
const std::vector<const OrtValue*>& implicit_inputs_;
// Not used in CPU. Stream is for CUDA only.
void* stream_;
BeamSearchParameters* parameters_;
LogitsProcessorList<T> logits_processors_;
std::unique_ptr<BeamSearchScorer<T>> beam_scorer_;
AllocatorPtr allocator_;
};
template <typename T>
void BeamSearch<T>::Init(const OpKernelInfo& info) {
// Make sure the body attribute was present even though we don't need it here.
ONNX_NAMESPACE::GraphProto proto;
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("body", &proto).IsOK());
ORT_IGNORE_RETURN_VALUE(proto);
parameters_.ParseFromAttributes(info);
stream_ = nullptr;
}
template <typename T>
std::unique_ptr<OpKernel> BeamSearch<T>::Create(const OpKernelInfo& info,
void* stream) {
auto result = std::make_unique<BeamSearch>(info);
result->SetComputeStream(stream);
return result;
}
template <typename T>
common::Status BeamSearch<T>::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
const auto& node = Node();
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
gpt_subgraph_->num_heads,
gpt_subgraph_->head_size,
gpt_subgraph_->num_layers);
return Status::OK();
}
template <typename T>
Status BeamSearch<T>::Compute(OpKernelContext* ctx) const {
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto* session_state = ctx_internal->SubgraphSessionState("body");
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute.");
ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
BeamSearchParameters parameters = parameters_; // make a copy since we will update the parameters based on inputs later
BeamSearchImpl<T> impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, stream_, parameters};
auto status = impl.Initialize();
ORT_RETURN_IF_ERROR(status);
status = impl.Execute(*feeds_fetches_manager_);
return status;
}
template <typename T>
BeamSearchImpl<T>::BeamSearchImpl(OpKernelContextInternal& context,
const SessionState& session_state,
GptSubgraph& gpt_subgraph,
concurrency::ThreadPool* thread_pool,
void* stream,
BeamSearchParameters& params)
: context_(context),
session_state_(session_state),
gpt_subgraph_(gpt_subgraph),
thread_pool_(thread_pool),
implicit_inputs_(context_.GetImplicitInputs()),
stream_(stream),
parameters_(&params),
allocator_(nullptr) {
parameters_->ParseFromInputs(&context);
allocator_ = session_state.GetExecutionProviders()
.Get(onnxruntime::kCpuExecutionProvider)
->GetAllocator(0, OrtMemTypeDefault);
}
template <typename T>
Status BeamSearchImpl<T>::CheckInputs(const OpKernelContextInternal& context) {
// Input shapes:
// input_ids : (batch_size, sequence_length)
// vocab_mask : (vocab_size) or nullptr
const Tensor* input_ids = context.Input<Tensor>(0);
const auto& dims = input_ids->Shape().GetDims();
if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ",
dims.size());
}
const Tensor* vocab_mask = context.Input<Tensor>(8);
if (vocab_mask != nullptr) { // vocab_mask is optional
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
if (vocab_mask_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ",
vocab_mask_dims.size());
}
// There is dependency on vocab_size parameter, which shall be set before calling this function.
if (static_cast<int>(vocab_mask_dims[0]) != parameters_->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ",
vocab_mask_dims[0]);
}
// store vocab mask in parameters.
parameters_->vocab_mask = vocab_mask->DataAsSpan<int32_t>();
}
return Status::OK();
}
template <typename T>
Status BeamSearchImpl<T>::Initialize() {
auto status = Status::OK();
#define CHECK_SCALAR_INPUT(name, index, required) \
auto* name##_tensor = context_.Input<Tensor>(index); \
if (name##_tensor) { \
if (!name##_tensor->Shape().IsScalar()) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \
name##_tensor->Shape()); \
} \
} else if (required) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \
}
CHECK_SCALAR_INPUT(min_length, 1, false);
CHECK_SCALAR_INPUT(max_length, 2, true);
CHECK_SCALAR_INPUT(num_beams, 3, true);
CHECK_SCALAR_INPUT(num_return_sequences, 4, true);
CHECK_SCALAR_INPUT(temperature, 5, true);
CHECK_SCALAR_INPUT(length_penalty, 6, true);
ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'.");
ORT_RETURN_IF_ERROR(CheckInputs(context_));
// This flag will be updated later when the scores output exists.
parameters_->output_scores = false;
// Initialize processsors after CheckInputs so that parameters_->vocab_mask is ready.
logits_processors_.Init(*parameters_);
return status;
}
template <typename T>
void BeamSearchImpl<T>::CreateInitialFeeds(gsl::span<int64_t>& next_positions, std::vector<OrtValue>& feeds) {
const OrtValue* input_ids_value = context_.GetInputOrtValue(0);
const Tensor& input_ids = input_ids_value->Get<Tensor>();
gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, next_positions, feeds);
}
template <typename T>
Status BeamSearchImpl<T>::ProcessLogits(
const OrtValue& logits,
BeamSearchState<T>& beam_state,
AllocatorPtr& allocator) {
const int64_t batch_beam_size = static_cast<int64_t>(parameters_->BatchBeamSize());
const int& vocab_size = parameters_->vocab_size;
const T* logits_data = logits.Get<Tensor>().Data<T>();
// Logits has shape (batch_size * num_beams, input_length, vocab_size),
// where input_length equals to parameters_->sequence_length for first subgraph call, and 1 for the remaining calls.
const TensorShape& logits_shape = logits.Get<Tensor>().Shape();
ORT_ENFORCE(logits_shape.NumDimensions() == 3);
auto input_length = logits_shape[1];
// Get logits for the last token:
// next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size)
// When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1.
gsl::span<T>& next_token_logits = beam_state.next_token_logits;
if (input_length > 1) {
const T* current_logits = logits_data + (input_length - 1) * vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const T> source(current_logits, vocab_size);
gsl::span<T> target = next_token_logits.subspan(i * vocab_size, vocab_size);
gsl::copy(source, target);
current_logits += input_length * vocab_size;
}
}
#ifdef DEBUG_BEAM_SEARCH
//DumpOrtValue("logits", logits);
DumpTensor("next_token_logits", next_token_logits.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
#endif
// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
gsl::span<T>& next_token_scores = beam_state.next_token_scores;
Status status = SoftmaxCPU<T>(batch_beam_size, // rows
vocab_size, // elements per row
input_length > 1 ? next_token_logits.data() : logits_data,
next_token_scores.data(),
true,
thread_pool_);
if (!status.IsOK()) {
return status;
}
#ifdef DEBUG_BEAM_SEARCH
DumpTensor("next_token_scores after softmax", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
#endif
// Apply all score processors that updates scores
logits_processors_.Process(&(beam_state.sequences), next_token_scores);
#ifdef DEBUG_BEAM_SEARCH
DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
#endif
// Add beam score to next token scores. Corresponding python code is like:
// next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
// TODO: use thread pool to parrellel
int offset = 0;
int batch_beam_index = 0;
for (int i = 0; i < parameters_->batch_size; i++) {
for (int j = 0; j < parameters_->num_beams; j++, batch_beam_index++) {
for (int k = 0; k < parameters_->vocab_size; k++, offset++) {
next_token_scores[offset] += beam_state.beam_scores[batch_beam_index];
}
}
}
#ifdef DEBUG_BEAM_SEARCH
DumpTensor("next_token_scores after adding beam_scores", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
#endif
if (parameters_->output_scores) {
// Append next token scores to the scores output.
gsl::copy(next_token_scores, beam_state.remaining_scores);
beam_state.remaining_scores = beam_state.remaining_scores.subspan(next_token_scores.size());
}
// Apply top-k selection like the following:
// next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
// next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
int64_t next_token_scores_dims[] = {parameters_->batch_size, parameters_->num_beams * vocab_size};
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
auto element_type = DataTypeImpl::GetType<T>();
OrtValue next_token_scores_value;
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value);
const Tensor& input = next_token_scores_value.Get<Tensor>();
const int axis = 1;
const unsigned top_k = static_cast<unsigned>(2 * parameters_->num_beams);
const bool largest = true;
const bool sorted = true; // results returned in sorted order.
std::unique_ptr<Tensor> topk_scores;
std::unique_ptr<Tensor> topk_indices;
status = GetTopK<T>(&input, axis, top_k, largest, sorted, allocator, thread_pool_, topk_scores, topk_indices);
if (!status.IsOK()) {
return status;
}
#ifdef DEBUG_BEAM_SEARCH
DumpTensor<T>("topk_scores", *(topk_scores.get()));
DumpTensor<int64_t>("topk_indices", *(topk_indices.get()));
#endif
// Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following:
// next_indices = (next_tokens / vocab_size).long()
// next_tokens = next_tokens % vocab_size
gsl::span<const int64_t> next_token_indices = topk_indices->DataAsSpan<int64_t>();
offset = 0;
for (int i = 0; i < parameters_->batch_size; i++) {
for (unsigned int j = 0; j < top_k; j++, offset++) {
beam_state.next_indices[offset] = next_token_indices[offset] / vocab_size;
beam_state.next_tokens[offset] = next_token_indices[offset] % vocab_size;
}
}
gsl::span<const T> next_scores = topk_scores->DataAsSpan<T>();
gsl::span<const int64_t> next_tokens(beam_state.next_tokens.data(), beam_state.next_tokens.size());
gsl::span<const int64_t> next_indices(beam_state.next_indices.data(), beam_state.next_indices.size());
#ifdef DEBUG_BEAM_SEARCH
DumpTensor<T>("next_scores before scorer", next_scores.data(), parameters_->batch_size, top_k);
DumpTensor<int64_t>("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, top_k);
DumpTensor<int64_t>("next_indices before scorer", next_indices.data(), parameters_->batch_size, top_k);
#endif
beam_scorer_->Process(
&(beam_state.sequences),
next_scores,
next_tokens,
next_indices);
return Status::OK();
}
template <typename T>
Status BeamSearchImpl<T>::GenerateNextToken(
const OrtValue& logits,
gsl::span<int64_t>& beam_next_tokens,
gsl::span<int64_t>& beam_indices,
BeamSearchState<T>& beam_state) {
// Process logits to get next token scores
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_));
gsl::span<T>& beam_scores = beam_scorer_->GetNextScores();
// It is optional to clone beam_scores. Change it to use same buffer also works:
// beam_state.beam_scores = beam_scores
// Here we make a copy to reduce the coupling with little cost (the buffer size is small).
gsl::copy(beam_scores, beam_state.beam_scores);
beam_next_tokens = beam_scorer_->GetNextTokens();
beam_indices = beam_scorer_->GetNextIndices();
#ifdef DEBUG_BEAM_SEARCH
DumpTensor<T>("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams);
DumpTensor<int64_t>("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams);
DumpTensor<int64_t>("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams);
#endif
beam_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens);
#ifdef DEBUG_BEAM_SEARCH
beam_state.sequences.PrintSequences();
#endif
return Status::OK();
}
template <typename T>
Status BeamSearchImpl<T>::UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
gsl::span<int64_t>& next_positions,
gsl::span<const int64_t> beam_next_tokens,
gsl::span<const int64_t> beam_indices) {
return gpt_subgraph_.UpdateFeeds(last_outputs, next_inputs, current_length, next_positions,
beam_next_tokens, beam_indices, parameters_->num_beams);
}
template <typename T>
Status BeamSearchImpl<T>::Execute(const FeedsFetchesManager& ffm) {
auto status = Status::OK();
std::vector<int64_t> sequences_dims{parameters_->batch_size, parameters_->num_return_sequences, parameters_->max_length};
TensorShape sequences_shape(sequences_dims);
Tensor* output_sequences = context_.Output(0, sequences_shape);
std::vector<int64_t> sequences_scores_dims{parameters_->batch_size, parameters_->num_return_sequences};
TensorShape sequences_scores_shape(sequences_scores_dims);
Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape);
std::vector<int64_t> scores_dims{
parameters_->max_length - parameters_->sequence_length,
parameters_->batch_size, parameters_->num_beams, parameters_->vocab_size};
TensorShape scores_shape(scores_dims);
Tensor* output_scores = context_.Output(2, scores_shape);
// Update the flag to indicate whether scores exists in output
parameters_->output_scores = (output_scores != nullptr);
std::vector<OrtValue> feeds;
std::vector<OrtValue> fetches;
// Initialize resources
AllocatorPtr temp_space_allocator;
ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator));
BeamSearchState<T> beam_state;
beam_state.Init(temp_space_allocator,
parameters_->batch_size,
parameters_->num_beams,
parameters_->vocab_size,
parameters_->sequence_length,
parameters_->max_length,
parameters_->output_scores);
beam_scorer_ = std::make_unique<BeamSearchScorer<T>>(parameters_->batch_size,
parameters_->num_beams,
parameters_->max_length,
parameters_->length_penalty,
parameters_->early_stopping,
parameters_->num_return_sequences,
parameters_->pad_token_id,
parameters_->eos_token_id);
beam_scorer_->Initialize(allocator_, parameters_->sequence_length); // TODO: use temp_space_allocator
CreateInitialFeeds(beam_state.next_positions, feeds);
const OrtValue& input_ids = feeds[0];
beam_state.sequences.Init(temp_space_allocator,
input_ids,
parameters_->BatchBeamSize(),
parameters_->sequence_length,
parameters_->max_length);
#ifdef DEBUG_BEAM_SEARCH
DumpOrtValue("input_ids", input_ids);
DumpOrtValue("position_ids", feeds[1]);
DumpOrtValue("attention_mask", feeds[2]);
#endif
int current_length = parameters_->sequence_length;
while (current_length < parameters_->max_length) {
#ifdef DEBUG_BEAM_SEARCH
DumpString("***CurrentLength", std::to_string(current_length), true);
#endif
status = utils::ExecuteSubgraph(session_state_, ffm, feeds, fetches, {},
ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger());
ORT_RETURN_IF_ERROR(status);
const OrtValue& logits = fetches[0];
gsl::span<int64_t> beam_next_tokens;
gsl::span<int64_t> beam_indices;
ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state));
// When all batches are finished, stop earlier to avoid wasting computation.
if (beam_scorer_->IsDone()) {
break;
}
// Increase sequence length after a new token is generated.
++current_length;
// Prepare inputs for next round of subgraph call.
if (current_length < parameters_->max_length) {
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
beam_state.next_positions,
beam_next_tokens.as_span<const int64_t>(),
beam_indices.as_span<const int64_t>()));
}
fetches.clear();
#ifdef DEBUG_BEAM_SEARCH
if (current_length - parameters_->sequence_length == 3) { // only dump a few steps.
DisableTensorDump();
}
#endif
}
gsl::span<const T> beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
beam_scorer_->Finalize(&(beam_state.sequences),
beam_scores,
output_sequences,
output_sequences_scores);
// Output per token scores
if (output_scores != nullptr) {
gsl::span<T> target = output_scores->MutableDataAsSpan<T>();
gsl::span<const T> source = gsl::span<const T>(beam_state.scores.data(), beam_state.scores.size());
assert(target.length() == source.length());
gsl::copy(source, target);
}
return status;
}
// Instantiation
template class BeamSearchImpl<float>;
template class BeamSearch<float>;
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,48 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <functional>
#include "gsl/gsl"
#include "core/common/common.h"
#include "core/framework/feeds_fetches_manager.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "beam_search_parameters.h"
#include "beam_search_scorer.h"
#include "gpt_subgraph.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
template <typename T>
class BeamSearch : public controlflow::IControlFlowKernel {
public:
BeamSearch(const OpKernelInfo& info) : IControlFlowKernel(info) { Init(info); }
void Init(const OpKernelInfo& info);
Status Compute(OpKernelContext* ctx) const override;
Status SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) override;
static std::unique_ptr<OpKernel> Create(const OpKernelInfo& info, void* stream);
protected:
void SetComputeStream(void* stream) { stream_ = stream; }
private:
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
std::unique_ptr<GptSubgraph> gpt_subgraph_;
FeedsFetchesManager* feeds_fetches_manager_;
void* stream_;
BeamSearchParameters parameters_;
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,72 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "beam_search_parameters.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
constexpr int kMaxSequenceLength = 4096;
Status BeamSearchParameters::Validate() const {
ORT_RETURN_IF(eos_token_id < 0, "eos_token_id is invalid");
ORT_RETURN_IF(pad_token_id < 0, "pad_token_id is invalid");
ORT_RETURN_IF(min_length >= max_length, "min_length shall be smaller than max_length");
return Status::OK();
}
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
no_repeat_ngram_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_repeat_ngram_size", 0));
}
void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
ORT_ENFORCE(context != nullptr);
const Tensor* input_ids = context->Input<Tensor>(0);
const auto& dims = input_ids->Shape().GetDims();
ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
batch_size = static_cast<int>(dims[0]);
sequence_length = static_cast<int>(dims[1]);
auto* max_length_tensor = context->Input<Tensor>(1);
max_length = max_length_tensor ? static_cast<int>(*max_length_tensor->Data<int32_t>()) : kMaxSequenceLength;
ORT_ENFORCE(max_length > sequence_length, "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")");
ORT_ENFORCE(max_length <= kMaxSequenceLength, "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength);
auto* min_length_tensor = context->Input<Tensor>(2);
min_length = min_length_tensor ? static_cast<int>(*min_length_tensor->Data<int32_t>()) : 0;
auto* num_beams_tensor = context->Input<Tensor>(3);
num_beams = num_beams_tensor ? static_cast<int>(*num_beams_tensor->Data<int32_t>()) : 1;
// TODO: limit num_beams > 1 when we can have another operator for greedy search.
ORT_ENFORCE(num_beams >= 1, "num_beams shall be a positive integer, got ", num_beams);
auto* num_return_sequences_tensor = context->Input<Tensor>(4);
num_return_sequences = num_return_sequences_tensor ? static_cast<int>(*num_return_sequences_tensor->Data<int32_t>()) : 1;
ORT_ENFORCE(num_return_sequences >= 1, "num_return_sequences shall be a positive integer, got ", num_return_sequences);
ORT_ENFORCE(num_beams >= num_return_sequences, "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
auto* temperature_tensor = context->Input<Tensor>(5);
temperature = temperature_tensor ? static_cast<float>(*temperature_tensor->Data<float>()) : 1;
ORT_ENFORCE(temperature > 0.0f, "temperature shall be greater than 0, got ", temperature);
auto* length_penalty_tensor = context->Input<Tensor>(6);
length_penalty = length_penalty_tensor ? static_cast<float>(*length_penalty_tensor->Data<float>()) : 1;
auto* repetition_penalty_tensor = context->Input<Tensor>(7);
repetition_penalty = repetition_penalty_tensor ? static_cast<float>(*repetition_penalty_tensor->Data<float>()) : 1.0f;
ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);
}
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
vocab_size = vocabulary_size;
num_heads = heads;
head_size = hidden_size_per_head;
num_layers = layers;
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
struct BeamSearchParameters {
// Parameters from node attributes
int eos_token_id;
int pad_token_id;
int no_repeat_ngram_size;
bool early_stopping;
// Parameters from inputs
int min_length;
int max_length;
int num_beams;
int num_return_sequences;
float temperature;
float length_penalty;
float repetition_penalty;
int batch_size; // deduce from first dimension of input_ids
int sequence_length; // deduce from second dimension of input_ids
gsl::span<const int32_t> vocab_mask;
// Parameters from outputs.
bool output_scores; // whether scores existed in output
// Parameters from subgraph.
int vocab_size;
// Below are used in CPU, reserved for CUDA.
int num_heads;
int head_size;
int num_layers;
Status Validate() const;
int BatchBeamSize() const { return batch_size * num_beams; }
void ParseFromAttributes(const OpKernelInfo& info);
void ParseFromInputs(OpKernelContext* context);
void SetSubgraphParameters(int vocab_size, int num_heads, int head_size, int num_layers);
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,285 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <queue>
#include <math.h>
#include "core/common/common.h"
#include "core/framework/allocator.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/cpu/rnn/rnn_helpers.h"
#include "beam_search_scorer.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
using ::onnxruntime::rnn::detail::Allocate;
template <typename T>
BeamHypotheses<T>::BeamHypotheses(int num_beams, T length_penalty, bool early_stopping)
: num_beams_(num_beams),
length_penalty_(length_penalty),
early_stopping_(early_stopping),
worst_score_(1e9) {}
template <typename T>
void BeamHypotheses<T>::Add(gsl::span<const int64_t>& hypothesis, T sum_logprobs) {
auto length = hypothesis.size();
// TODO: when T is FP16, compute in FP32, then cast result back to FP16. length_penalty_ might also be float.
T score = sum_logprobs / pow(static_cast<T>(length), length_penalty_);
if (this->Size() < num_beams_ || score > worst_score_) {
HypothesisScore<T> item(hypothesis, score);
beams_.push(item);
if (this->Size() > num_beams_) {
beams_.pop();
}
worst_score_ = beams_.top().score;
}
}
template <typename T>
bool BeamHypotheses<T>::IsDone(T best_sum_logprobs, int current_length) {
// If there are enough hypotheses and that none of the hypotheses being generated can become better
// than the worst one in the heap, then we are done with this sentence.
if (Size() < num_beams_)
return false;
if (early_stopping_)
return true;
T current_score = best_sum_logprobs / pow(static_cast<T>(current_length), length_penalty_);
return worst_score_ >= current_score;
}
template <typename T>
void BeamHypotheses<T>::Output(
int top_k,
int max_length,
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
gsl::span<T>& sequences_scores) // buffer of shape (num_return_sequences) or empty
{
ORT_ENFORCE(top_k <= Size());
int remove_count = Size() - top_k;
for (int i = 0; i < remove_count; i++) {
beams_.pop();
}
// Since pop get the worst sequence, so output it in the reverse order.
// The frist (worst) beam shall be put at the last position among top_k sequences.
int index = top_k - 1;
while (!beams_.empty()) {
auto item = beams_.top();
gsl::span<const int64_t>& source = item.hypothesis;
gsl::span<int32_t> target = sequences.subspan(index * max_length, max_length);
// Note that word_ids might be less than max_length.
// Since the sequences has been filled with pad token ID, so padding is not needed here.
// Since data type need cast from int64_t to int32_t, we cannot use gsl::copy(word_ids, sequence) here.
for (size_t i = 0; i < source.length(); i++) {
target[i] = static_cast<int32_t>(source[i]);
}
if (!sequences_scores.empty())
sequences_scores[index] = item.score;
beams_.pop();
index--;
}
}
template <typename T>
BeamSearchScorer<T>::BeamSearchScorer(int batch_size,
int num_beams,
int max_length,
T length_penalty,
bool early_stopping,
int num_return_sequences,
int pad_token_id,
int eos_token_id)
: batch_size_(batch_size),
num_beams_(num_beams),
max_length_(max_length),
num_beam_hyps_to_keep_(num_return_sequences),
pad_token_id_(pad_token_id),
eos_token_id_(eos_token_id),
hypothesis_buffer_length_(0),
hypothesis_buffer_offset_(0) {
for (int batch = 0; batch < batch_size; batch++) {
beam_hyps.push_back(BeamHypotheses(num_beams, length_penalty, early_stopping));
}
}
template <typename T>
bool BeamSearchScorer<T>::IsDone() {
for (int batch = 0; batch < batch_size_; batch++) {
if (!done_[batch])
return false;
}
return true;
}
template <typename T>
void BeamSearchScorer<T>::Initialize(AllocatorPtr& allocator, int sequence_length){
ORT_ENFORCE(next_beam_scores_.empty()); // Make sure this is called only once.
size_t batch_beam_size = static_cast<size_t>(batch_size_ * num_beams_);
const bool no_fill = false; // do not fill values after allocation
next_beam_scores_ = Allocate<T>(allocator, batch_beam_size, next_beam_scores_ptr_, no_fill);
next_beam_tokens_ = Allocate<int64_t>(allocator, batch_beam_size, next_beam_tokens_ptr_, no_fill);
next_beam_indices_ = Allocate<int64_t>(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill);
// Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
int buffer_per_beam = (max_length_ * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2;
hypothesis_buffer_length_ = batch_beam_size * static_cast<size_t>(buffer_per_beam);
hypothesis_buffer_ = Allocate<int64_t>(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill);
done_ = Allocate<bool>(allocator, static_cast<size_t>(batch_size_), done_ptr_, no_fill);
std::fill_n(done_.data(), done_.size(), false);
}
template <typename T>
void BeamSearchScorer<T>::Process(ISequences* sequences,
gsl::span<const T>& next_scores,
gsl::span<const int64_t>& next_tokens,
gsl::span<const int64_t>& next_indices) {
// Sequences shape is (batch_size * num_beams, total_sequence_length)
// It contains word ID of whole sequence generated so far.
// It is different from subgraph input_ids, which only need one word when past state is not empty.
const int sequence_length = sequences->GetSequenceLength();
ORT_ENFORCE(next_scores.size() == next_tokens.size());
ORT_ENFORCE(next_scores.size() == next_indices.size());
for (int batch = 0; batch < batch_size_; batch++) {
BeamHypotheses<T>& beam_hyp = beam_hyps[batch];
if (done_[batch]) {
ORT_ENFORCE(beam_hyp.Size() >= num_beams_, "Batch can only be done if all beams have been generated");
// Pad the batch.
for (int j = 0; j < num_beams_; j++) {
next_beam_scores_[batch * num_beams_ + j] = 0.0f;
next_beam_tokens_[batch * num_beams_ + j] = pad_token_id_;
next_beam_indices_[batch * num_beams_ + j] = 0;
}
continue;
}
// Next tokens for this sentence.
int beam_idx = 0;
int top_k = 2 * num_beams_;
for (int j = 0; j < top_k; j++) {
int64_t next_token = next_tokens[batch * top_k + j];
T next_score = next_scores[batch * top_k + j];
int64_t next_index = next_indices[batch * top_k + j];
int batch_beam_idx = batch * num_beams_ + static_cast<int>(next_index);
// Add to generated hypotheses if end of sentence.
if ((eos_token_id_ >= 0) && (next_token == eos_token_id_)) {
bool is_beam_token_worse_than_top_num_beams = (j >= num_beams_);
if (is_beam_token_worse_than_top_num_beams) {
continue;
}
// Clone the sequence and append to buffer.
gsl::span<const int64_t> src = sequences->GetSequence(batch_beam_idx);
auto clone = hypothesis_buffer_.subspan(hypothesis_buffer_offset_, sequence_length);
gsl::copy(src, clone);
hypothesis_buffer_offset_ += sequence_length;
auto sequence = clone.template as_span<const int64_t>();
beam_hyp.Add(sequence, next_score);
} else {
// Add next predicted token since it is not eos_token.
next_beam_scores_[batch * num_beams_ + beam_idx] = next_score;
next_beam_tokens_[batch * num_beams_ + beam_idx] = next_token;
next_beam_indices_[batch * num_beams_ + beam_idx] = batch_beam_idx;
++beam_idx;
}
// Once the beam for next step is full, don't add more tokens to it.
if (beam_idx == num_beams_)
break;
}
ORT_ENFORCE(beam_idx == num_beams_);
ORT_ENFORCE(hypothesis_buffer_offset_ <= batch_size_ * num_beams_ * max_length_);
// Check if we are done so that we can save a pad step if all(done)
if (!done_[batch]) {
gsl::span<const T> topk_scores = next_scores.subspan(batch * num_beams_, top_k);
const T* best_sum_logprobs = std::max_element(topk_scores.begin(), topk_scores.end());
if (beam_hyp.IsDone(*best_sum_logprobs, sequence_length)) {
done_[batch] = true;
}
}
}
}
template <typename T>
void BeamSearchScorer<T>::Finalize(ISequences* sequences,
gsl::span<const T>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) {
ORT_ENFORCE(sequences != nullptr);
ORT_ENFORCE(output_sequences != nullptr);
// Finalize all open beam hypotheses and add to generated hypotheses.
for (int batch_index = 0; batch_index < batch_size_; batch_index++) {
BeamHypotheses<T>& beam_hyp = beam_hyps[batch_index];
if (done_[batch_index]) {
continue;
}
for (int beam_index = 0; beam_index < num_beams_; beam_index++) {
int batch_beam_index = batch_index * num_beams_ + beam_index;
T final_score = final_beam_scores[batch_beam_index];
auto final_tokens = sequences->GetSequence(batch_beam_index);
beam_hyp.Add(final_tokens, final_score);
}
}
// Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length).
gsl::span<int32_t> output = output_sequences->MutableDataAsSpan<int32_t>();
// Fill output sequences with pad token ID so that we do not need append it later.
std::fill_n(output.data(), output.size(), pad_token_id_);
// Score of each sequence, with shape (batch_size * num_return_sequences).
gsl::span<T> sequence_scores;
if (output_sequence_scores != nullptr) {
sequence_scores = output_sequence_scores->MutableDataAsSpan<T>();
}
// Span is empty when output_sequence_scores is NULL.
gsl::span<T> batch_sequence_score;
// Select the best hypotheses according to number of sequences to return.
for (int batch_index = 0; batch_index < batch_size_; batch_index++) {
BeamHypotheses<T>& beam_hyp = beam_hyps[batch_index];
const int num_return_sequences = num_beam_hyps_to_keep_;
auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, num_return_sequences * max_length_);
if (output_sequence_scores != nullptr) {
batch_sequence_score = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences);
}
beam_hyp.Output(
num_return_sequences,
max_length_,
batch_output,
batch_sequence_score);
}
}
// Instantiation
template class HypothesisScoreCompare<float>;
template class BeamHypotheses<float>;
template class BeamSearchScorer<float>;
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,144 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// The implementation is based on huggingface transformers generation_beam_search.py
#pragma once
#include <queue>
#include <math.h>
#include "core/common/common.h"
#include "core/framework/allocator.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "sequences.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
// Interface for all scorers for beam search or beam sample.
template <typename T>
class IBeamScorer {
public:
virtual ~IBeamScorer() {}
virtual void Initialize(AllocatorPtr& allocator, int sequence_length) = 0;
virtual void Process(ISequences* sequences,
gsl::span<const T>& next_scores,
gsl::span<const int64_t>& next_tokens,
gsl::span<const int64_t>& next_indices) = 0;
virtual void Finalize(ISequences* sequences,
gsl::span<const T>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) = 0;
};
template <typename T>
struct HypothesisScore {
HypothesisScore(gsl::span<const int64_t>& _hypothesis, T _score)
: hypothesis(_hypothesis), score(_score) {}
gsl::span<const int64_t> hypothesis;
T score;
};
template <typename T>
class HypothesisScoreCompare {
public:
bool operator()(const HypothesisScore<T>& a, const HypothesisScore<T>& b) {
return a.score > b.score;
}
};
template <typename T>
class BeamHypotheses {
public:
BeamHypotheses(int num_beams, T length_penalty, bool early_stopping);
// Number of hypotheses
int Size() { return static_cast<int>(beams_.size()); }
// Add a new hypothesis
void Add(gsl::span<const int64_t>& hypothesis, T sum_logprobs);
bool IsDone(T best_sum_logprobs, int current_length);
// Output results. Note that it will clear all beams.
void Output(int top_k, // number of sequences to return
int max_length, // max sequence length
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, with shape (num_return_sequences, max_length)
gsl::span<T>& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
private:
int num_beams_;
T length_penalty_;
bool early_stopping_;
T worst_score_;
std::priority_queue<HypothesisScore<T>, std::vector<HypothesisScore<T>>, HypothesisScoreCompare<T>> beams_; // min-heap for top k
};
template <typename T>
class BeamSearchScorer : public IBeamScorer<T> {
public:
BeamSearchScorer(int batch_size,
int num_beams,
int max_length,
T length_penalty,
bool early_stopping,
int num_return_sequences,
int pad_token_id,
int eos_token_id);
void Initialize(AllocatorPtr& allocator, int sequence_length) override;
void Process(ISequences* sequences,
gsl::span<const T>& next_scores,
gsl::span<const int64_t>& next_tokens,
gsl::span<const int64_t>& next_indices) override;
void Finalize(ISequences* sequences,
gsl::span<const T>& final_beam_scores,
Tensor* output_sequences,
Tensor* output_sequence_scores) override;
bool IsDone();
gsl::span<T>& GetNextScores() { return next_beam_scores_; }
gsl::span<int64_t>& GetNextTokens() { return next_beam_tokens_; }
gsl::span<int64_t>& GetNextIndices() { return next_beam_indices_; }
private:
int batch_size_;
int num_beams_;
int max_length_;
int num_beam_hyps_to_keep_;
int pad_token_id_;
int eos_token_id_;
// TODO: use ORT allocator to avoid allocating from heap directly
std::vector<BeamHypotheses<T>> beam_hyps; // List of batch result of beam search. Its shape is (batch_size)
IAllocatorUniquePtr<bool> done_ptr_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size).
gsl::span<bool> done_;
IAllocatorUniquePtr<T> next_beam_scores_ptr_;
gsl::span<T> next_beam_scores_;
IAllocatorUniquePtr<int64_t> next_beam_tokens_ptr_;
gsl::span<int64_t> next_beam_tokens_;
IAllocatorUniquePtr<int64_t> next_beam_indices_ptr_;
gsl::span<int64_t> next_beam_indices_;
IAllocatorUniquePtr<int64_t> hypothesis_buffer_ptr_; // Allocated buffer to hold all hypotheses
gsl::span<int64_t> hypothesis_buffer_; // Span of the allocated buffer
size_t hypothesis_buffer_length_; // Total number of elements
int hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer.
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,73 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "dump_tensor.h"
#include "core/platform/env.h"
#include "core/platform/env_var_utils.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
namespace dump_tensor_env_vars {
constexpr const char* kDumpBeamSearch = "ORT_DUMP_BEAM_SEARCH";
}
#ifdef NDEBUG
bool g_enable_tensor_dump = false;
#else
bool g_enable_tensor_dump = true;
#endif
void DumpOrtValue(const char* name, const OrtValue& value) {
if (!g_enable_tensor_dump)
return;
std::cout << std::string(name) << "\n";
const Tensor& tensor = value.Get<Tensor>();
MLDataType dataType = tensor.DataType();
if (dataType == DataTypeImpl::GetType<float>()) {
DumpTensor<float>(nullptr, tensor);
} else if (dataType == DataTypeImpl::GetType<int32_t>()) {
DumpTensor<int32_t>(nullptr, tensor);
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
DumpTensor<int64_t>(nullptr, tensor);
} else {
std::cout << "not float/int32/int64";
}
}
void ConfigureTensorDump() {
const auto parsed = ParseEnvironmentVariable<bool>(dump_tensor_env_vars::kDumpBeamSearch);
if (parsed.has_value()) {
g_enable_tensor_dump = *parsed;
}
}
void DisableTensorDump() {
g_enable_tensor_dump = false;
}
void DumpString(const char* name, int index, bool end_line) {
if (!g_enable_tensor_dump)
return;
std::cout << std::string(name) << "[" << index << "]";
if (end_line) {
std::cout << std::endl;
}
}
void DumpString(const char* name, std::string value, bool end_line) {
if (!g_enable_tensor_dump)
return;
std::cout << std::string(name) << "=" << value;
if (end_line) {
std::cout << std::endl;
}
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,148 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <iomanip>
#include <string>
#include "core/framework/tensorprotoutils.h"
#ifndef NDEBUG
//#define DEBUG_BEAM_SEARCH 1 // uncomment it for debugging beam search
#endif
namespace onnxruntime {
namespace contrib {
namespace transformers {
#define MAX_ROW_OR_COLUMN 8
#define SKIP_IF_MORE_THAN(row_or_column_size, i, max_n, new_line) \
if (row_or_column_size > max_n && i >= max_n / 2 && i + max_n / 2 < row_or_column_size) { \
if (i == max_n / 2) { \
std::cout << ", ..."; \
if (new_line) \
std::cout << std::endl; \
} \
continue; \
}
#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) SKIP_IF_MORE_THAN(row_or_column_size, i, MAX_ROW_OR_COLUMN, new_line)
extern bool g_enable_tensor_dump; // global variance to turn on/off dump
template <typename T>
void PrintValue(const T& value) {
if (std::is_floating_point<T>::value)
std::cout << std::setprecision(8) << value;
else
std::cout << value;
}
template <typename T>
void DumpTensor(const char* name, const Tensor& tensor) {
if (!g_enable_tensor_dump)
return;
if (nullptr != name) {
std::cout << std::string(name) << std::endl;
}
const auto& shape = tensor.Shape();
auto num_items = shape.Size();
if (num_items == 0) {
std::cout << "no data";
return;
}
size_t num_dims = shape.NumDimensions();
size_t num_rows = 1;
if (num_dims > 1) {
num_rows = static_cast<size_t>(shape[0]);
}
size_t row_size = num_items / num_rows;
auto data = tensor.DataAsSpan<T>();
for (size_t row = 0; row < num_rows; ++row) {
SKIP_IF_TOO_MANY(num_rows, row, true);
std::cout << "[" << row << "]:";
for (size_t i = 0; i < row_size; ++i) {
SKIP_IF_TOO_MANY(row_size, i, false);
if (i > 0)
std::cout << ", ";
PrintValue(data[row * row_size + i]);
}
std::cout << "\n";
}
std::cout << std::endl;
}
void DumpOrtValue(const char* name, const OrtValue& value);
template <typename T>
void DumpTensor(const char* name, const T* tensor, int dim0, int dim1) {
if (!g_enable_tensor_dump)
return;
if (nullptr != name) {
std::cout << std::string(name) << std::endl;
}
for (int i = 0; i < dim0; i++) {
SKIP_IF_TOO_MANY(dim0, i, true);
std::cout << "[" << i << "]:";
for (int j = 0; j < dim1; j++) {
SKIP_IF_TOO_MANY(dim1, j, false);
if (j > 0)
std::cout << ", ";
T value = tensor[i * dim1 + j];
PrintValue<T>(value);
}
std::cout << std::endl;
}
}
void DumpString(const char* name, int index, bool end_line);
void DumpString(const char* name, std::string value, bool end_line);
template <typename T>
void DumpTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) {
if (!g_enable_tensor_dump)
return;
if (nullptr != name) {
std::cout << std::string(name) << std::endl;
}
for (int i = 0; i < dim0; i++) {
SKIP_IF_TOO_MANY(dim0, i, true);
for (int j = 0; j < dim1; j++) {
SKIP_IF_TOO_MANY(dim1, j, true);
std::cout << "[" << i << "][" << j << "]:";
for (int k = 0; k < dim2; k++) {
SKIP_IF_TOO_MANY(dim2, k, false);
if (k > 0)
std::cout << ", ";
T value = tensor[i * dim1 * dim2 + j * dim2 + k];
PrintValue<T>(value);
}
std::cout << std::endl;
}
std::cout << std::endl;
}
std::cout << std::endl;
}
void ConfigureTensorDump();
void DisableTensorDump();
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,451 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// there's no way to use a raw pointer as the copy destination with std::copy_n
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
#include "core/framework/framework_common.h"
#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "gsl/gsl"
#include "gpt_subgraph.h"
#include "dump_tensor.h"
#ifdef _MSC_VER
#pragma warning(pop)
#endif
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
namespace transformers {
GptSubgraph::GptSubgraph(
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in)
: node(node_in), attribute(attribute_name), subgraph(subgraph_in), allocator_(nullptr) {
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
auto& subgraph_inputs = subgraph.GetInputs();
auto& subgraph_outputs = subgraph.GetOutputs();
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
// outputs: logits, present_0, present_1, ...
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
// CheckSubgraph will verify inputs and outputs later.
subgraph_input_names.reserve(num_subgraph_inputs);
for (int i = 0; i < num_subgraph_inputs; ++i) {
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
}
subgraph_output_names.reserve(num_subgraph_outputs);
for (int i = 0; i < num_subgraph_outputs; ++i) {
subgraph_output_names.push_back(subgraph_outputs[i]->Name());
}
}
Status GptSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
ORT_RETURN_IF(num_subgraph_outputs <= 1,
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in inputs and outputs).");
ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2,
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2");
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ",
subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ",
subgraph_inputs[1]->Name());
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ",
subgraph_inputs[2]->Name());
ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ",
subgraph_inputs[3]->Name());
// Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads.
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape();
ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ",
past_shape->dim_size());
ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2,
"subgraph past state dimension 0 shall have length of 2");
ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for number of heads");
ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0,
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
// check subgraph outputs
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ",
subgraph_outputs[0]->Name());
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ",
subgraph_outputs[1]->Name());
// Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size.
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ",
logits_shape->dim_size());
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
// Save parameters related to the subgraph.
num_heads = static_cast<int>(past_shape->dim(2).dim_value());
head_size = static_cast<int>(past_shape->dim(4).dim_value());
vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
num_layers = static_cast<int>(subgraph_outputs.size()) - 1;
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
"subgraph input 0 (input_ids) shall have int64 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
"subgraph input 1 (position_ids) shall have int64 type");
// TODO: support float16
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
"subgraph input 2 (attention_mask) shall have float type");
ORT_RETURN_IF(subgraph_inputs[3]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
"subgraph input 3 (past_0) shall have float type");
ORT_RETURN_IF(subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
"subgraph output 0 (logits) shall have float type");
ORT_RETURN_IF(subgraph_outputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
"subgraph output 1 (present_0) shall have float type");
return Status::OK();
}
Status GptSubgraph::Setup(const SessionState& session_state,
const SessionState& subgraph_session_state) {
session_state_ = &session_state;
subgraph_session_state_ = &subgraph_session_state;
std::vector<std::string> feed_names;
feed_names.reserve(num_subgraph_inputs + num_implicit_inputs);
// First, get the location of input_ids of current operator.
const auto& node_inputs = node.InputDefs();
const OrtMemoryInfo& input_ids_location = utils::FindMemoryInfoForValue(session_state, node_inputs[0]->Name());
// position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
// as we skip them when we call FindDevicesForValues, and default them to be in the same device as input_ids
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
for (auto& entry : node.ImplicitInputDefs()) {
feed_names.push_back(entry->Name());
}
std::vector<OrtDevice> feed_locations;
feed_locations.resize(feed_names.size());
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
if (i >= subgraph_input_names.size()) { // implicit inputs
const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]);
feed_locations[i] = location.device;
} else {
feed_locations[i] = input_ids_location.device;
}
}
std::unique_ptr<FeedsFetchesManager> ffm;
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names,
subgraph_session_state.GetOrtValueNameIdxMap(), ffm));
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
// setup the locations where we want the subgraph output to end up on
std::vector<const OrtMemoryInfo*> fetch_locations;
fetch_locations.reserve(num_subgraph_outputs);
// past state need to be where we can feed them in to the next iteration, so set the fetch location to match the feed location.
for (int i = 0; i < num_subgraph_outputs; ++i) {
fetch_locations.push_back(&input_ids_location);
}
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
feeds_fetches_manager_ = std::move(ffm);
// Check subgraph only need once so put in Setup function.
auto& inputs = subgraph.GetInputs();
auto& outputs = subgraph.GetOutputs();
ORT_RETURN_IF_ERROR(Validate(inputs, outputs));
return Status::OK();
}
void GptSubgraph::CreateInitialFeeds(
const Tensor& input_ids,
const std::vector<const OrtValue*>& implicit_inputs,
int num_beams,
int pad_token_id,
gsl::span<int64_t>& next_positions,
std::vector<OrtValue>& feeds) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
// Subgraph inputs:
// input_ids: shape (B, S) wher B is batch size, and S is sequence length
// position_ids: shape (B, S)
// attention_mask: shape (B, P+S), where past_sequence_length (P) is 0
// After expansion, their shapes will become (B, M*S), where M is num_beams.
// Allocate subgraph inputs to be same device as input_ids
AllocatorPtr alloactor = session_state_->GetAllocator(input_ids.Location());
// Store allocator, which is needed in ExpandInputs.
allocator_ = alloactor;
const TensorShape& input_ids_shape = input_ids.Shape();
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
const int64_t& batch_size = input_ids_shape[0];
const int64_t& sequence_length = input_ids_shape[1];
// Allocate position_ids and attention_mask based on shape of input_ids
auto element_type = DataTypeImpl::GetType<int64_t>();
// input_ids for subgraph is int64, so we need Cast input_ids from int32 to int64.
OrtValue subgraph_input_ids;
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, subgraph_input_ids);
int64_t* subgraph_input_data = subgraph_input_ids.GetMutable<Tensor>()->MutableData<int64_t>();
const int32_t* source = input_ids.Data<int32_t>();
int64_t* target = subgraph_input_data;
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < sequence_length; j++, source++, target++) {
*target = static_cast<int64_t>(*source);
}
}
OrtValue position_ids;
Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, position_ids);
OrtValue attention_mask;
auto mask_type = DataTypeImpl::GetType<float>();
Tensor::InitOrtValue(mask_type, input_ids_shape, alloactor, attention_mask);
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
// Set position id to be 0 for pad tokens, and cumulated sum of mask in a batch for other tokens
float* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<float>();
int64_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int64_t>();
source = input_ids.Data<int32_t>();
float* mask = mask_data;
int64_t* position = position_data;
for (int i = 0; i < batch_size; i++) {
int64_t abs_position = 0;
for (int j = 0; j < sequence_length; j++, source++, mask++, position++) {
if (*source == pad_token_id) {
*mask = 0.0f;
*position = 0;
} else {
*mask = 1.0f;
*position = abs_position;
abs_position++;
}
}
for (int k = 0; k < num_beams; k++) {
next_positions[i * num_beams + k] = abs_position;
}
}
// Initialize empty past state
auto past_type = DataTypeImpl::GetType<float>();
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size};
TensorShape past_shape(&past_state_dims[0], 5);
OrtValue empty_past;
Tensor::InitOrtValue(past_type, past_shape, allocator_, empty_past);
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask
// TODO: Try expand inputs/outputs after first subgraph call instead. That may get better peroformance, but more complex to implement.
OrtValue expanded_input_ids = ExpandInputs(subgraph_input_ids, num_beams);
OrtValue expanded_position_ids = ExpandInputs(position_ids, num_beams);
OrtValue expanded_attention_mask = ExpandInputs(attention_mask, num_beams);
// The ordering is the same as used in Setup
feeds.reserve(num_subgraph_inputs + num_implicit_inputs);
feeds.push_back(expanded_input_ids);
feeds.push_back(expanded_position_ids);
feeds.push_back(expanded_attention_mask);
// The remaing inputs are past state.
for (int i = 3; i < num_subgraph_inputs; ++i) {
feeds.push_back(empty_past);
}
// pass in implicit inputs
for (const auto* entry : implicit_inputs) {
feeds.push_back(*entry);
}
}
OrtValue GptSubgraph::ExpandInputs(const OrtValue& input, int num_beams) const {
// Input shape (batch_size, sequence_length)
// Output shape (batch_size * num_beams, sequence_length)
if (num_beams == 1)
return input;
const TensorShape& input_shape = input.Get<Tensor>().Shape();
const int64_t& batch_size = input_shape[0];
const int64_t& sequence_length = input_shape[1];
int64_t dims[] = {batch_size * num_beams, sequence_length};
TensorShape expanded_shape(&dims[0], 2);
OrtValue expanded;
MLDataType element_type = input.Get<Tensor>().DataType();
Tensor::InitOrtValue(element_type, expanded_shape, allocator_, expanded);
if (element_type == DataTypeImpl::GetType<int64_t>()) {
const int64_t* input_data = input.Get<Tensor>().Data<int64_t>();
int64_t* expanded_data = expanded.GetMutable<Tensor>()->MutableData<int64_t>();
int64_t* target = expanded_data;
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
memcpy(target, input_data + i * sequence_length, sizeof(int64_t) * sequence_length);
target += sequence_length;
}
}
} else if (element_type == DataTypeImpl::GetType<float>()) {
const float* input_data = input.Get<Tensor>().Data<float>();
float* expanded_data = expanded.GetMutable<Tensor>()->MutableData<float>();
float* target = expanded_data;
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
memcpy(target, input_data + i * sequence_length, sizeof(float) * sequence_length);
target += sequence_length;
}
}
}
return expanded;
}
// TODO: support float16
void GptSubgraph::PickPastState(const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
gsl::span<const int64_t>& beam_indices) {
for (int i = 3; i < num_subgraph_inputs; ++i) {
const OrtValue& present = last_outputs[i - 2]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64)
const TensorShape& past_shape = present.Get<Tensor>().Shape();
// Create a tensor with same shape.
OrtValue past;
auto past_type = DataTypeImpl::GetType<float>();
Tensor::InitOrtValue(past_type, past_shape, allocator_, past);
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
gsl::span<float> past_span = past.GetMutable<Tensor>()->MutableDataAsSpan<float>();
gsl::span<const float> present_span = present.Get<Tensor>().DataAsSpan<float>();
for (gsl::index j = 0; j < beam_indices.length(); j++) {
int64_t beam_index = beam_indices[j];
gsl::span<const float> present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
gsl::span<const float> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
gsl::span<float> past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
gsl::span<float> past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
gsl::copy(present_key, past_key);
gsl::copy(present_value, past_value);
#ifdef DEBUG_BEAM_SEARCH
if (i == 3) // only dump past_0
{
DumpString("past_key of beam", static_cast<int>(j), true);
DumpTensor<float>(nullptr, past_key.data(), 1, static_cast<int>(block_size_per_beam));
DumpString("past_value of beam", static_cast<int>(j), true);
DumpTensor<float>(nullptr, past_value.data(), 1, static_cast<int>(block_size_per_beam));
}
#endif
}
next_inputs[i] = past;
}
}
Status GptSubgraph::UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
gsl::span<int64_t>& next_positions,
gsl::span<const int64_t> beam_next_tokens,
gsl::span<const int64_t> beam_indices,
int num_beams) {
// last_outputs: logits, present_0, present_1, ...
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
// The following updates inputs for subgraph
// TODO: Reuse buffer for input_ids and position_ids to reduce memory allocation.
// Update input_ids with next tokens.
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
int64_t dims[] = {batch_beam_size, 1};
TensorShape input_ids_shape(&dims[0], 2);
auto element_type = DataTypeImpl::GetType<int64_t>();
OrtValue input_ids;
Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, input_ids);
int64_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int64_t>();
for (int i = 0; i < batch_beam_size; i++) {
input_ids_data[i] = beam_next_tokens[i];
}
next_inputs[0] = input_ids;
// Update position IDs
OrtValue position_ids;
Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids);
int64_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int64_t>();
for (int i = 0; i < batch_beam_size; i++) {
position_data[i] = next_positions[i];
next_positions[i]++;
}
next_inputs[1] = position_ids;
// Update attention mask
const OrtValue& old_mask = next_inputs[2];
const float* old_mask_data = old_mask.Get<Tensor>().Data<float>();
int64_t mask_dims[] = {batch_beam_size, current_length};
TensorShape mask_shape(&mask_dims[0], 2);
OrtValue attention_mask;
auto mask_type = DataTypeImpl::GetType<float>();
Tensor::InitOrtValue(mask_type, mask_shape, allocator_, attention_mask);
float* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<float>();
for (int i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < current_length - 1; j++) {
mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j];
}
mask_data[i * current_length + current_length - 1] = 1.0f;
}
next_inputs[2] = attention_mask;
#ifdef DEBUG_BEAM_SEARCH
DumpOrtValue("input_ids", input_ids);
DumpOrtValue("position_ids", position_ids);
DumpOrtValue("attention_mask", attention_mask);
#endif
// Update past state
if (num_beams == 1) {
// feed present_* output to past_* inputs one by one
for (int i = 3; i < num_subgraph_inputs; ++i) {
next_inputs[i] = last_outputs[i - 2];
}
} else {
PickPastState(last_outputs, next_inputs, beam_indices);
}
return Status::OK();
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,82 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
//#include <functional>
#include "gsl/gsl"
#include "core/framework/allocator.h"
#include "core/framework/session_state.h"
#include "core/framework/feeds_fetches_manager.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
// A class for GPT-2 subgraph inputs and outputs preparation.
struct GptSubgraph {
GptSubgraph(
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in);
const onnxruntime::Node& node; // node that contains the subgraph
const std::string& attribute; // attribute of th node that contains the subgraph. Not used yet.
const GraphViewer& subgraph; // the subgraph
int num_implicit_inputs;
int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience.
int num_subgraph_outputs; // same as subgraph_output_names.size()
std::vector<std::string> subgraph_input_names;
std::vector<std::string> subgraph_output_names;
// Parameters deduced from the subgraph
int num_heads;
int head_size;
int vocab_size;
int num_layers;
// Setup exectuion
Status Setup(const SessionState& session_state,
const SessionState& subgraph_session_state);
// Create inputs for first inference of subgraph.
void CreateInitialFeeds(
const Tensor& input_ids,
const std::vector<const OrtValue*>& implicit_inputs,
int num_beams,
int pad_token_id,
gsl::span<int64_t>& next_positions,
std::vector<OrtValue>& feeds);
Status UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
gsl::span<int64_t>& next_positions,
gsl::span<const int64_t> beam_next_tokens,
gsl::span<const int64_t> beam_indices,
int num_beams);
FeedsFetchesManager* GetFeedsFetchesManager() const { return feeds_fetches_manager_.get(); }
protected:
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs);
OrtValue ExpandInputs(const OrtValue& input, int num_beams) const;
void PickPastState(const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
gsl::span<const int64_t>& beam_indices);
AllocatorPtr allocator_;
const SessionState* session_state_;
const SessionState* subgraph_session_state_;
std::unique_ptr<FeedsFetchesManager> feeds_fetches_manager_;
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,192 @@
#include <assert.h>
#include "logits_processor.h"
#include "dump_tensor.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
template <typename T>
gsl::span<T> NextTokenScores<T>::GetScores(int batch_beam_index) {
assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size);
return scores.subspan(batch_beam_index * vocab_size, vocab_size);
}
template <typename T>
void NextTokenScores<T>::SetScore(int token_id, T score) {
assert(token_id >= 0 && token_id < vocab_size);
for (int i = 0; i < batch_beam_size; i++) {
scores[i * vocab_size + token_id] = score;
}
}
#ifdef DEBUG_BEAM_SEARCH
template <typename T>
void DumpScores(const char* name, gsl::span<T>& scores) {
DumpString(name, 0, true);
ORT_UNUSED_PARAMETER(scores);
}
#endif
// Interface for all scorers for beam search or beam sample.
template <typename T>
MinLengthLogitsProcessor<T>::MinLengthLogitsProcessor(int min_length, int eos_token_id)
: min_length_(min_length), eos_token_id_(eos_token_id) {}
template <typename T>
void MinLengthLogitsProcessor<T>::Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) {
if (sequences->GetSequenceLength() < min_length_) {
next_token_scores.SetScore(eos_token_id_, std::numeric_limits<T>::lowest());
}
#ifdef DEBUG_BEAM_SEARCH
DumpScores("MinLengthLogitsProcessor", next_token_scores.scores);
#endif
}
template <typename T>
RepetitionPenaltyLogitsProcessor<T>::RepetitionPenaltyLogitsProcessor(float penalty) : penalty_(penalty) {
}
template <typename T>
void RepetitionPenaltyLogitsProcessor<T>::Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) {
const int batch_beam_size = next_token_scores.batch_beam_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<T> beam_token_scores = next_token_scores.GetScores(i);
gsl::span<const int64_t> sequence = sequences->GetSequence(i);
// Find unique word IDs in sequence.
std::unordered_set<int64_t> unique_word_ids;
for (const auto& word_id : sequence) {
unique_word_ids.insert(word_id);
}
for (const int64_t word_id : unique_word_ids) {
T score = beam_token_scores[word_id];
// If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability,
// This assumes that scores are either positive (like ctrl) or negative (like GPT-2), but not a mixture.
beam_token_scores[word_id] = (score < 0 ? score * penalty_ : score / penalty_);
}
}
#ifdef DEBUG_BEAM_SEARCH
DumpScores("RepetitionPenaltyLogitsProcessor", next_token_scores.scores);
#endif
}
template <typename T>
NoRepeatNGramLogitsProcessor<T>::NoRepeatNGramLogitsProcessor(int ngram_size) : ngram_size_(ngram_size) {
}
template <typename T>
void NoRepeatNGramLogitsProcessor<T>::Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) {
if (ngram_size_ == 0 || ngram_size_ > sequences->GetSequenceLength()) {
return;
}
const gsl::index prefix_length = static_cast<gsl::index>(ngram_size_ - 1);
int batch_beam_size = next_token_scores.batch_beam_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<T> beam_token_scores = next_token_scores.GetScores(i);
gsl::span<const int64_t> sequence = sequences->GetSequence(i);
gsl::span<const int64_t> prefix = sequence.subspan(sequence.length() - prefix_length);
ORT_ENFORCE(prefix.length() == prefix_length);
std::unordered_set<int64_t> blocked_word_ids;
for (int j = 0; j <= static_cast<int>(sequence.length()) - ngram_size_; j++) {
// Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length)
// TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching.
if (ngram_size_ == 1 || prefix == sequence.subspan(j, prefix_length)) {
blocked_word_ids.insert(sequence[j + prefix_length]);
}
}
for (const int64_t word_id : blocked_word_ids) {
beam_token_scores[word_id] = std::numeric_limits<T>::lowest();
}
}
#ifdef DEBUG_BEAM_SEARCH
DumpScores("NoRepeatNGramLogitsProcessor", next_token_scores.scores);
#endif
}
template <typename T>
VocabMaskLogitsProcessor<T>::VocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask) : vocab_mask_(vocab_mask) {
}
template <typename T>
void VocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
NextTokenScores<T>& next_token_scores) {
assert(!vocab_mask_.empty());
// Process vocabulary mask and set tokens with mask value 0 to -inf.
T* p = next_token_scores.scores.data();
// next_token_scores shape (batch_size * num_beams, vocab_size)
// vocab_mask shape (vocab_size). TODO: support shape (batch_size, vocab_size)
for (int i = 0; i < next_token_scores.batch_beam_size; i++) {
for (int j = 0; j < next_token_scores.vocab_size; j++, p++) {
if (vocab_mask_[j] == 0) {
*p = std::numeric_limits<T>::lowest();
}
}
}
#ifdef DEBUG_BEAM_SEARCH
DumpScores("VocabMaskLogitsProcessor", next_token_scores.scores);
#endif
}
template <typename T>
void LogitsProcessorList<T>::Init(const BeamSearchParameters& parameters) {
processor_list_.clear();
if (parameters.repetition_penalty != 1.0f) { // 1.0 means no penalty
repetition_penalty_processor_ = std::make_unique<RepetitionPenaltyLogitsProcessor<T>>(parameters.repetition_penalty);
processor_list_.push_back(repetition_penalty_processor_.get());
}
if (parameters.no_repeat_ngram_size > 0) {
no_repeat_ngram_processor_ = std::make_unique<NoRepeatNGramLogitsProcessor<T>>(parameters.no_repeat_ngram_size);
processor_list_.push_back(no_repeat_ngram_processor_.get());
}
if (!parameters.vocab_mask.empty()) {
vocab_mask_processor_ = std::make_unique<VocabMaskLogitsProcessor<T>>(parameters.vocab_mask);
processor_list_.push_back(vocab_mask_processor_.get());
}
if (parameters.min_length > 0) {
min_length_processor_ = std::make_unique<MinLengthLogitsProcessor<T>>(parameters.min_length, parameters.eos_token_id);
processor_list_.push_back(min_length_processor_.get());
}
batch_beam_size_ = parameters.BatchBeamSize();
vocab_size_ = parameters.vocab_size;
}
template <typename T>
void LogitsProcessorList<T>::Process(const ISequences* sequences,
gsl::span<T>& next_token_scores) {
NextTokenScores<T> input_scores = {next_token_scores, batch_beam_size_, vocab_size_};
for (size_t i = 0; i < processor_list_.size(); i++) {
processor_list_[i]->Process(sequences, input_scores);
}
}
// Instantiation
template class MinLengthLogitsProcessor<float>;
template class RepetitionPenaltyLogitsProcessor<float>;
template class NoRepeatNGramLogitsProcessor<float>;
template class VocabMaskLogitsProcessor<float>;
template class LogitsProcessorList<float>;
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,99 @@
#pragma once
#include "sequences.h"
#include "beam_search_parameters.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
template <typename T>
struct NextTokenScores {
gsl::span<T>& scores;
int batch_beam_size;
int vocab_size;
gsl::span<T> GetScores(int batch_beam_index);
void SetScore(int token_id, T score);
};
// Interface for all scorers for beam search or beam sample.
template <typename T>
class ILogitsProcessor {
public:
virtual ~ILogitsProcessor() {}
virtual void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) = 0;
};
template <typename T>
class MinLengthLogitsProcessor : public ILogitsProcessor<T> {
public:
MinLengthLogitsProcessor(int min_length, int eos_token_id);
void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;
private:
int min_length_;
int eos_token_id_;
};
template <typename T>
class RepetitionPenaltyLogitsProcessor : public ILogitsProcessor<T> {
public:
RepetitionPenaltyLogitsProcessor(float penalty);
void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;
private:
float penalty_;
};
template <typename T>
class NoRepeatNGramLogitsProcessor : public ILogitsProcessor<T> {
public:
NoRepeatNGramLogitsProcessor(int ngram_size);
void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;
private:
int ngram_size_;
};
template <typename T>
class VocabMaskLogitsProcessor : public ILogitsProcessor<T> {
public:
VocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask);
void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;
private:
gsl::span<const int32_t> vocab_mask_;
};
template <typename T>
class LogitsProcessorList {
public:
LogitsProcessorList() = default ;
void Init(const BeamSearchParameters& parameters);
void Process(const ISequences* sequences, gsl::span<T>& next_token_scores);
private:
int batch_beam_size_;
int vocab_size_;
std::vector<ILogitsProcessor<T>*> processor_list_;
std::unique_ptr<RepetitionPenaltyLogitsProcessor<T>> repetition_penalty_processor_;
std::unique_ptr<NoRepeatNGramLogitsProcessor<T>> no_repeat_ngram_processor_;
std::unique_ptr<VocabMaskLogitsProcessor<T>> vocab_mask_processor_;
std::unique_ptr<MinLengthLogitsProcessor<T>> min_length_processor_;
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,76 @@
#include "sequences.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
void Sequences::Init(AllocatorPtr allocator, const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length) {
size_t sequences_size = SafeInt<size_t>(batch_beam_size) * max_length;
size_t buffer_size = sequences_size + sequences_size;
gsl::span<int64_t> buffer = AllocateBuffer<int64_t>(allocator, sequences_space_buffer_, buffer_size, true, static_cast<int64_t>(0));
sequences[0] = buffer.subspan(0, sequences_size);
sequences[1] = buffer.subspan(sequences_size);
// Copy input_ids to sequences[0].
gsl::span<const int64_t> input = input_ids.Get<Tensor>().DataAsSpan<int64_t>();
gsl::span<int64_t> output = sequences[0];
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const int64_t> source = input.subspan(i * sequence_length, sequence_length);
gsl::span<int64_t> target = output.subspan(i * max_length, sequence_length);
gsl::copy(source, target);
}
current_sequences_buffer = 0;
batch_beam_size_ = batch_beam_size;
max_length_ = max_length;
current_length_ = sequence_length;
}
gsl::span<const int64_t> Sequences::GetSequence(int beam_index) const {
gsl::span<const int64_t> buffer(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size());
gsl::span<const int64_t> sequence = buffer.subspan(beam_index * max_length_, current_length_);
return sequence;
}
int Sequences::GetSequenceLength() const {
return current_length_;
}
void Sequences::PrintSequences() {
#ifdef DEBUG_BEAM_SEARCH
for (int i = 0; i < batch_beam_size_; i++) {
gsl::span<const int64_t> sequence = GetSequence(i);
DumpString("sequences", i, false);
DumpTensor<int64_t>(nullptr, sequence.data(), 1, current_length_);
}
#endif
}
void Sequences::AppendNextTokenToSequences(
gsl::span<int64_t>& beam_indices,
gsl::span<int64_t>& beam_next_tokens) {
gsl::span<const int64_t> input(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size());
gsl::span<int64_t> output = sequences[1 - current_sequences_buffer];
for (int i = 0; i < batch_beam_size_; i++) {
int beam_index = static_cast<int>(beam_indices[i]);
gsl::span<const int64_t> source = input.subspan(beam_index * max_length_, current_length_);
gsl::span<int64_t> target = output.subspan(i * max_length_, current_length_);
gsl::copy(source, target);
}
// Append next token to each beam.
for (int i = 0; i < batch_beam_size_; i++) {
output[i * max_length_ + current_length_] = beam_next_tokens[i];
}
++current_length_;
// Rotate buffer for next round.
current_sequences_buffer = 1 - current_sequences_buffer;
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,80 @@
#pragma once
#include "gsl/gsl"
#include "core/common/safeint.h"
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
class ISequences {
public:
virtual ~ISequences() {}
virtual gsl::span<const int64_t> GetSequence(int beam_index) const = 0;
virtual int GetSequenceLength() const = 0;
};
// This class keeps track of sequences generated.
class Sequences : public ISequences {
public:
Sequences() {}
// Initialize the sequence with initial input_ids and related parameters.
void Init(AllocatorPtr allocator, const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length);
// Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size).
gsl::span<const int64_t> GetSequence(int beam_index) const override;
// Returns current sequence length.
int GetSequenceLength() const override;
// Print the sequences to StdOut in debug mode
void PrintSequences();
// Select sequences based on beam indices, then append next token to selected sequences.
void AppendNextTokenToSequences(
gsl::span<int64_t>& beam_indices,
gsl::span<int64_t>& beam_next_tokens);
private:
gsl::span<int64_t> sequences_space; // shape (2, batch_size, num_beams, max_seq_length)
BufferUniquePtr sequences_space_buffer_;
// Two buffers of shape (batch_size, num_beams, max_seq_length) to store sequences.
// At each time, there is only one buffer is active. The other one will be active in next token.
// Each AppendNextTokenToSequences call will trigger a rotation of active buffer.
gsl::span<int64_t> sequences[2];
// Index (either 0 or 1) of two buffers that is currently is active.
int current_sequences_buffer;
int batch_beam_size_;
int max_length_;
int current_length_;
};
template <typename T>
gsl::span<T> AllocateBuffer(AllocatorPtr allocator,
BufferUniquePtr& buffer,
size_t elements,
bool fill = false,
T fill_value = T{}) {
size_t bytes = SafeInt<size_t>(sizeof(T)) * elements;
void* data = allocator->Alloc(bytes);
BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
buffer = std::move(temp_buffer);
T* first = reinterpret_cast<T*>(buffer.get());
auto span = gsl::make_span(first, elements);
if (fill) {
std::fill_n(first, elements, fill_value);
}
return span;
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -568,6 +568,139 @@ void DecoderAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx
}
}
bool ParseScalar(const TensorProto* initializer, int& value) {
std::vector<int32_t> parsed_data;
if (initializer->data_type() == TensorProto::INT32) {
const auto& data = ParseData<int32_t>(initializer);
parsed_data.insert(parsed_data.end(), data.begin(), data.end());
if (parsed_data.size() == 1) {
value = parsed_data[0];
return true;
}
}
return false;
}
void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
if (ctx.getNumOutputs() > 1) {
// Here we assume that the third output exist only if second output exists.
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 1);
if (ctx.getNumOutputs() > 2) {
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 2);
}
}
// Shape inference
// input 0 (input_ids) shape: (batch_size, sequence_length)
// output 0 (sequences) shape: (batch_size, num_return_sequences, max_length)
// output 1 (sequences_scores) shape: (batch_size, num_return_sequences)
// output 2 (scores) shape: (max_length - sequence_length, batch_size, num_beams, vocab_size)
if (!hasInputShape(ctx, 0)) {
return;
}
auto& input_ids_shape = getInputShape(ctx, 0);
auto& input_ids_dims = input_ids_shape.dim();
if (input_ids_dims.size() != 2) {
fail_shape_inference("Inputs 0 shall be 2 dimensions");
}
if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) {
return;
}
int64_t batch_size = input_ids_dims[0].dim_value();
int64_t sequence_length = input_ids_dims[1].dim_value();
const auto max_length = ctx.getInputData(1);
const auto num_beams = ctx.getInputData(3);
const auto num_return_sequences = ctx.getInputData(4);
if (num_beams == nullptr || max_length == nullptr || num_return_sequences == nullptr) { // not initializer
return;
}
int max_length_value = 0;
if (!ParseScalar(max_length, max_length_value) || max_length_value <= 0) {
fail_shape_inference("Failed to parse max_length or it is not positive integer scalar");
}
int num_beams_value = 0;
if (!ParseScalar(num_beams, num_beams_value) || num_beams_value <= 0) {
fail_shape_inference("Failed to parse num_beams or it is not positive integer scalar");
}
int num_return_sequences_value = 0;
if (!ParseScalar(num_return_sequences, num_return_sequences_value) || num_return_sequences_value <= 0) {
fail_shape_inference("Failed to parse num_return_sequences or it is not positive integer scalar");
}
ONNX_NAMESPACE::TensorShapeProto sequences_shape;
sequences_shape.add_dim()->set_dim_value(batch_size);
sequences_shape.add_dim()->set_dim_value(num_return_sequences_value);
sequences_shape.add_dim()->set_dim_value(max_length_value);
updateOutputShape(ctx, 0, sequences_shape);
if (ctx.getNumOutputs() > 1) {
ONNX_NAMESPACE::TensorShapeProto sequences_scores_shape;
sequences_shape.add_dim()->set_dim_value(batch_size);
sequences_shape.add_dim()->set_dim_value(num_return_sequences_value);
updateOutputShape(ctx, 1, sequences_scores_shape);
if (ctx.getNumOutputs() > 2) {
ONNX_NAMESPACE::TensorShapeProto scores_shape;
scores_shape.add_dim()->set_dim_value(max_length_value - sequence_length);
scores_shape.add_dim()->set_dim_value(batch_size);
scores_shape.add_dim()->set_dim_value(num_beams_value);
scores_shape.add_dim(); // vocab_size is unknown
updateOutputShape(ctx, 2, scores_shape);
}
}
}
void RegisterTextGenerationSchemas() {
ONNX_CONTRIB_OPERATOR_SCHEMA(BeamSearch)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetDoc("Beam Search for text generation. Supports GPT-2 decoder.")
.Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT)
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast<int64_t>(0))
.Attr(
"body",
"The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output",
AttributeProto::GRAPH)
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
.Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I")
.Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I")
.Input(5, "temperature", "The value used to module the next token probabilities. Accepts value > 0.0. Shape is (1)", "T")
.Input(6, "length_penalty",
"Exponential penalty to the length. Default value 1.0 means no penalty."
"Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences."
"Shape is (1,)",
"T", OpSchema::Optional)
.Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
.Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
.Output(2, "scores",
"Processed beam scores for each vocabulary token at each generation step."
"Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam."
"Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)",
"T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types")
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
BeamSearchShapeInference(ctx);
});
}
void RegisterBertSchemas() {
static const char* Attention_ver1_doc = R"DOC(
Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT).
@ -750,8 +883,7 @@ Some boolean parameters are passed by runtime input for generic purpose
.TypeConstraint("B", {"tensor(bool)"}, "Constrain key_padding_mask to bool tensors.")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
DecoderAttentionTypeAndShapeInference(ctx);
});
});
static const char* EmbedLayerNormalization_ver1_doc = R"DOC(
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
@ -3146,6 +3278,7 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
RegisterNhwcSchemas();
RegisterBertSchemas();
RegisterTextGenerationSchemas();
#ifdef BUILD_MS_EXPERIMENTAL_OPS
onnxruntime::signal::RegisterSignalSchemas();

View file

@ -2587,6 +2587,15 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) {
}
}
// verify subgraphs
for (auto node_index : nodes_in_topological_order_) {
auto& node = *GetNode(node_index);
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
Graph* subgraph = entry.second;
ORT_RETURN_IF_ERROR(subgraph->VerifyNodeAndOpMatch(options));
}
}
return Status::OK();
}

View file

@ -370,6 +370,56 @@ static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input
return Status::OK();
}
// Wrapper over core TopK implementation
template <typename T>
Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices) {
const TensorShape& input_shape = input->Shape();
// Will return axis_ as is if positive or fixes it in case it is negative
const auto axis_parsed = HandleNegativeAxis(axis, static_cast<int64_t>(input_shape.NumDimensions()));
// Check to ensure k is within the bounds of what is available in that specific axis
if (input_shape[axis_parsed] < k) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "k argument [", k,
"] should not be greater than specified axis dim value [", input_shape[axis_parsed], "]");
}
// Resize output tensors to be the same shape as the input except
// for the specified dimension ((i.e.) axis_parsed), which will be of size k. E.x. for an input tensor
// of shape [3, 4, 5] and k=2 with axis_parsed=1, both of the outputs will be shape [3, 2, 5]
TensorShape output_shape = input_shape;
output_shape[axis_parsed] = k;
output_values = Tensor::Create(input->DataType(), output_shape, allocator);
output_indices = Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
// no-op - no output buffers to fill - return silently
if (k == 0) {
return Status::OK();
}
if (largest) {
FindTopKElements<GreaterValueCmp<T>>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted,
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
} else {
FindTopKElements<LesserValueCmp<T>>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted,
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
}
return Status::OK();
}
// explicit instantiation
template Status GetTopK<float>(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices);
// Opset ver - 1 to 9
static void TopkOpset9ConstructorCommon(const OpKernelInfo& op_kernel_info, int& axis, unsigned int& k) {

View file

@ -19,4 +19,11 @@ class TopK final : public OpKernel {
bool largest_; // opset-11 only
bool sorted_; // opset-11 only
};
template <typename T>
Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices);
} // namespace onnxruntime

View file

@ -0,0 +1,459 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import os
import time
import onnx
import logging
import argparse
from pathlib import Path
from onnx import helper
import numpy as np
from typing import List
import torch
from transformers import GPT2Config
from gpt2_helper import PRETRAINED_GPT2_MODELS
from convert_to_onnx import main as convert_gpt2_to_onnx
from benchmark_helper import Precision
"""
This converts GPT2 model to onnx with beam search operator.
Examples:
python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
"""
config: GPT2Config = None
logger = logging.getLogger('')
def parse_arguments(argv=None):
parser = argparse.ArgumentParser()
parser.add_argument('-m',
'--model_name_or_path',
required=True,
type=str,
help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS))
parser.add_argument('--cache_dir',
required=False,
type=str,
default=os.path.join('.', 'cache_models'),
help='Directory to cache pre-trained models')
parser.add_argument('--gpt2_onnx',
required=True,
type=str,
help='Output directory for GPT-2 onnx model, or model path ends with .onnx')
parser.add_argument('--output',
required=False,
type=str,
help='Output directory for beam search model, or model path ends with .onnx')
parser.add_argument("-p",
"--precision",
required=False,
type=Precision,
default=Precision.FLOAT32,
choices=[Precision.FLOAT32, Precision.FLOAT16],
help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision")
parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU for inference")
parser.set_defaults(use_gpu=False)
parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true')
parser.set_defaults(use_external_data_format=False)
parser.add_argument('--disable_parity', required=False, action='store_true', help="do not run parity test")
parser.set_defaults(disable_parity=False)
parser.add_argument('--total_runs',
required=False,
type=int,
default=1,
help='Number of times of inference for latency measurement')
beam_search_group = parser.add_argument_group("beam search options")
beam_search_group.add_argument('--output_sequences_scores',
required=False,
action='store_true',
help="output sequences scores")
beam_search_group.set_defaults(output_sequences_scores=False)
beam_search_group.add_argument('--output_token_scores',
required=False,
action='store_true',
help="output token scores")
beam_search_group.set_defaults(output_token_scores=False)
beam_search_group.add_argument('--early_stopping', required=False, action='store_true')
beam_search_group.set_defaults(early_stopping=False)
beam_search_group.add_argument('--min_length', type=int, required=False, default=1, help='Min sequence length')
beam_search_group.add_argument('--max_length', type=int, required=False, default=50, help='Max sequence length')
beam_search_group.add_argument('--no_repeat_ngram_size',
type=int,
required=False,
default=0,
help='No repeat ngram size')
beam_search_group.add_argument('--num_beams', type=int, required=False, default=4, help='Beam size')
beam_search_group.add_argument('--num_return_sequences',
type=int,
required=False,
default=1,
help='Number of return sequence <= num_beams')
beam_search_group.add_argument('--temperature',
type=float,
required=False,
default=1,
help='Softmax temperature for output logits.')
beam_search_group.add_argument('--length_penalty',
type=float,
required=False,
default=1,
help='Positive. >1 to penalize and <1 to encorage short sentence.')
beam_search_group.add_argument('--repetition_penalty',
type=float,
required=False,
default=1,
help='Positive. >1 to penalize and <1 to encorage.')
mixed_precision_option_group = parser.add_argument_group(
"mixed precision conversion parameters that works when \"--precision fp16\" is specified")
mixed_precision_option_group.add_argument('--io_block_list',
nargs='+',
required=False,
default=[],
help='List of inputs or outputs in float32')
mixed_precision_option_group.add_argument(
'--op_block_list',
nargs='+',
required=False,
default=[],
help='List of operators (like Add LayerNormalization FastGelu) to compute in float32.')
mixed_precision_option_group.add_argument('--node_block_list',
nargs='+',
required=False,
default=[],
help='List of node names to compute in float32.')
mixed_precision_option_group.add_argument('--force_fp16_initializers',
required=False,
action='store_true',
help='Convert all float initializers to float16.')
mixed_precision_option_group.set_defaults(force_fp16_initializers=False)
args = parser.parse_args(argv)
return args
def gpt2_to_onnx(args):
model_name = args.model_name_or_path
print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.gpt2_onnx} ...")
arguments = [
'--model_name_or_path', model_name, '--output', args.gpt2_onnx, '--optimize_onnx', '--precision',
'fp32' if args.precision == Precision.FLOAT32 else 'fp16', '--test_runs', '1', '--test_cases', '10'
]
if args.use_gpu:
arguments.append('--use_gpu')
if args.use_external_data_format:
arguments.append('--use_external_data_format')
# mixed precision conversion options
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
if args.io_block_list:
arguments.append('--io_block_list')
arguments.extend(args.io_block_list)
if args.op_block_list:
arguments.append('--op_block_list')
arguments.extend(args.op_block_list)
if args.node_block_list:
arguments.append('--node_block_list')
arguments.extend(args.node_block_list)
if args.force_fp16_initializers:
arguments.append('--force_fp16_initializers')
convert_gpt2_to_onnx(arguments)
def shape_inference(gpt2_onnx_path):
# Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
out = SymbolicShapeInference.infer_shapes(onnx.load(gpt2_onnx_path), auto_merge=True, guess_output_rank=False)
if out:
# TODO: Use external format if input has extra data.
onnx.save(out, gpt2_onnx_path)
else:
print("Failed to run symbolic shape inference on the model.")
def create_ort_session(model_path, use_gpu):
from onnxruntime import SessionOptions, InferenceSession, __version__ as ort_version, GraphOptimizationLevel
sess_options = SessionOptions()
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
ort_session = InferenceSession(model_path, sess_options, providers=execution_providers)
return ort_session
def convert_model(args):
if os.path.exists(args.gpt2_onnx):
print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}")
else:
gpt2_to_onnx(args)
print(f"Run symbolic shape inference on {args.gpt2_onnx}. The file will be overwritten.")
shape_inference(args.gpt2_onnx)
global config
config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
print(config)
eos_token_id = config.eos_token_id
pad_token_id = config.eos_token_id
vocab_size = config.vocab_size
model = onnx.load(args.gpt2_onnx)
model.graph.name = "gpt2 subgraph"
inputs = [
"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty",
"repetition_penalty", "vocab_mask"
]
outputs = ["sequences"]
if args.output_sequences_scores:
outputs.append("sequences_scores")
if args.output_token_scores:
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
outputs.append("scores")
node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name='BeamSearch_GPT2')
node.domain = "com.microsoft"
node.attribute.extend([
helper.make_attribute("eos_token_id", eos_token_id),
helper.make_attribute("pad_token_id", pad_token_id),
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
helper.make_attribute("body", model.graph),
])
from onnx import TensorProto
# graph inputs
input_ids = helper.make_tensor_value_info('input_ids', TensorProto.INT32, ['batch_size', 'sequence_length'])
max_length = helper.make_tensor_value_info('max_length', TensorProto.INT32, [1])
min_length = helper.make_tensor_value_info('min_length', TensorProto.INT32, [1])
num_beams = helper.make_tensor_value_info('num_beams', TensorProto.INT32, [1])
num_return_sequences = helper.make_tensor_value_info('num_return_sequences', TensorProto.INT32, [1])
temperature = helper.make_tensor_value_info('temperature', TensorProto.FLOAT, [1])
length_penalty = helper.make_tensor_value_info('length_penalty', TensorProto.FLOAT, [1])
repetition_penalty = helper.make_tensor_value_info('repetition_penalty', TensorProto.FLOAT, [1])
vocab_mask = helper.make_tensor_value_info('vocab_mask', TensorProto.INT32, [vocab_size])
graph_inputs = [
input_ids, max_length, min_length, num_beams, num_return_sequences, temperature, length_penalty,
repetition_penalty, vocab_mask
]
# graph outputs
sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32,
['batch_size', 'num_return_sequences', 'max_length'])
sequences_scores = helper.make_tensor_value_info('sequences_scores', TensorProto.FLOAT,
['batch_size', 'num_return_sequences'])
scores = helper.make_tensor_value_info('scores', TensorProto.FLOAT,
['max_length - sequence_length', 'batch_size', 'num_beams', vocab_size])
initializers = []
graph_outputs = [sequences]
if args.output_sequences_scores:
graph_outputs.append(sequences_scores)
if args.output_token_scores:
graph_outputs.append(scores)
new_graph = helper.make_graph([node], 'gpt2-beam-search', graph_inputs, graph_outputs, initializers)
# Create the model
new_model = helper.make_model(new_graph, producer_name='onnxruntime.transformers', opset_imports=model.opset_import)
onnx.save(new_model, args.output)
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer.padding_side = "left"
tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path,
cache_dir=args.cache_dir,
pad_token_id=tokenizer.eos_token_id)
# Use different length sentences to test batching
if sentences is None:
sentences = ["The product is released", "I enjoy walking in the park", "Test best way to invest"]
inputs = tokenizer(sentences, return_tensors='pt', padding=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
bad_words = "walk in park"
bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True)
bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
if use_vocab_mask:
print("bad_words_ids", bad_words_ids)
else:
bad_words_ids = None
global config
config = model.config
eos_token_id = config.eos_token_id
pad_token_id = config.eos_token_id
vocab_size = config.vocab_size
torch_decoded_sequences = []
if not args.disable_parity:
print('-' * 50)
print("Test PyTorch model and beam search with huggingface transformers...")
beam_outputs = model.generate(input_ids=input_ids,
attention_mask=attention_mask,
max_length=args.max_length,
min_length=args.min_length,
num_beams=args.num_beams,
early_stopping=args.early_stopping,
no_repeat_ngram_size=args.no_repeat_ngram_size,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
num_return_sequences=args.num_return_sequences,
temperature=args.temperature,
length_penalty=args.length_penalty,
repetition_penalty=args.repetition_penalty,
bad_words_ids=bad_words_ids,
return_dict_in_generate=True,
output_scores=True)
print("input_ids", input_ids)
print("huggingface transformers outputs:")
print("sequences", beam_outputs.sequences)
if args.output_sequences_scores:
print("sequences_scores", beam_outputs.sequences_scores)
if args.output_token_scores:
print("scores", beam_outputs.scores)
for i, sequence in enumerate(beam_outputs.sequences):
decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
torch_decoded_sequences.append(decoded_sequence)
print("{}: {}".format(i, decoded_sequence))
print('-' * 50)
print("Test ONNX model and bream search with onnxruntime...")
ort_session = create_ort_session(args.output, args.use_gpu)
vocab_mask = np.ones((vocab_size), dtype=np.int32)
if use_vocab_mask:
for bad_word_id in bad_words_ids:
vocab_mask[bad_word_id] = 0
inputs = {
"input_ids": input_ids.cpu().numpy().astype(np.int32),
"max_length": np.array([args.max_length], dtype=np.int32),
"min_length": np.array([args.min_length], dtype=np.int32),
"num_beams": np.array([args.num_beams], dtype=np.int32),
"num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
"temperature": np.array([args.temperature], dtype=np.float32),
"length_penalty": np.array([args.length_penalty], dtype=np.float32),
"repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
"vocab_mask": vocab_mask
}
test_data_dir = Path(args.output).parent.as_posix()
print("test_data_dir", test_data_dir)
from bert_test_data import output_test_data
all_inputs = [inputs]
for i, inputs in enumerate(all_inputs):
dir = os.path.join(test_data_dir, 'test_data_set_' + str(i))
output_test_data(dir, inputs)
print("inputs", inputs)
# Test performance
latency = []
for _ in range(args.total_runs):
start = time.time()
result = ort_session.run(None, inputs)
latency.append(time.time() - start)
batch_size = input_ids.shape[0]
from benchmark_helper import get_latency_result
output = get_latency_result(latency, batch_size)
print("ORT outputs:")
sequences = result[0]
print("sequences", sequences)
if args.output_sequences_scores:
print("sequences_scores", result[1])
if args.output_token_scores:
print("scores", result[2])
(batch_size, num_sequences, max_length) = sequences.shape
ort_decoded_sequences = []
for i in range(batch_size):
for j in range(num_sequences):
decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
ort_decoded_sequences.append(decoded_sequence)
print(f"batch {i} sequence {j}: {decoded_sequence}")
if not args.disable_parity:
torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
ort_sequences = torch.LongTensor(sequences)
print("-" * 50)
print("Torch Sequences:")
print(torch_sequences)
print(torch_decoded_sequences)
print("-" * 50)
print("ORT Sequences:")
print(ort_sequences)
print(ort_decoded_sequences)
print("-" * 50)
# Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
is_same = (torch_decoded_sequences == ort_decoded_sequences)
print("Torch and ORT result is ", "same" if is_same else "different")
output["parity"] = is_same
print(output)
return output
def main(argv=None, sentences=None):
args = parse_arguments(argv)
if os.path.exists(args.output):
print(f"skip conversion since path existed: {args.output}")
else:
convert_model(args)
return test_model(args, use_vocab_mask=True, sentences=sentences)
if __name__ == '__main__':
main()

View file

@ -17,7 +17,6 @@ This converts GPT2 model to onnx. Examples:
import os
import argparse
import coloredlogs
import logging
import torch
import numpy

View file

@ -0,0 +1,70 @@
#!/usr/bin/env python
# coding: utf-8
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
import unittest
import os
import pytest
from parity_utilities import find_transformers_source
if find_transformers_source():
from convert_beam_search import main as run
else:
from onnxruntime.transformers.convert_beam_search import main as run
class TestBeamSearch(unittest.TestCase):
def setUp(self):
#TODO: use a smaller model and enable tests in CI pipeline
self.model_name = "gpt2"
self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx')
self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx')
self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0'
def run_beam_search(self, arguments: str, sentences=None):
return run(arguments.split(), sentences=sentences)
@pytest.mark.slow
def test_cpu(self):
result = self.run_beam_search(self.cpu_params + " --num_return_sequences 2",
sentences=["The product is released"])
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
@pytest.mark.slow
def test_early_stopping(self):
result = self.run_beam_search(self.cpu_params + " --early_stopping")
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
@pytest.mark.slow
def test_temperature(self):
result = self.run_beam_search(self.cpu_params + " --temperature 0.5")
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
@pytest.mark.slow
def test_length_penalty(self):
result = self.run_beam_search(self.cpu_params + " --length_penalty 0.5")
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
@pytest.mark.slow
def test_no_repeat_ngram(self):
for ngram_size in [1, 2]:
result = self.run_beam_search(self.cpu_params + f' --no_repeat_ngram_size {ngram_size}')
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
if __name__ == '__main__':
unittest.main()

View file

@ -3,6 +3,10 @@
"Affine ai.onnx CPUExecutionProvider",
7811918192248490408
],
[
"BeamSearch com.microsoft CPUExecutionProvider",
6968087233460196528
],
[
"Crop ai.onnx CPUExecutionProvider",
6914973556202621376