mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Add BeamSearch operator for GPT-2 decoding (#9680)
* Add BeamSearch operator and CPU implementation * Add ONNX conversion script
This commit is contained in:
parent
fab39b4704
commit
ef36488df0
25 changed files with 3262 additions and 3 deletions
|
|
@ -5,6 +5,7 @@ Do not modify directly.*
|
|||
* com.microsoft
|
||||
* <a href="#com.microsoft.Attention">com.microsoft.Attention</a>
|
||||
* <a href="#com.microsoft.AttnLSTM">com.microsoft.AttnLSTM</a>
|
||||
* <a href="#com.microsoft.BeamSearch">com.microsoft.BeamSearch</a>
|
||||
* <a href="#com.microsoft.BiasDropout">com.microsoft.BiasDropout</a>
|
||||
* <a href="#com.microsoft.BiasGelu">com.microsoft.BiasGelu</a>
|
||||
* <a href="#com.microsoft.BiasSoftmax">com.microsoft.BiasSoftmax</a>
|
||||
|
|
@ -337,6 +338,75 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.BeamSearch"></a><a name="com.microsoft.beamsearch">**com.microsoft.BeamSearch**</a>
|
||||
|
||||
Beam Search for text generation. Supports GPT-2 decoder.
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>body</tt> : graph (required)</dt>
|
||||
<dd>The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output</dd>
|
||||
<dt><tt>early_stopping</tt> : int</dt>
|
||||
<dd>early stop or not</dd>
|
||||
<dt><tt>eos_token_id</tt> : int (required)</dt>
|
||||
<dd>The id of the end-of-sequence token</dd>
|
||||
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
|
||||
<dd>no repeat ngrams size</dd>
|
||||
<dt><tt>pad_token_id</tt> : int (required)</dt>
|
||||
<dd>The id of the padding token</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (6 - 9)
|
||||
|
||||
<dl>
|
||||
<dt><tt>input_ids</tt> : I</dt>
|
||||
<dd>The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)</dd>
|
||||
<dt><tt>max_length</tt> : I</dt>
|
||||
<dd>The maximum length of the sequence to be generated. Shape is (1)</dd>
|
||||
<dt><tt>min_length</tt> (optional) : I</dt>
|
||||
<dd>The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)</dd>
|
||||
<dt><tt>num_beams</tt> : I</dt>
|
||||
<dd>Number of beams for beam search. 1 means no beam search. Shape is (1)</dd>
|
||||
<dt><tt>num_return_sequences</tt> : I</dt>
|
||||
<dd>The number of returned sequences in the batch. Shape is (1)</dd>
|
||||
<dt><tt>temperature</tt> : T</dt>
|
||||
<dd>The value used to module the next token probabilities. Accepts value > 0.0. Shape is (1)</dd>
|
||||
<dt><tt>length_penalty</tt> (optional) : T</dt>
|
||||
<dd>Exponential penalty to the length. Default value 1.0 means no penalty.Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences.Shape is (1,)</dd>
|
||||
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
|
||||
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
|
||||
<dt><tt>vocab_mask</tt> (optional) : M</dt>
|
||||
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - 3)
|
||||
|
||||
<dl>
|
||||
<dt><tt>sequences</tt> : I</dt>
|
||||
<dd>Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)</dd>
|
||||
<dt><tt>sequences_scores</tt> (optional) : T</dt>
|
||||
<dd>Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)</dd>
|
||||
<dt><tt>scores</tt> (optional) : T</dt>
|
||||
<dd>Processed beam scores for each vocabulary token at each generation step.Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam.Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float), tensor(float16)</dt>
|
||||
<dd>Constrain input and output types to float tensors.</dd>
|
||||
<dt><tt>I</tt> : tensor(int32)</dt>
|
||||
<dd>Constrain to integer types</dd>
|
||||
<dt><tt>M</tt> : tensor(int32)</dt>
|
||||
<dd>Constrain mask to integer types</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.BiasDropout"></a><a name="com.microsoft.biasdropout">**com.microsoft.BiasDropout**</a>
|
||||
|
||||
output, dropout_mask = Dropout(data + bias, ratio) + residual, Intended to specialize the dropout pattern commonly found in transformer models.
|
||||
|
|
|
|||
|
|
@ -377,6 +377,7 @@ Do not modify directly.*
|
|||
|**Operator Domain:** *com.microsoft*||||
|
||||
|Attention|*in* input:**T**<br> *in* weight:**T**<br> *in* bias:**T**<br> *in* mask_index:**M**<br> *in* past:**T**<br> *in* extra_add:**T**<br> *out* output:**T**<br> *out* present:**T**|1+|**T** = tensor(float)|
|
||||
|AttnLSTM|*in* X:**T**<br> *in* W:**T**<br> *in* R:**T**<br> *in* B:**T**<br> *in* sequence_lens:**T1**<br> *in* initial_h:**T**<br> *in* initial_c:**T**<br> *in* P:**T**<br> *in* QW:**T**<br> *in* MW:**T**<br> *in* V:**T**<br> *in* M:**T**<br> *in* memory_seq_lens:**T1**<br> *in* AW:**T**<br> *out* Y:**T**<br> *out* Y_h:**T**<br> *out* Y_c:**T**|1+|**T** = tensor(double), tensor(float)<br/> **T1** = tensor(int32)|
|
||||
|BeamSearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* num_beams:**I**<br> *in* num_return_sequences:**I**<br> *in* temperature:**T**<br> *in* length_penalty:**T**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**M**<br> *out* sequences:**I**<br> *out* sequences_scores:**T**<br> *out* scores:**T**|1+|**T** = tensor(float)|
|
||||
|BiasGelu|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(float)|
|
||||
|BifurcationDetector|*in* src_tokens:**T**<br> *in* cur_tokens:**T**<br> *in* prev_suffix_match_idx:**T**<br> *in* pred_tokens:**T**<br> *out* tokens:**T**<br> *out* suffix_match_idx:**T**|1+|**T** = tensor(int64)|
|
||||
|CDist|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T**|1+|**T** = tensor(double), tensor(float)|
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
|
||||
|
|
@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
// add more kernels here
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GridSample)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Attention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, BeamSearch)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, EmbedLayerNormalization)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, ExpandDims)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>,
|
||||
|
|
|
|||
650
onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Normal file
650
onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
Normal file
|
|
@ -0,0 +1,650 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// there's no way to use a raw pointer as the copy destination with std::copy_n
|
||||
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
|
||||
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996)
|
||||
#endif
|
||||
|
||||
#include <assert.h>
|
||||
#include "core/providers/cpu/controlflow/utils.h"
|
||||
#include "core/providers/cpu/math/top_k.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/framework/session_options.h"
|
||||
#include "core/framework/TensorSeq.h"
|
||||
#include "gsl/gsl"
|
||||
#include "core/providers/cpu/math/softmax_shared.h"
|
||||
#include "beam_search.h"
|
||||
#include "logits_processor.h"
|
||||
#include "sequences.h"
|
||||
#include "dump_tensor.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
BeamSearch, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCpuExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
transformers::BeamSearch<T>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
|
||||
namespace transformers {
|
||||
|
||||
template <typename T>
|
||||
struct BeamSearchState {
|
||||
gsl::span<T> beam_scores; // shape (batch_size, num_beams)
|
||||
gsl::span<T> next_token_logits; // shape (batch_size * num_beams, vocab_size)
|
||||
gsl::span<T> next_token_scores; // shape (batch_size, num_beams * vocab_size)
|
||||
gsl::span<int64_t> next_tokens; // shape (batch_size, 2 * num_beams)
|
||||
gsl::span<int64_t> next_indices; // shape (batch_size, 2 * num_beams)
|
||||
gsl::span<int64_t> next_positions; // shape (batch_size, num_beams). Next position value for position_ids.
|
||||
|
||||
gsl::span<T> scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
|
||||
gsl::span<T> remaining_scores; // subspan that is avaiable for appending next token scores.
|
||||
|
||||
Sequences sequences;
|
||||
|
||||
void Init(AllocatorPtr allocator,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
bool output_scores) {
|
||||
size_t batch_beam_size = SafeInt<size_t>(batch_size) * num_beams;
|
||||
beam_scores = AllocateBuffer<T>(allocator, beam_scores_buffer_, batch_beam_size, true, static_cast<T>(0));
|
||||
|
||||
// Initialize score of first beam of each group with 0 and the rest with -1e9.
|
||||
// This ensures that the beams in the same group don't produce same tokens every time.
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
for (int j = 1; j < num_beams; j++) {
|
||||
beam_scores[i * num_beams + j] = -1e9;
|
||||
}
|
||||
}
|
||||
|
||||
size_t next_token_size = SafeInt<size_t>(batch_beam_size) * vocab_size;
|
||||
next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size, true, static_cast<T>(0));
|
||||
next_token_scores = AllocateBuffer<T>(allocator, next_token_scores_buffer_, next_token_size, true, static_cast<T>(0));
|
||||
|
||||
next_tokens = AllocateBuffer<int64_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size, true, static_cast<int64_t>(0));
|
||||
|
||||
next_indices = AllocateBuffer<int64_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size, true, static_cast<int64_t>(0));
|
||||
|
||||
next_positions = AllocateBuffer<int64_t>(allocator, next_positions_buffer_, batch_beam_size, true, static_cast<int64_t>(0));
|
||||
|
||||
if (output_scores) {
|
||||
size_t elements = SafeInt<size_t>(max_length - sequence_length) * batch_size * num_beams * vocab_size;
|
||||
scores = AllocateBuffer<T>(allocator, scores_buffer_, elements);
|
||||
remaining_scores = scores;
|
||||
}
|
||||
|
||||
// sequences will be initialized later since it has dependency on input_ids
|
||||
}
|
||||
|
||||
private:
|
||||
BufferUniquePtr beam_scores_buffer_;
|
||||
BufferUniquePtr next_token_logits_buffer_;
|
||||
BufferUniquePtr next_token_scores_buffer_;
|
||||
BufferUniquePtr next_tokens_buffer_;
|
||||
BufferUniquePtr next_indices_buffer_;
|
||||
BufferUniquePtr next_positions_buffer_;
|
||||
BufferUniquePtr scores_buffer_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BeamSearchImpl {
|
||||
public:
|
||||
BeamSearchImpl(OpKernelContextInternal& context,
|
||||
const SessionState& session_state,
|
||||
GptSubgraph& gpt_subgraph,
|
||||
concurrency::ThreadPool* thread_pool,
|
||||
void* stream,
|
||||
BeamSearchParameters& params);
|
||||
|
||||
// Initialize by validating all the inputs, and allocating the output tensors.
|
||||
Status Initialize();
|
||||
|
||||
// Execute beam search in iterations util stopping criteria is reached.
|
||||
// In each iteration, GPT subgraph is called, and next token for each sequence is generated.
|
||||
Status Execute(const FeedsFetchesManager& cached_ffm);
|
||||
|
||||
private:
|
||||
// Validate inputs.
|
||||
Status CheckInputs(const OpKernelContextInternal& context);
|
||||
|
||||
// Prepare the inputs for first inference of subgraph
|
||||
void CreateInitialFeeds(gsl::span<int64_t>& next_positions, std::vector<OrtValue>& feeds);
|
||||
|
||||
// Update the input for next iteration.
|
||||
Status UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
gsl::span<int64_t>& next_positions,
|
||||
gsl::span<const int64_t> beam_next_tokens,
|
||||
gsl::span<const int64_t> beam_indices);
|
||||
|
||||
// Process logits and append next tokens to sequences.
|
||||
Status GenerateNextToken(const OrtValue& logits,
|
||||
gsl::span<int64_t>& beam_next_tokens,
|
||||
gsl::span<int64_t>& beam_indices,
|
||||
BeamSearchState<T>& beam_state);
|
||||
|
||||
// Calculate scores from logits, then apply filtering and select next token for each beam.
|
||||
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
|
||||
BeamSearchState<T>& beam_state,
|
||||
AllocatorPtr& allocator);
|
||||
|
||||
OpKernelContextInternal& context_;
|
||||
|
||||
const SessionState& session_state_;
|
||||
|
||||
GptSubgraph& gpt_subgraph_;
|
||||
|
||||
concurrency::ThreadPool* thread_pool_;
|
||||
|
||||
const std::vector<const OrtValue*>& implicit_inputs_;
|
||||
|
||||
// Not used in CPU. Stream is for CUDA only.
|
||||
void* stream_;
|
||||
|
||||
BeamSearchParameters* parameters_;
|
||||
|
||||
LogitsProcessorList<T> logits_processors_;
|
||||
|
||||
std::unique_ptr<BeamSearchScorer<T>> beam_scorer_;
|
||||
|
||||
AllocatorPtr allocator_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
void BeamSearch<T>::Init(const OpKernelInfo& info) {
|
||||
// Make sure the body attribute was present even though we don't need it here.
|
||||
ONNX_NAMESPACE::GraphProto proto;
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("body", &proto).IsOK());
|
||||
ORT_IGNORE_RETURN_VALUE(proto);
|
||||
|
||||
parameters_.ParseFromAttributes(info);
|
||||
|
||||
stream_ = nullptr;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
std::unique_ptr<OpKernel> BeamSearch<T>::Create(const OpKernelInfo& info,
|
||||
void* stream) {
|
||||
auto result = std::make_unique<BeamSearch>(info);
|
||||
result->SetComputeStream(stream);
|
||||
return result;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
common::Status BeamSearch<T>::SetupSubgraphExecutionInfo(const SessionState& session_state,
|
||||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) {
|
||||
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
const auto& node = Node();
|
||||
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
|
||||
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
|
||||
feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
|
||||
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
|
||||
gpt_subgraph_->num_heads,
|
||||
gpt_subgraph_->head_size,
|
||||
gpt_subgraph_->num_layers);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearch<T>::Compute(OpKernelContext* ctx) const {
|
||||
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto* session_state = ctx_internal->SubgraphSessionState("body");
|
||||
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'body' attribute.");
|
||||
ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
||||
BeamSearchParameters parameters = parameters_; // make a copy since we will update the parameters based on inputs later
|
||||
|
||||
BeamSearchImpl<T> impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, stream_, parameters};
|
||||
|
||||
auto status = impl.Initialize();
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
status = impl.Execute(*feeds_fetches_manager_);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
BeamSearchImpl<T>::BeamSearchImpl(OpKernelContextInternal& context,
|
||||
const SessionState& session_state,
|
||||
GptSubgraph& gpt_subgraph,
|
||||
concurrency::ThreadPool* thread_pool,
|
||||
void* stream,
|
||||
BeamSearchParameters& params)
|
||||
: context_(context),
|
||||
session_state_(session_state),
|
||||
gpt_subgraph_(gpt_subgraph),
|
||||
thread_pool_(thread_pool),
|
||||
implicit_inputs_(context_.GetImplicitInputs()),
|
||||
stream_(stream),
|
||||
parameters_(¶ms),
|
||||
allocator_(nullptr) {
|
||||
parameters_->ParseFromInputs(&context);
|
||||
|
||||
allocator_ = session_state.GetExecutionProviders()
|
||||
.Get(onnxruntime::kCpuExecutionProvider)
|
||||
->GetAllocator(0, OrtMemTypeDefault);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::CheckInputs(const OpKernelContextInternal& context) {
|
||||
// Input shapes:
|
||||
// input_ids : (batch_size, sequence_length)
|
||||
// vocab_mask : (vocab_size) or nullptr
|
||||
|
||||
const Tensor* input_ids = context.Input<Tensor>(0);
|
||||
const auto& dims = input_ids->Shape().GetDims();
|
||||
if (dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ",
|
||||
dims.size());
|
||||
}
|
||||
|
||||
const Tensor* vocab_mask = context.Input<Tensor>(8);
|
||||
if (vocab_mask != nullptr) { // vocab_mask is optional
|
||||
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
|
||||
if (vocab_mask_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ",
|
||||
vocab_mask_dims.size());
|
||||
}
|
||||
|
||||
// There is dependency on vocab_size parameter, which shall be set before calling this function.
|
||||
if (static_cast<int>(vocab_mask_dims[0]) != parameters_->vocab_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ",
|
||||
vocab_mask_dims[0]);
|
||||
}
|
||||
|
||||
// store vocab mask in parameters.
|
||||
parameters_->vocab_mask = vocab_mask->DataAsSpan<int32_t>();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::Initialize() {
|
||||
auto status = Status::OK();
|
||||
|
||||
#define CHECK_SCALAR_INPUT(name, index, required) \
|
||||
auto* name##_tensor = context_.Input<Tensor>(index); \
|
||||
if (name##_tensor) { \
|
||||
if (!name##_tensor->Shape().IsScalar()) { \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \
|
||||
name##_tensor->Shape()); \
|
||||
} \
|
||||
} else if (required) { \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \
|
||||
}
|
||||
|
||||
CHECK_SCALAR_INPUT(min_length, 1, false);
|
||||
|
||||
CHECK_SCALAR_INPUT(max_length, 2, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(num_beams, 3, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(num_return_sequences, 4, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(temperature, 5, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(length_penalty, 6, true);
|
||||
|
||||
ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'.");
|
||||
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(context_));
|
||||
|
||||
// This flag will be updated later when the scores output exists.
|
||||
parameters_->output_scores = false;
|
||||
|
||||
// Initialize processsors after CheckInputs so that parameters_->vocab_mask is ready.
|
||||
logits_processors_.Init(*parameters_);
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BeamSearchImpl<T>::CreateInitialFeeds(gsl::span<int64_t>& next_positions, std::vector<OrtValue>& feeds) {
|
||||
const OrtValue* input_ids_value = context_.GetInputOrtValue(0);
|
||||
const Tensor& input_ids = input_ids_value->Get<Tensor>();
|
||||
gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, next_positions, feeds);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::ProcessLogits(
|
||||
const OrtValue& logits,
|
||||
BeamSearchState<T>& beam_state,
|
||||
AllocatorPtr& allocator) {
|
||||
const int64_t batch_beam_size = static_cast<int64_t>(parameters_->BatchBeamSize());
|
||||
const int& vocab_size = parameters_->vocab_size;
|
||||
|
||||
const T* logits_data = logits.Get<Tensor>().Data<T>();
|
||||
|
||||
// Logits has shape (batch_size * num_beams, input_length, vocab_size),
|
||||
// where input_length equals to parameters_->sequence_length for first subgraph call, and 1 for the remaining calls.
|
||||
const TensorShape& logits_shape = logits.Get<Tensor>().Shape();
|
||||
ORT_ENFORCE(logits_shape.NumDimensions() == 3);
|
||||
auto input_length = logits_shape[1];
|
||||
|
||||
// Get logits for the last token:
|
||||
// next_token_logits = logits[:, -1, :], and the result shape is (batch_size * num_beams, vocab_size)
|
||||
// When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1.
|
||||
gsl::span<T>& next_token_logits = beam_state.next_token_logits;
|
||||
if (input_length > 1) {
|
||||
const T* current_logits = logits_data + (input_length - 1) * vocab_size;
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
gsl::span<const T> source(current_logits, vocab_size);
|
||||
gsl::span<T> target = next_token_logits.subspan(i * vocab_size, vocab_size);
|
||||
gsl::copy(source, target);
|
||||
current_logits += input_length * vocab_size;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
//DumpOrtValue("logits", logits);
|
||||
DumpTensor("next_token_logits", next_token_logits.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
|
||||
#endif
|
||||
|
||||
// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
|
||||
gsl::span<T>& next_token_scores = beam_state.next_token_scores;
|
||||
Status status = SoftmaxCPU<T>(batch_beam_size, // rows
|
||||
vocab_size, // elements per row
|
||||
input_length > 1 ? next_token_logits.data() : logits_data,
|
||||
next_token_scores.data(),
|
||||
true,
|
||||
thread_pool_);
|
||||
if (!status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpTensor("next_token_scores after softmax", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
|
||||
#endif
|
||||
|
||||
// Apply all score processors that updates scores
|
||||
logits_processors_.Process(&(beam_state.sequences), next_token_scores);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpTensor("next_token_scores after logits processor", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
|
||||
#endif
|
||||
|
||||
// Add beam score to next token scores. Corresponding python code is like:
|
||||
// next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
||||
// TODO: use thread pool to parrellel
|
||||
int offset = 0;
|
||||
int batch_beam_index = 0;
|
||||
for (int i = 0; i < parameters_->batch_size; i++) {
|
||||
for (int j = 0; j < parameters_->num_beams; j++, batch_beam_index++) {
|
||||
for (int k = 0; k < parameters_->vocab_size; k++, offset++) {
|
||||
next_token_scores[offset] += beam_state.beam_scores[batch_beam_index];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpTensor("next_token_scores after adding beam_scores", next_token_scores.data(), parameters_->batch_size, parameters_->num_beams, vocab_size);
|
||||
#endif
|
||||
|
||||
if (parameters_->output_scores) {
|
||||
// Append next token scores to the scores output.
|
||||
gsl::copy(next_token_scores, beam_state.remaining_scores);
|
||||
beam_state.remaining_scores = beam_state.remaining_scores.subspan(next_token_scores.size());
|
||||
}
|
||||
|
||||
// Apply top-k selection like the following:
|
||||
// next_token_scores = next_token_scores.view(batch_size, num_beams * vocab_size)
|
||||
// next_token_scores, next_tokens = torch.topk(next_token_scores, 2 * num_beams, dim=1, largest=True, sorted=True)
|
||||
int64_t next_token_scores_dims[] = {parameters_->batch_size, parameters_->num_beams * vocab_size};
|
||||
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
OrtValue next_token_scores_value;
|
||||
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value);
|
||||
const Tensor& input = next_token_scores_value.Get<Tensor>();
|
||||
|
||||
const int axis = 1;
|
||||
const unsigned top_k = static_cast<unsigned>(2 * parameters_->num_beams);
|
||||
const bool largest = true;
|
||||
const bool sorted = true; // results returned in sorted order.
|
||||
|
||||
std::unique_ptr<Tensor> topk_scores;
|
||||
std::unique_ptr<Tensor> topk_indices;
|
||||
status = GetTopK<T>(&input, axis, top_k, largest, sorted, allocator, thread_pool_, topk_scores, topk_indices);
|
||||
if (!status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpTensor<T>("topk_scores", *(topk_scores.get()));
|
||||
DumpTensor<int64_t>("topk_indices", *(topk_indices.get()));
|
||||
#endif
|
||||
|
||||
// Convert indices in range [0, num_beams * vocab_size) to token ID of range [0, vocab_size) like the following:
|
||||
// next_indices = (next_tokens / vocab_size).long()
|
||||
// next_tokens = next_tokens % vocab_size
|
||||
gsl::span<const int64_t> next_token_indices = topk_indices->DataAsSpan<int64_t>();
|
||||
offset = 0;
|
||||
for (int i = 0; i < parameters_->batch_size; i++) {
|
||||
for (unsigned int j = 0; j < top_k; j++, offset++) {
|
||||
beam_state.next_indices[offset] = next_token_indices[offset] / vocab_size;
|
||||
beam_state.next_tokens[offset] = next_token_indices[offset] % vocab_size;
|
||||
}
|
||||
}
|
||||
|
||||
gsl::span<const T> next_scores = topk_scores->DataAsSpan<T>();
|
||||
gsl::span<const int64_t> next_tokens(beam_state.next_tokens.data(), beam_state.next_tokens.size());
|
||||
gsl::span<const int64_t> next_indices(beam_state.next_indices.data(), beam_state.next_indices.size());
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpTensor<T>("next_scores before scorer", next_scores.data(), parameters_->batch_size, top_k);
|
||||
DumpTensor<int64_t>("next_tokens before scorer", next_tokens.data(), parameters_->batch_size, top_k);
|
||||
DumpTensor<int64_t>("next_indices before scorer", next_indices.data(), parameters_->batch_size, top_k);
|
||||
#endif
|
||||
|
||||
beam_scorer_->Process(
|
||||
&(beam_state.sequences),
|
||||
next_scores,
|
||||
next_tokens,
|
||||
next_indices);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::GenerateNextToken(
|
||||
const OrtValue& logits,
|
||||
gsl::span<int64_t>& beam_next_tokens,
|
||||
gsl::span<int64_t>& beam_indices,
|
||||
BeamSearchState<T>& beam_state) {
|
||||
// Process logits to get next token scores
|
||||
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, allocator_));
|
||||
|
||||
gsl::span<T>& beam_scores = beam_scorer_->GetNextScores();
|
||||
// It is optional to clone beam_scores. Change it to use same buffer also works:
|
||||
// beam_state.beam_scores = beam_scores
|
||||
// Here we make a copy to reduce the coupling with little cost (the buffer size is small).
|
||||
gsl::copy(beam_scores, beam_state.beam_scores);
|
||||
|
||||
beam_next_tokens = beam_scorer_->GetNextTokens();
|
||||
beam_indices = beam_scorer_->GetNextIndices();
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpTensor<T>("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
DumpTensor<int64_t>("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
DumpTensor<int64_t>("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
#endif
|
||||
|
||||
beam_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
beam_state.sequences.PrintSequences();
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
gsl::span<int64_t>& next_positions,
|
||||
gsl::span<const int64_t> beam_next_tokens,
|
||||
gsl::span<const int64_t> beam_indices) {
|
||||
return gpt_subgraph_.UpdateFeeds(last_outputs, next_inputs, current_length, next_positions,
|
||||
beam_next_tokens, beam_indices, parameters_->num_beams);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::Execute(const FeedsFetchesManager& ffm) {
|
||||
auto status = Status::OK();
|
||||
|
||||
std::vector<int64_t> sequences_dims{parameters_->batch_size, parameters_->num_return_sequences, parameters_->max_length};
|
||||
TensorShape sequences_shape(sequences_dims);
|
||||
Tensor* output_sequences = context_.Output(0, sequences_shape);
|
||||
|
||||
std::vector<int64_t> sequences_scores_dims{parameters_->batch_size, parameters_->num_return_sequences};
|
||||
TensorShape sequences_scores_shape(sequences_scores_dims);
|
||||
Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape);
|
||||
|
||||
std::vector<int64_t> scores_dims{
|
||||
parameters_->max_length - parameters_->sequence_length,
|
||||
parameters_->batch_size, parameters_->num_beams, parameters_->vocab_size};
|
||||
TensorShape scores_shape(scores_dims);
|
||||
Tensor* output_scores = context_.Output(2, scores_shape);
|
||||
|
||||
// Update the flag to indicate whether scores exists in output
|
||||
parameters_->output_scores = (output_scores != nullptr);
|
||||
|
||||
std::vector<OrtValue> feeds;
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
// Initialize resources
|
||||
AllocatorPtr temp_space_allocator;
|
||||
ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator));
|
||||
|
||||
BeamSearchState<T> beam_state;
|
||||
beam_state.Init(temp_space_allocator,
|
||||
parameters_->batch_size,
|
||||
parameters_->num_beams,
|
||||
parameters_->vocab_size,
|
||||
parameters_->sequence_length,
|
||||
parameters_->max_length,
|
||||
parameters_->output_scores);
|
||||
|
||||
beam_scorer_ = std::make_unique<BeamSearchScorer<T>>(parameters_->batch_size,
|
||||
parameters_->num_beams,
|
||||
parameters_->max_length,
|
||||
parameters_->length_penalty,
|
||||
parameters_->early_stopping,
|
||||
parameters_->num_return_sequences,
|
||||
parameters_->pad_token_id,
|
||||
parameters_->eos_token_id);
|
||||
beam_scorer_->Initialize(allocator_, parameters_->sequence_length); // TODO: use temp_space_allocator
|
||||
|
||||
CreateInitialFeeds(beam_state.next_positions, feeds);
|
||||
const OrtValue& input_ids = feeds[0];
|
||||
beam_state.sequences.Init(temp_space_allocator,
|
||||
input_ids,
|
||||
parameters_->BatchBeamSize(),
|
||||
parameters_->sequence_length,
|
||||
parameters_->max_length);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpOrtValue("input_ids", input_ids);
|
||||
DumpOrtValue("position_ids", feeds[1]);
|
||||
DumpOrtValue("attention_mask", feeds[2]);
|
||||
#endif
|
||||
|
||||
int current_length = parameters_->sequence_length;
|
||||
while (current_length < parameters_->max_length) {
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpString("***CurrentLength", std::to_string(current_length), true);
|
||||
#endif
|
||||
|
||||
status = utils::ExecuteSubgraph(session_state_, ffm, feeds, fetches, {},
|
||||
ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger());
|
||||
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
const OrtValue& logits = fetches[0];
|
||||
gsl::span<int64_t> beam_next_tokens;
|
||||
gsl::span<int64_t> beam_indices;
|
||||
ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state));
|
||||
|
||||
// When all batches are finished, stop earlier to avoid wasting computation.
|
||||
if (beam_scorer_->IsDone()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Increase sequence length after a new token is generated.
|
||||
++current_length;
|
||||
|
||||
// Prepare inputs for next round of subgraph call.
|
||||
if (current_length < parameters_->max_length) {
|
||||
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
|
||||
beam_state.next_positions,
|
||||
beam_next_tokens.as_span<const int64_t>(),
|
||||
beam_indices.as_span<const int64_t>()));
|
||||
}
|
||||
fetches.clear();
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
if (current_length - parameters_->sequence_length == 3) { // only dump a few steps.
|
||||
DisableTensorDump();
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
gsl::span<const T> beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
|
||||
beam_scorer_->Finalize(&(beam_state.sequences),
|
||||
beam_scores,
|
||||
output_sequences,
|
||||
output_sequences_scores);
|
||||
|
||||
// Output per token scores
|
||||
if (output_scores != nullptr) {
|
||||
gsl::span<T> target = output_scores->MutableDataAsSpan<T>();
|
||||
gsl::span<const T> source = gsl::span<const T>(beam_state.scores.data(), beam_state.scores.size());
|
||||
assert(target.length() == source.length());
|
||||
gsl::copy(source, target);
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
// Instantiation
|
||||
template class BeamSearchImpl<float>;
|
||||
template class BeamSearch<float>;
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
48
onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Normal file
48
onnxruntime/contrib_ops/cpu/transformers/beam_search.h
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <functional>
|
||||
#include "gsl/gsl"
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/feeds_fetches_manager.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/controlflow/utils.h"
|
||||
#include "beam_search_parameters.h"
|
||||
#include "beam_search_scorer.h"
|
||||
#include "gpt_subgraph.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
template <typename T>
|
||||
class BeamSearch : public controlflow::IControlFlowKernel {
|
||||
public:
|
||||
BeamSearch(const OpKernelInfo& info) : IControlFlowKernel(info) { Init(info); }
|
||||
void Init(const OpKernelInfo& info);
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
|
||||
Status SetupSubgraphExecutionInfo(const SessionState& session_state,
|
||||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) override;
|
||||
|
||||
static std::unique_ptr<OpKernel> Create(const OpKernelInfo& info, void* stream);
|
||||
|
||||
protected:
|
||||
void SetComputeStream(void* stream) { stream_ = stream; }
|
||||
|
||||
private:
|
||||
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
|
||||
std::unique_ptr<GptSubgraph> gpt_subgraph_;
|
||||
FeedsFetchesManager* feeds_fetches_manager_;
|
||||
|
||||
void* stream_;
|
||||
|
||||
BeamSearchParameters parameters_;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "beam_search_parameters.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
constexpr int kMaxSequenceLength = 4096;
|
||||
|
||||
Status BeamSearchParameters::Validate() const {
|
||||
ORT_RETURN_IF(eos_token_id < 0, "eos_token_id is invalid");
|
||||
ORT_RETURN_IF(pad_token_id < 0, "pad_token_id is invalid");
|
||||
ORT_RETURN_IF(min_length >= max_length, "min_length shall be smaller than max_length");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
|
||||
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
|
||||
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
|
||||
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
|
||||
no_repeat_ngram_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_repeat_ngram_size", 0));
|
||||
}
|
||||
|
||||
void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
||||
ORT_ENFORCE(context != nullptr);
|
||||
const Tensor* input_ids = context->Input<Tensor>(0);
|
||||
const auto& dims = input_ids->Shape().GetDims();
|
||||
ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
|
||||
batch_size = static_cast<int>(dims[0]);
|
||||
sequence_length = static_cast<int>(dims[1]);
|
||||
|
||||
auto* max_length_tensor = context->Input<Tensor>(1);
|
||||
max_length = max_length_tensor ? static_cast<int>(*max_length_tensor->Data<int32_t>()) : kMaxSequenceLength;
|
||||
ORT_ENFORCE(max_length > sequence_length, "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")");
|
||||
ORT_ENFORCE(max_length <= kMaxSequenceLength, "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength);
|
||||
|
||||
auto* min_length_tensor = context->Input<Tensor>(2);
|
||||
min_length = min_length_tensor ? static_cast<int>(*min_length_tensor->Data<int32_t>()) : 0;
|
||||
|
||||
auto* num_beams_tensor = context->Input<Tensor>(3);
|
||||
num_beams = num_beams_tensor ? static_cast<int>(*num_beams_tensor->Data<int32_t>()) : 1;
|
||||
// TODO: limit num_beams > 1 when we can have another operator for greedy search.
|
||||
ORT_ENFORCE(num_beams >= 1, "num_beams shall be a positive integer, got ", num_beams);
|
||||
|
||||
auto* num_return_sequences_tensor = context->Input<Tensor>(4);
|
||||
num_return_sequences = num_return_sequences_tensor ? static_cast<int>(*num_return_sequences_tensor->Data<int32_t>()) : 1;
|
||||
ORT_ENFORCE(num_return_sequences >= 1, "num_return_sequences shall be a positive integer, got ", num_return_sequences);
|
||||
ORT_ENFORCE(num_beams >= num_return_sequences, "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
|
||||
|
||||
auto* temperature_tensor = context->Input<Tensor>(5);
|
||||
temperature = temperature_tensor ? static_cast<float>(*temperature_tensor->Data<float>()) : 1;
|
||||
ORT_ENFORCE(temperature > 0.0f, "temperature shall be greater than 0, got ", temperature);
|
||||
|
||||
auto* length_penalty_tensor = context->Input<Tensor>(6);
|
||||
length_penalty = length_penalty_tensor ? static_cast<float>(*length_penalty_tensor->Data<float>()) : 1;
|
||||
|
||||
auto* repetition_penalty_tensor = context->Input<Tensor>(7);
|
||||
repetition_penalty = repetition_penalty_tensor ? static_cast<float>(*repetition_penalty_tensor->Data<float>()) : 1.0f;
|
||||
ORT_ENFORCE(repetition_penalty > 0.0f, "repetition_penalty shall be greater than 0, got ", repetition_penalty);
|
||||
}
|
||||
|
||||
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
|
||||
vocab_size = vocabulary_size;
|
||||
num_heads = heads;
|
||||
head_size = hidden_size_per_head;
|
||||
num_layers = layers;
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
struct BeamSearchParameters {
|
||||
// Parameters from node attributes
|
||||
int eos_token_id;
|
||||
int pad_token_id;
|
||||
int no_repeat_ngram_size;
|
||||
bool early_stopping;
|
||||
|
||||
// Parameters from inputs
|
||||
int min_length;
|
||||
int max_length;
|
||||
int num_beams;
|
||||
int num_return_sequences;
|
||||
float temperature;
|
||||
float length_penalty;
|
||||
float repetition_penalty;
|
||||
int batch_size; // deduce from first dimension of input_ids
|
||||
int sequence_length; // deduce from second dimension of input_ids
|
||||
|
||||
gsl::span<const int32_t> vocab_mask;
|
||||
|
||||
// Parameters from outputs.
|
||||
bool output_scores; // whether scores existed in output
|
||||
|
||||
// Parameters from subgraph.
|
||||
int vocab_size;
|
||||
// Below are used in CPU, reserved for CUDA.
|
||||
int num_heads;
|
||||
int head_size;
|
||||
int num_layers;
|
||||
|
||||
Status Validate() const;
|
||||
|
||||
int BatchBeamSize() const { return batch_size * num_beams; }
|
||||
|
||||
void ParseFromAttributes(const OpKernelInfo& info);
|
||||
|
||||
void ParseFromInputs(OpKernelContext* context);
|
||||
|
||||
void SetSubgraphParameters(int vocab_size, int num_heads, int head_size, int num_layers);
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
285
onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
Normal file
285
onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc
Normal file
|
|
@ -0,0 +1,285 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <queue>
|
||||
#include <math.h>
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/cpu/rnn/rnn_helpers.h"
|
||||
#include "beam_search_scorer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
using ::onnxruntime::rnn::detail::Allocate;
|
||||
|
||||
template <typename T>
|
||||
BeamHypotheses<T>::BeamHypotheses(int num_beams, T length_penalty, bool early_stopping)
|
||||
: num_beams_(num_beams),
|
||||
length_penalty_(length_penalty),
|
||||
early_stopping_(early_stopping),
|
||||
worst_score_(1e9) {}
|
||||
|
||||
template <typename T>
|
||||
void BeamHypotheses<T>::Add(gsl::span<const int64_t>& hypothesis, T sum_logprobs) {
|
||||
auto length = hypothesis.size();
|
||||
// TODO: when T is FP16, compute in FP32, then cast result back to FP16. length_penalty_ might also be float.
|
||||
T score = sum_logprobs / pow(static_cast<T>(length), length_penalty_);
|
||||
|
||||
if (this->Size() < num_beams_ || score > worst_score_) {
|
||||
HypothesisScore<T> item(hypothesis, score);
|
||||
beams_.push(item);
|
||||
if (this->Size() > num_beams_) {
|
||||
beams_.pop();
|
||||
}
|
||||
worst_score_ = beams_.top().score;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool BeamHypotheses<T>::IsDone(T best_sum_logprobs, int current_length) {
|
||||
// If there are enough hypotheses and that none of the hypotheses being generated can become better
|
||||
// than the worst one in the heap, then we are done with this sentence.
|
||||
|
||||
if (Size() < num_beams_)
|
||||
return false;
|
||||
|
||||
if (early_stopping_)
|
||||
return true;
|
||||
|
||||
T current_score = best_sum_logprobs / pow(static_cast<T>(current_length), length_penalty_);
|
||||
return worst_score_ >= current_score;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BeamHypotheses<T>::Output(
|
||||
int top_k,
|
||||
int max_length,
|
||||
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, shape (num_return_sequences, max_length)
|
||||
gsl::span<T>& sequences_scores) // buffer of shape (num_return_sequences) or empty
|
||||
{
|
||||
ORT_ENFORCE(top_k <= Size());
|
||||
int remove_count = Size() - top_k;
|
||||
for (int i = 0; i < remove_count; i++) {
|
||||
beams_.pop();
|
||||
}
|
||||
|
||||
// Since pop get the worst sequence, so output it in the reverse order.
|
||||
// The frist (worst) beam shall be put at the last position among top_k sequences.
|
||||
int index = top_k - 1;
|
||||
while (!beams_.empty()) {
|
||||
auto item = beams_.top();
|
||||
gsl::span<const int64_t>& source = item.hypothesis;
|
||||
gsl::span<int32_t> target = sequences.subspan(index * max_length, max_length);
|
||||
|
||||
// Note that word_ids might be less than max_length.
|
||||
// Since the sequences has been filled with pad token ID, so padding is not needed here.
|
||||
// Since data type need cast from int64_t to int32_t, we cannot use gsl::copy(word_ids, sequence) here.
|
||||
for (size_t i = 0; i < source.length(); i++) {
|
||||
target[i] = static_cast<int32_t>(source[i]);
|
||||
}
|
||||
|
||||
if (!sequences_scores.empty())
|
||||
sequences_scores[index] = item.score;
|
||||
|
||||
beams_.pop();
|
||||
index--;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
BeamSearchScorer<T>::BeamSearchScorer(int batch_size,
|
||||
int num_beams,
|
||||
int max_length,
|
||||
T length_penalty,
|
||||
bool early_stopping,
|
||||
int num_return_sequences,
|
||||
int pad_token_id,
|
||||
int eos_token_id)
|
||||
: batch_size_(batch_size),
|
||||
num_beams_(num_beams),
|
||||
max_length_(max_length),
|
||||
num_beam_hyps_to_keep_(num_return_sequences),
|
||||
pad_token_id_(pad_token_id),
|
||||
eos_token_id_(eos_token_id),
|
||||
hypothesis_buffer_length_(0),
|
||||
hypothesis_buffer_offset_(0) {
|
||||
for (int batch = 0; batch < batch_size; batch++) {
|
||||
beam_hyps.push_back(BeamHypotheses(num_beams, length_penalty, early_stopping));
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
bool BeamSearchScorer<T>::IsDone() {
|
||||
for (int batch = 0; batch < batch_size_; batch++) {
|
||||
if (!done_[batch])
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BeamSearchScorer<T>::Initialize(AllocatorPtr& allocator, int sequence_length){
|
||||
ORT_ENFORCE(next_beam_scores_.empty()); // Make sure this is called only once.
|
||||
|
||||
size_t batch_beam_size = static_cast<size_t>(batch_size_ * num_beams_);
|
||||
const bool no_fill = false; // do not fill values after allocation
|
||||
next_beam_scores_ = Allocate<T>(allocator, batch_beam_size, next_beam_scores_ptr_, no_fill);
|
||||
next_beam_tokens_ = Allocate<int64_t>(allocator, batch_beam_size, next_beam_tokens_ptr_, no_fill);
|
||||
next_beam_indices_ = Allocate<int64_t>(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill);
|
||||
|
||||
// Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
|
||||
int buffer_per_beam = (max_length_ * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2;
|
||||
hypothesis_buffer_length_ = batch_beam_size * static_cast<size_t>(buffer_per_beam);
|
||||
hypothesis_buffer_ = Allocate<int64_t>(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill);
|
||||
|
||||
done_ = Allocate<bool>(allocator, static_cast<size_t>(batch_size_), done_ptr_, no_fill);
|
||||
std::fill_n(done_.data(), done_.size(), false);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BeamSearchScorer<T>::Process(ISequences* sequences,
|
||||
gsl::span<const T>& next_scores,
|
||||
gsl::span<const int64_t>& next_tokens,
|
||||
gsl::span<const int64_t>& next_indices) {
|
||||
// Sequences shape is (batch_size * num_beams, total_sequence_length)
|
||||
// It contains word ID of whole sequence generated so far.
|
||||
// It is different from subgraph input_ids, which only need one word when past state is not empty.
|
||||
|
||||
const int sequence_length = sequences->GetSequenceLength();
|
||||
|
||||
ORT_ENFORCE(next_scores.size() == next_tokens.size());
|
||||
ORT_ENFORCE(next_scores.size() == next_indices.size());
|
||||
|
||||
for (int batch = 0; batch < batch_size_; batch++) {
|
||||
BeamHypotheses<T>& beam_hyp = beam_hyps[batch];
|
||||
if (done_[batch]) {
|
||||
ORT_ENFORCE(beam_hyp.Size() >= num_beams_, "Batch can only be done if all beams have been generated");
|
||||
|
||||
// Pad the batch.
|
||||
for (int j = 0; j < num_beams_; j++) {
|
||||
next_beam_scores_[batch * num_beams_ + j] = 0.0f;
|
||||
next_beam_tokens_[batch * num_beams_ + j] = pad_token_id_;
|
||||
next_beam_indices_[batch * num_beams_ + j] = 0;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// Next tokens for this sentence.
|
||||
int beam_idx = 0;
|
||||
int top_k = 2 * num_beams_;
|
||||
for (int j = 0; j < top_k; j++) {
|
||||
int64_t next_token = next_tokens[batch * top_k + j];
|
||||
T next_score = next_scores[batch * top_k + j];
|
||||
int64_t next_index = next_indices[batch * top_k + j];
|
||||
|
||||
int batch_beam_idx = batch * num_beams_ + static_cast<int>(next_index);
|
||||
// Add to generated hypotheses if end of sentence.
|
||||
if ((eos_token_id_ >= 0) && (next_token == eos_token_id_)) {
|
||||
bool is_beam_token_worse_than_top_num_beams = (j >= num_beams_);
|
||||
if (is_beam_token_worse_than_top_num_beams) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Clone the sequence and append to buffer.
|
||||
gsl::span<const int64_t> src = sequences->GetSequence(batch_beam_idx);
|
||||
auto clone = hypothesis_buffer_.subspan(hypothesis_buffer_offset_, sequence_length);
|
||||
gsl::copy(src, clone);
|
||||
hypothesis_buffer_offset_ += sequence_length;
|
||||
auto sequence = clone.template as_span<const int64_t>();
|
||||
beam_hyp.Add(sequence, next_score);
|
||||
} else {
|
||||
// Add next predicted token since it is not eos_token.
|
||||
next_beam_scores_[batch * num_beams_ + beam_idx] = next_score;
|
||||
next_beam_tokens_[batch * num_beams_ + beam_idx] = next_token;
|
||||
next_beam_indices_[batch * num_beams_ + beam_idx] = batch_beam_idx;
|
||||
++beam_idx;
|
||||
}
|
||||
|
||||
// Once the beam for next step is full, don't add more tokens to it.
|
||||
if (beam_idx == num_beams_)
|
||||
break;
|
||||
}
|
||||
|
||||
ORT_ENFORCE(beam_idx == num_beams_);
|
||||
ORT_ENFORCE(hypothesis_buffer_offset_ <= batch_size_ * num_beams_ * max_length_);
|
||||
|
||||
// Check if we are done so that we can save a pad step if all(done)
|
||||
if (!done_[batch]) {
|
||||
gsl::span<const T> topk_scores = next_scores.subspan(batch * num_beams_, top_k);
|
||||
const T* best_sum_logprobs = std::max_element(topk_scores.begin(), topk_scores.end());
|
||||
if (beam_hyp.IsDone(*best_sum_logprobs, sequence_length)) {
|
||||
done_[batch] = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BeamSearchScorer<T>::Finalize(ISequences* sequences,
|
||||
gsl::span<const T>& final_beam_scores,
|
||||
Tensor* output_sequences,
|
||||
Tensor* output_sequence_scores) {
|
||||
ORT_ENFORCE(sequences != nullptr);
|
||||
ORT_ENFORCE(output_sequences != nullptr);
|
||||
|
||||
// Finalize all open beam hypotheses and add to generated hypotheses.
|
||||
for (int batch_index = 0; batch_index < batch_size_; batch_index++) {
|
||||
BeamHypotheses<T>& beam_hyp = beam_hyps[batch_index];
|
||||
if (done_[batch_index]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
for (int beam_index = 0; beam_index < num_beams_; beam_index++) {
|
||||
int batch_beam_index = batch_index * num_beams_ + beam_index;
|
||||
T final_score = final_beam_scores[batch_beam_index];
|
||||
auto final_tokens = sequences->GetSequence(batch_beam_index);
|
||||
beam_hyp.Add(final_tokens, final_score);
|
||||
}
|
||||
}
|
||||
|
||||
// Word IDs of each sequence, with shape (batch_size * num_return_sequences, max_sequence_length).
|
||||
gsl::span<int32_t> output = output_sequences->MutableDataAsSpan<int32_t>();
|
||||
|
||||
// Fill output sequences with pad token ID so that we do not need append it later.
|
||||
std::fill_n(output.data(), output.size(), pad_token_id_);
|
||||
|
||||
// Score of each sequence, with shape (batch_size * num_return_sequences).
|
||||
gsl::span<T> sequence_scores;
|
||||
if (output_sequence_scores != nullptr) {
|
||||
sequence_scores = output_sequence_scores->MutableDataAsSpan<T>();
|
||||
}
|
||||
|
||||
// Span is empty when output_sequence_scores is NULL.
|
||||
gsl::span<T> batch_sequence_score;
|
||||
|
||||
// Select the best hypotheses according to number of sequences to return.
|
||||
for (int batch_index = 0; batch_index < batch_size_; batch_index++) {
|
||||
BeamHypotheses<T>& beam_hyp = beam_hyps[batch_index];
|
||||
|
||||
const int num_return_sequences = num_beam_hyps_to_keep_;
|
||||
auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, num_return_sequences * max_length_);
|
||||
|
||||
if (output_sequence_scores != nullptr) {
|
||||
batch_sequence_score = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences);
|
||||
}
|
||||
|
||||
beam_hyp.Output(
|
||||
num_return_sequences,
|
||||
max_length_,
|
||||
batch_output,
|
||||
batch_sequence_score);
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiation
|
||||
template class HypothesisScoreCompare<float>;
|
||||
template class BeamHypotheses<float>;
|
||||
template class BeamSearchScorer<float>;
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
144
onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
Normal file
144
onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// The implementation is based on huggingface transformers generation_beam_search.py
|
||||
|
||||
#pragma once
|
||||
#include <queue>
|
||||
#include <math.h>
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "sequences.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
// Interface for all scorers for beam search or beam sample.
|
||||
template <typename T>
|
||||
class IBeamScorer {
|
||||
public:
|
||||
virtual ~IBeamScorer() {}
|
||||
|
||||
virtual void Initialize(AllocatorPtr& allocator, int sequence_length) = 0;
|
||||
|
||||
virtual void Process(ISequences* sequences,
|
||||
gsl::span<const T>& next_scores,
|
||||
gsl::span<const int64_t>& next_tokens,
|
||||
gsl::span<const int64_t>& next_indices) = 0;
|
||||
|
||||
virtual void Finalize(ISequences* sequences,
|
||||
gsl::span<const T>& final_beam_scores,
|
||||
Tensor* output_sequences,
|
||||
Tensor* output_sequence_scores) = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct HypothesisScore {
|
||||
HypothesisScore(gsl::span<const int64_t>& _hypothesis, T _score)
|
||||
: hypothesis(_hypothesis), score(_score) {}
|
||||
|
||||
gsl::span<const int64_t> hypothesis;
|
||||
T score;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class HypothesisScoreCompare {
|
||||
public:
|
||||
bool operator()(const HypothesisScore<T>& a, const HypothesisScore<T>& b) {
|
||||
return a.score > b.score;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BeamHypotheses {
|
||||
public:
|
||||
BeamHypotheses(int num_beams, T length_penalty, bool early_stopping);
|
||||
|
||||
// Number of hypotheses
|
||||
int Size() { return static_cast<int>(beams_.size()); }
|
||||
|
||||
// Add a new hypothesis
|
||||
void Add(gsl::span<const int64_t>& hypothesis, T sum_logprobs);
|
||||
|
||||
bool IsDone(T best_sum_logprobs, int current_length);
|
||||
|
||||
// Output results. Note that it will clear all beams.
|
||||
void Output(int top_k, // number of sequences to return
|
||||
int max_length, // max sequence length
|
||||
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, with shape (num_return_sequences, max_length)
|
||||
gsl::span<T>& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
|
||||
|
||||
private:
|
||||
int num_beams_;
|
||||
T length_penalty_;
|
||||
bool early_stopping_;
|
||||
T worst_score_;
|
||||
std::priority_queue<HypothesisScore<T>, std::vector<HypothesisScore<T>>, HypothesisScoreCompare<T>> beams_; // min-heap for top k
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BeamSearchScorer : public IBeamScorer<T> {
|
||||
public:
|
||||
BeamSearchScorer(int batch_size,
|
||||
int num_beams,
|
||||
int max_length,
|
||||
T length_penalty,
|
||||
bool early_stopping,
|
||||
int num_return_sequences,
|
||||
int pad_token_id,
|
||||
int eos_token_id);
|
||||
|
||||
void Initialize(AllocatorPtr& allocator, int sequence_length) override;
|
||||
|
||||
void Process(ISequences* sequences,
|
||||
gsl::span<const T>& next_scores,
|
||||
gsl::span<const int64_t>& next_tokens,
|
||||
gsl::span<const int64_t>& next_indices) override;
|
||||
|
||||
void Finalize(ISequences* sequences,
|
||||
gsl::span<const T>& final_beam_scores,
|
||||
Tensor* output_sequences,
|
||||
Tensor* output_sequence_scores) override;
|
||||
|
||||
bool IsDone();
|
||||
|
||||
gsl::span<T>& GetNextScores() { return next_beam_scores_; }
|
||||
gsl::span<int64_t>& GetNextTokens() { return next_beam_tokens_; }
|
||||
gsl::span<int64_t>& GetNextIndices() { return next_beam_indices_; }
|
||||
|
||||
private:
|
||||
int batch_size_;
|
||||
int num_beams_;
|
||||
int max_length_;
|
||||
int num_beam_hyps_to_keep_;
|
||||
int pad_token_id_;
|
||||
int eos_token_id_;
|
||||
|
||||
// TODO: use ORT allocator to avoid allocating from heap directly
|
||||
std::vector<BeamHypotheses<T>> beam_hyps; // List of batch result of beam search. Its shape is (batch_size)
|
||||
|
||||
IAllocatorUniquePtr<bool> done_ptr_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size).
|
||||
gsl::span<bool> done_;
|
||||
|
||||
IAllocatorUniquePtr<T> next_beam_scores_ptr_;
|
||||
gsl::span<T> next_beam_scores_;
|
||||
|
||||
IAllocatorUniquePtr<int64_t> next_beam_tokens_ptr_;
|
||||
gsl::span<int64_t> next_beam_tokens_;
|
||||
|
||||
IAllocatorUniquePtr<int64_t> next_beam_indices_ptr_;
|
||||
gsl::span<int64_t> next_beam_indices_;
|
||||
|
||||
IAllocatorUniquePtr<int64_t> hypothesis_buffer_ptr_; // Allocated buffer to hold all hypotheses
|
||||
gsl::span<int64_t> hypothesis_buffer_; // Span of the allocated buffer
|
||||
size_t hypothesis_buffer_length_; // Total number of elements
|
||||
int hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer.
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
73
onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc
Normal file
73
onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "dump_tensor.h"
|
||||
#include "core/platform/env.h"
|
||||
#include "core/platform/env_var_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
namespace dump_tensor_env_vars {
|
||||
constexpr const char* kDumpBeamSearch = "ORT_DUMP_BEAM_SEARCH";
|
||||
}
|
||||
|
||||
#ifdef NDEBUG
|
||||
bool g_enable_tensor_dump = false;
|
||||
#else
|
||||
bool g_enable_tensor_dump = true;
|
||||
#endif
|
||||
|
||||
void DumpOrtValue(const char* name, const OrtValue& value) {
|
||||
if (!g_enable_tensor_dump)
|
||||
return;
|
||||
std::cout << std::string(name) << "\n";
|
||||
const Tensor& tensor = value.Get<Tensor>();
|
||||
MLDataType dataType = tensor.DataType();
|
||||
if (dataType == DataTypeImpl::GetType<float>()) {
|
||||
DumpTensor<float>(nullptr, tensor);
|
||||
} else if (dataType == DataTypeImpl::GetType<int32_t>()) {
|
||||
DumpTensor<int32_t>(nullptr, tensor);
|
||||
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
|
||||
DumpTensor<int64_t>(nullptr, tensor);
|
||||
} else {
|
||||
std::cout << "not float/int32/int64";
|
||||
}
|
||||
}
|
||||
|
||||
void ConfigureTensorDump() {
|
||||
const auto parsed = ParseEnvironmentVariable<bool>(dump_tensor_env_vars::kDumpBeamSearch);
|
||||
if (parsed.has_value()) {
|
||||
g_enable_tensor_dump = *parsed;
|
||||
}
|
||||
}
|
||||
|
||||
void DisableTensorDump() {
|
||||
g_enable_tensor_dump = false;
|
||||
}
|
||||
|
||||
void DumpString(const char* name, int index, bool end_line) {
|
||||
if (!g_enable_tensor_dump)
|
||||
return;
|
||||
std::cout << std::string(name) << "[" << index << "]";
|
||||
|
||||
if (end_line) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void DumpString(const char* name, std::string value, bool end_line) {
|
||||
if (!g_enable_tensor_dump)
|
||||
return;
|
||||
|
||||
std::cout << std::string(name) << "=" << value;
|
||||
|
||||
if (end_line) {
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
148
onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h
Normal file
148
onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <iomanip>
|
||||
#include <string>
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
|
||||
#ifndef NDEBUG
|
||||
//#define DEBUG_BEAM_SEARCH 1 // uncomment it for debugging beam search
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
#define MAX_ROW_OR_COLUMN 8
|
||||
|
||||
#define SKIP_IF_MORE_THAN(row_or_column_size, i, max_n, new_line) \
|
||||
if (row_or_column_size > max_n && i >= max_n / 2 && i + max_n / 2 < row_or_column_size) { \
|
||||
if (i == max_n / 2) { \
|
||||
std::cout << ", ..."; \
|
||||
if (new_line) \
|
||||
std::cout << std::endl; \
|
||||
} \
|
||||
continue; \
|
||||
}
|
||||
|
||||
#define SKIP_IF_TOO_MANY(row_or_column_size, i, new_line) SKIP_IF_MORE_THAN(row_or_column_size, i, MAX_ROW_OR_COLUMN, new_line)
|
||||
|
||||
extern bool g_enable_tensor_dump; // global variance to turn on/off dump
|
||||
|
||||
template <typename T>
|
||||
void PrintValue(const T& value) {
|
||||
if (std::is_floating_point<T>::value)
|
||||
std::cout << std::setprecision(8) << value;
|
||||
else
|
||||
std::cout << value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void DumpTensor(const char* name, const Tensor& tensor) {
|
||||
if (!g_enable_tensor_dump)
|
||||
return;
|
||||
|
||||
if (nullptr != name) {
|
||||
std::cout << std::string(name) << std::endl;
|
||||
}
|
||||
|
||||
const auto& shape = tensor.Shape();
|
||||
auto num_items = shape.Size();
|
||||
|
||||
if (num_items == 0) {
|
||||
std::cout << "no data";
|
||||
return;
|
||||
}
|
||||
|
||||
size_t num_dims = shape.NumDimensions();
|
||||
size_t num_rows = 1;
|
||||
if (num_dims > 1) {
|
||||
num_rows = static_cast<size_t>(shape[0]);
|
||||
}
|
||||
|
||||
size_t row_size = num_items / num_rows;
|
||||
|
||||
auto data = tensor.DataAsSpan<T>();
|
||||
|
||||
for (size_t row = 0; row < num_rows; ++row) {
|
||||
SKIP_IF_TOO_MANY(num_rows, row, true);
|
||||
std::cout << "[" << row << "]:";
|
||||
for (size_t i = 0; i < row_size; ++i) {
|
||||
SKIP_IF_TOO_MANY(row_size, i, false);
|
||||
|
||||
if (i > 0)
|
||||
std::cout << ", ";
|
||||
|
||||
PrintValue(data[row * row_size + i]);
|
||||
}
|
||||
std::cout << "\n";
|
||||
}
|
||||
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void DumpOrtValue(const char* name, const OrtValue& value);
|
||||
|
||||
template <typename T>
|
||||
void DumpTensor(const char* name, const T* tensor, int dim0, int dim1) {
|
||||
if (!g_enable_tensor_dump)
|
||||
return;
|
||||
|
||||
if (nullptr != name) {
|
||||
std::cout << std::string(name) << std::endl;
|
||||
}
|
||||
|
||||
for (int i = 0; i < dim0; i++) {
|
||||
SKIP_IF_TOO_MANY(dim0, i, true);
|
||||
std::cout << "[" << i << "]:";
|
||||
for (int j = 0; j < dim1; j++) {
|
||||
SKIP_IF_TOO_MANY(dim1, j, false);
|
||||
if (j > 0)
|
||||
std::cout << ", ";
|
||||
T value = tensor[i * dim1 + j];
|
||||
PrintValue<T>(value);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
void DumpString(const char* name, int index, bool end_line);
|
||||
|
||||
void DumpString(const char* name, std::string value, bool end_line);
|
||||
|
||||
template <typename T>
|
||||
void DumpTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) {
|
||||
if (!g_enable_tensor_dump)
|
||||
return;
|
||||
|
||||
if (nullptr != name) {
|
||||
std::cout << std::string(name) << std::endl;
|
||||
}
|
||||
|
||||
for (int i = 0; i < dim0; i++) {
|
||||
SKIP_IF_TOO_MANY(dim0, i, true);
|
||||
for (int j = 0; j < dim1; j++) {
|
||||
SKIP_IF_TOO_MANY(dim1, j, true);
|
||||
std::cout << "[" << i << "][" << j << "]:";
|
||||
for (int k = 0; k < dim2; k++) {
|
||||
SKIP_IF_TOO_MANY(dim2, k, false);
|
||||
if (k > 0)
|
||||
std::cout << ", ";
|
||||
T value = tensor[i * dim1 * dim2 + j * dim2 + k];
|
||||
PrintValue<T>(value);
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
std::cout << std::endl;
|
||||
}
|
||||
|
||||
void ConfigureTensorDump();
|
||||
|
||||
void DisableTensorDump();
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
451
onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc
Normal file
451
onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc
Normal file
|
|
@ -0,0 +1,451 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// there's no way to use a raw pointer as the copy destination with std::copy_n
|
||||
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
|
||||
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996)
|
||||
#endif
|
||||
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "gsl/gsl"
|
||||
#include "gpt_subgraph.h"
|
||||
#include "dump_tensor.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(pop)
|
||||
#endif
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
GptSubgraph::GptSubgraph(
|
||||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in)
|
||||
: node(node_in), attribute(attribute_name), subgraph(subgraph_in), allocator_(nullptr) {
|
||||
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
|
||||
|
||||
auto& subgraph_inputs = subgraph.GetInputs();
|
||||
auto& subgraph_outputs = subgraph.GetOutputs();
|
||||
|
||||
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
|
||||
// outputs: logits, present_0, present_1, ...
|
||||
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
|
||||
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
|
||||
|
||||
// CheckSubgraph will verify inputs and outputs later.
|
||||
subgraph_input_names.reserve(num_subgraph_inputs);
|
||||
for (int i = 0; i < num_subgraph_inputs; ++i) {
|
||||
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
|
||||
}
|
||||
|
||||
subgraph_output_names.reserve(num_subgraph_outputs);
|
||||
for (int i = 0; i < num_subgraph_outputs; ++i) {
|
||||
subgraph_output_names.push_back(subgraph_outputs[i]->Name());
|
||||
}
|
||||
}
|
||||
|
||||
Status GptSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) {
|
||||
ORT_RETURN_IF(num_subgraph_outputs <= 1,
|
||||
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in inputs and outputs).");
|
||||
|
||||
ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2,
|
||||
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2");
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ",
|
||||
subgraph_inputs[0]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ",
|
||||
subgraph_inputs[1]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ",
|
||||
subgraph_inputs[2]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ",
|
||||
subgraph_inputs[3]->Name());
|
||||
|
||||
// Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads.
|
||||
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape();
|
||||
ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ",
|
||||
past_shape->dim_size());
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2,
|
||||
"subgraph past state dimension 0 shall have length of 2");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for number of heads");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0,
|
||||
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
|
||||
|
||||
// check subgraph outputs
|
||||
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ",
|
||||
subgraph_outputs[0]->Name());
|
||||
|
||||
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ",
|
||||
subgraph_outputs[1]->Name());
|
||||
|
||||
// Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size.
|
||||
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
|
||||
ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ",
|
||||
logits_shape->dim_size());
|
||||
|
||||
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
|
||||
|
||||
// Save parameters related to the subgraph.
|
||||
num_heads = static_cast<int>(past_shape->dim(2).dim_value());
|
||||
head_size = static_cast<int>(past_shape->dim(4).dim_value());
|
||||
vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
|
||||
num_layers = static_cast<int>(subgraph_outputs.size()) - 1;
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
|
||||
"subgraph input 0 (input_ids) shall have int64 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT64,
|
||||
"subgraph input 1 (position_ids) shall have int64 type");
|
||||
// TODO: support float16
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
|
||||
"subgraph input 2 (attention_mask) shall have float type");
|
||||
ORT_RETURN_IF(subgraph_inputs[3]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
|
||||
"subgraph input 3 (past_0) shall have float type");
|
||||
ORT_RETURN_IF(subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
|
||||
"subgraph output 0 (logits) shall have float type");
|
||||
ORT_RETURN_IF(subgraph_outputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT,
|
||||
"subgraph output 1 (present_0) shall have float type");
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GptSubgraph::Setup(const SessionState& session_state,
|
||||
const SessionState& subgraph_session_state) {
|
||||
session_state_ = &session_state;
|
||||
subgraph_session_state_ = &subgraph_session_state;
|
||||
|
||||
std::vector<std::string> feed_names;
|
||||
feed_names.reserve(num_subgraph_inputs + num_implicit_inputs);
|
||||
|
||||
// First, get the location of input_ids of current operator.
|
||||
const auto& node_inputs = node.InputDefs();
|
||||
const OrtMemoryInfo& input_ids_location = utils::FindMemoryInfoForValue(session_state, node_inputs[0]->Name());
|
||||
|
||||
// position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
|
||||
// as we skip them when we call FindDevicesForValues, and default them to be in the same device as input_ids
|
||||
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
|
||||
|
||||
for (auto& entry : node.ImplicitInputDefs()) {
|
||||
feed_names.push_back(entry->Name());
|
||||
}
|
||||
|
||||
std::vector<OrtDevice> feed_locations;
|
||||
feed_locations.resize(feed_names.size());
|
||||
|
||||
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
|
||||
if (i >= subgraph_input_names.size()) { // implicit inputs
|
||||
const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]);
|
||||
feed_locations[i] = location.device;
|
||||
} else {
|
||||
feed_locations[i] = input_ids_location.device;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<FeedsFetchesManager> ffm;
|
||||
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names,
|
||||
subgraph_session_state.GetOrtValueNameIdxMap(), ffm));
|
||||
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
|
||||
|
||||
// setup the locations where we want the subgraph output to end up on
|
||||
std::vector<const OrtMemoryInfo*> fetch_locations;
|
||||
fetch_locations.reserve(num_subgraph_outputs);
|
||||
|
||||
// past state need to be where we can feed them in to the next iteration, so set the fetch location to match the feed location.
|
||||
for (int i = 0; i < num_subgraph_outputs; ++i) {
|
||||
fetch_locations.push_back(&input_ids_location);
|
||||
}
|
||||
|
||||
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
|
||||
|
||||
feeds_fetches_manager_ = std::move(ffm);
|
||||
|
||||
// Check subgraph only need once so put in Setup function.
|
||||
auto& inputs = subgraph.GetInputs();
|
||||
auto& outputs = subgraph.GetOutputs();
|
||||
ORT_RETURN_IF_ERROR(Validate(inputs, outputs));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void GptSubgraph::CreateInitialFeeds(
|
||||
const Tensor& input_ids,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int64_t>& next_positions,
|
||||
std::vector<OrtValue>& feeds) {
|
||||
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
|
||||
|
||||
// Subgraph inputs:
|
||||
// input_ids: shape (B, S) wher B is batch size, and S is sequence length
|
||||
// position_ids: shape (B, S)
|
||||
// attention_mask: shape (B, P+S), where past_sequence_length (P) is 0
|
||||
// After expansion, their shapes will become (B, M*S), where M is num_beams.
|
||||
|
||||
// Allocate subgraph inputs to be same device as input_ids
|
||||
AllocatorPtr alloactor = session_state_->GetAllocator(input_ids.Location());
|
||||
|
||||
// Store allocator, which is needed in ExpandInputs.
|
||||
allocator_ = alloactor;
|
||||
|
||||
const TensorShape& input_ids_shape = input_ids.Shape();
|
||||
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
|
||||
const int64_t& batch_size = input_ids_shape[0];
|
||||
const int64_t& sequence_length = input_ids_shape[1];
|
||||
|
||||
// Allocate position_ids and attention_mask based on shape of input_ids
|
||||
auto element_type = DataTypeImpl::GetType<int64_t>();
|
||||
|
||||
// input_ids for subgraph is int64, so we need Cast input_ids from int32 to int64.
|
||||
OrtValue subgraph_input_ids;
|
||||
// Current shape is (batch_size, sequence_length)
|
||||
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, subgraph_input_ids);
|
||||
|
||||
int64_t* subgraph_input_data = subgraph_input_ids.GetMutable<Tensor>()->MutableData<int64_t>();
|
||||
const int32_t* source = input_ids.Data<int32_t>();
|
||||
int64_t* target = subgraph_input_data;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
for (int j = 0; j < sequence_length; j++, source++, target++) {
|
||||
*target = static_cast<int64_t>(*source);
|
||||
}
|
||||
}
|
||||
|
||||
OrtValue position_ids;
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, position_ids);
|
||||
|
||||
OrtValue attention_mask;
|
||||
auto mask_type = DataTypeImpl::GetType<float>();
|
||||
Tensor::InitOrtValue(mask_type, input_ids_shape, alloactor, attention_mask);
|
||||
|
||||
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
|
||||
// Set position id to be 0 for pad tokens, and cumulated sum of mask in a batch for other tokens
|
||||
float* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<float>();
|
||||
int64_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int64_t>();
|
||||
source = input_ids.Data<int32_t>();
|
||||
float* mask = mask_data;
|
||||
int64_t* position = position_data;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
int64_t abs_position = 0;
|
||||
for (int j = 0; j < sequence_length; j++, source++, mask++, position++) {
|
||||
if (*source == pad_token_id) {
|
||||
*mask = 0.0f;
|
||||
*position = 0;
|
||||
} else {
|
||||
*mask = 1.0f;
|
||||
*position = abs_position;
|
||||
abs_position++;
|
||||
}
|
||||
}
|
||||
for (int k = 0; k < num_beams; k++) {
|
||||
next_positions[i * num_beams + k] = abs_position;
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize empty past state
|
||||
auto past_type = DataTypeImpl::GetType<float>();
|
||||
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size};
|
||||
TensorShape past_shape(&past_state_dims[0], 5);
|
||||
OrtValue empty_past;
|
||||
Tensor::InitOrtValue(past_type, past_shape, allocator_, empty_past);
|
||||
|
||||
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask
|
||||
// TODO: Try expand inputs/outputs after first subgraph call instead. That may get better peroformance, but more complex to implement.
|
||||
OrtValue expanded_input_ids = ExpandInputs(subgraph_input_ids, num_beams);
|
||||
OrtValue expanded_position_ids = ExpandInputs(position_ids, num_beams);
|
||||
OrtValue expanded_attention_mask = ExpandInputs(attention_mask, num_beams);
|
||||
|
||||
// The ordering is the same as used in Setup
|
||||
feeds.reserve(num_subgraph_inputs + num_implicit_inputs);
|
||||
feeds.push_back(expanded_input_ids);
|
||||
feeds.push_back(expanded_position_ids);
|
||||
feeds.push_back(expanded_attention_mask);
|
||||
|
||||
// The remaing inputs are past state.
|
||||
for (int i = 3; i < num_subgraph_inputs; ++i) {
|
||||
feeds.push_back(empty_past);
|
||||
}
|
||||
|
||||
// pass in implicit inputs
|
||||
for (const auto* entry : implicit_inputs) {
|
||||
feeds.push_back(*entry);
|
||||
}
|
||||
}
|
||||
|
||||
OrtValue GptSubgraph::ExpandInputs(const OrtValue& input, int num_beams) const {
|
||||
// Input shape (batch_size, sequence_length)
|
||||
// Output shape (batch_size * num_beams, sequence_length)
|
||||
if (num_beams == 1)
|
||||
return input;
|
||||
|
||||
const TensorShape& input_shape = input.Get<Tensor>().Shape();
|
||||
const int64_t& batch_size = input_shape[0];
|
||||
const int64_t& sequence_length = input_shape[1];
|
||||
|
||||
int64_t dims[] = {batch_size * num_beams, sequence_length};
|
||||
TensorShape expanded_shape(&dims[0], 2);
|
||||
|
||||
OrtValue expanded;
|
||||
MLDataType element_type = input.Get<Tensor>().DataType();
|
||||
Tensor::InitOrtValue(element_type, expanded_shape, allocator_, expanded);
|
||||
|
||||
if (element_type == DataTypeImpl::GetType<int64_t>()) {
|
||||
const int64_t* input_data = input.Get<Tensor>().Data<int64_t>();
|
||||
int64_t* expanded_data = expanded.GetMutable<Tensor>()->MutableData<int64_t>();
|
||||
int64_t* target = expanded_data;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
for (int j = 0; j < num_beams; j++) {
|
||||
memcpy(target, input_data + i * sequence_length, sizeof(int64_t) * sequence_length);
|
||||
target += sequence_length;
|
||||
}
|
||||
}
|
||||
} else if (element_type == DataTypeImpl::GetType<float>()) {
|
||||
const float* input_data = input.Get<Tensor>().Data<float>();
|
||||
float* expanded_data = expanded.GetMutable<Tensor>()->MutableData<float>();
|
||||
float* target = expanded_data;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
for (int j = 0; j < num_beams; j++) {
|
||||
memcpy(target, input_data + i * sequence_length, sizeof(float) * sequence_length);
|
||||
target += sequence_length;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return expanded;
|
||||
}
|
||||
|
||||
// TODO: support float16
|
||||
void GptSubgraph::PickPastState(const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
gsl::span<const int64_t>& beam_indices) {
|
||||
for (int i = 3; i < num_subgraph_inputs; ++i) {
|
||||
const OrtValue& present = last_outputs[i - 2]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64)
|
||||
const TensorShape& past_shape = present.Get<Tensor>().Shape();
|
||||
|
||||
// Create a tensor with same shape.
|
||||
OrtValue past;
|
||||
auto past_type = DataTypeImpl::GetType<float>();
|
||||
Tensor::InitOrtValue(past_type, past_shape, allocator_, past);
|
||||
|
||||
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
|
||||
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
|
||||
|
||||
gsl::span<float> past_span = past.GetMutable<Tensor>()->MutableDataAsSpan<float>();
|
||||
gsl::span<const float> present_span = present.Get<Tensor>().DataAsSpan<float>();
|
||||
for (gsl::index j = 0; j < beam_indices.length(); j++) {
|
||||
int64_t beam_index = beam_indices[j];
|
||||
gsl::span<const float> present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<const float> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
|
||||
|
||||
gsl::span<float> past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<float> past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
|
||||
gsl::copy(present_key, past_key);
|
||||
gsl::copy(present_value, past_value);
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
if (i == 3) // only dump past_0
|
||||
{
|
||||
DumpString("past_key of beam", static_cast<int>(j), true);
|
||||
DumpTensor<float>(nullptr, past_key.data(), 1, static_cast<int>(block_size_per_beam));
|
||||
|
||||
DumpString("past_value of beam", static_cast<int>(j), true);
|
||||
DumpTensor<float>(nullptr, past_value.data(), 1, static_cast<int>(block_size_per_beam));
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
next_inputs[i] = past;
|
||||
}
|
||||
}
|
||||
|
||||
Status GptSubgraph::UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
gsl::span<int64_t>& next_positions,
|
||||
gsl::span<const int64_t> beam_next_tokens,
|
||||
gsl::span<const int64_t> beam_indices,
|
||||
int num_beams) {
|
||||
// last_outputs: logits, present_0, present_1, ...
|
||||
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
|
||||
|
||||
// The following updates inputs for subgraph
|
||||
// TODO: Reuse buffer for input_ids and position_ids to reduce memory allocation.
|
||||
|
||||
// Update input_ids with next tokens.
|
||||
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
|
||||
int64_t dims[] = {batch_beam_size, 1};
|
||||
TensorShape input_ids_shape(&dims[0], 2);
|
||||
auto element_type = DataTypeImpl::GetType<int64_t>();
|
||||
OrtValue input_ids;
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, input_ids);
|
||||
int64_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int64_t>();
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
input_ids_data[i] = beam_next_tokens[i];
|
||||
}
|
||||
next_inputs[0] = input_ids;
|
||||
|
||||
// Update position IDs
|
||||
OrtValue position_ids;
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, allocator_, position_ids);
|
||||
int64_t* position_data = position_ids.GetMutable<Tensor>()->MutableData<int64_t>();
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
position_data[i] = next_positions[i];
|
||||
next_positions[i]++;
|
||||
}
|
||||
next_inputs[1] = position_ids;
|
||||
|
||||
// Update attention mask
|
||||
const OrtValue& old_mask = next_inputs[2];
|
||||
const float* old_mask_data = old_mask.Get<Tensor>().Data<float>();
|
||||
int64_t mask_dims[] = {batch_beam_size, current_length};
|
||||
TensorShape mask_shape(&mask_dims[0], 2);
|
||||
OrtValue attention_mask;
|
||||
auto mask_type = DataTypeImpl::GetType<float>();
|
||||
Tensor::InitOrtValue(mask_type, mask_shape, allocator_, attention_mask);
|
||||
float* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<float>();
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
for (int j = 0; j < current_length - 1; j++) {
|
||||
mask_data[i * current_length + j] = old_mask_data[i * (current_length - 1) + j];
|
||||
}
|
||||
mask_data[i * current_length + current_length - 1] = 1.0f;
|
||||
}
|
||||
next_inputs[2] = attention_mask;
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpOrtValue("input_ids", input_ids);
|
||||
DumpOrtValue("position_ids", position_ids);
|
||||
DumpOrtValue("attention_mask", attention_mask);
|
||||
#endif
|
||||
|
||||
// Update past state
|
||||
if (num_beams == 1) {
|
||||
// feed present_* output to past_* inputs one by one
|
||||
for (int i = 3; i < num_subgraph_inputs; ++i) {
|
||||
next_inputs[i] = last_outputs[i - 2];
|
||||
}
|
||||
} else {
|
||||
PickPastState(last_outputs, next_inputs, beam_indices);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
82
onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h
Normal file
82
onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h
Normal file
|
|
@ -0,0 +1,82 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
//#include <functional>
|
||||
#include "gsl/gsl"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/feeds_fetches_manager.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
// A class for GPT-2 subgraph inputs and outputs preparation.
|
||||
struct GptSubgraph {
|
||||
GptSubgraph(
|
||||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in);
|
||||
|
||||
const onnxruntime::Node& node; // node that contains the subgraph
|
||||
const std::string& attribute; // attribute of th node that contains the subgraph. Not used yet.
|
||||
const GraphViewer& subgraph; // the subgraph
|
||||
|
||||
int num_implicit_inputs;
|
||||
|
||||
int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience.
|
||||
int num_subgraph_outputs; // same as subgraph_output_names.size()
|
||||
|
||||
std::vector<std::string> subgraph_input_names;
|
||||
std::vector<std::string> subgraph_output_names;
|
||||
|
||||
// Parameters deduced from the subgraph
|
||||
int num_heads;
|
||||
int head_size;
|
||||
int vocab_size;
|
||||
int num_layers;
|
||||
|
||||
// Setup exectuion
|
||||
Status Setup(const SessionState& session_state,
|
||||
const SessionState& subgraph_session_state);
|
||||
|
||||
// Create inputs for first inference of subgraph.
|
||||
void CreateInitialFeeds(
|
||||
const Tensor& input_ids,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int64_t>& next_positions,
|
||||
std::vector<OrtValue>& feeds);
|
||||
|
||||
Status UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
gsl::span<int64_t>& next_positions,
|
||||
gsl::span<const int64_t> beam_next_tokens,
|
||||
gsl::span<const int64_t> beam_indices,
|
||||
int num_beams);
|
||||
|
||||
FeedsFetchesManager* GetFeedsFetchesManager() const { return feeds_fetches_manager_.get(); }
|
||||
|
||||
protected:
|
||||
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs);
|
||||
|
||||
OrtValue ExpandInputs(const OrtValue& input, int num_beams) const;
|
||||
|
||||
void PickPastState(const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
gsl::span<const int64_t>& beam_indices);
|
||||
|
||||
AllocatorPtr allocator_;
|
||||
const SessionState* session_state_;
|
||||
const SessionState* subgraph_session_state_;
|
||||
std::unique_ptr<FeedsFetchesManager> feeds_fetches_manager_;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
192
onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
Normal file
192
onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
#include <assert.h>
|
||||
#include "logits_processor.h"
|
||||
#include "dump_tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
template <typename T>
|
||||
gsl::span<T> NextTokenScores<T>::GetScores(int batch_beam_index) {
|
||||
assert(batch_beam_index >= 0 && batch_beam_index < batch_beam_size);
|
||||
return scores.subspan(batch_beam_index * vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void NextTokenScores<T>::SetScore(int token_id, T score) {
|
||||
assert(token_id >= 0 && token_id < vocab_size);
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
scores[i * vocab_size + token_id] = score;
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
template <typename T>
|
||||
void DumpScores(const char* name, gsl::span<T>& scores) {
|
||||
DumpString(name, 0, true);
|
||||
ORT_UNUSED_PARAMETER(scores);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Interface for all scorers for beam search or beam sample.
|
||||
template <typename T>
|
||||
MinLengthLogitsProcessor<T>::MinLengthLogitsProcessor(int min_length, int eos_token_id)
|
||||
: min_length_(min_length), eos_token_id_(eos_token_id) {}
|
||||
|
||||
template <typename T>
|
||||
void MinLengthLogitsProcessor<T>::Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) {
|
||||
if (sequences->GetSequenceLength() < min_length_) {
|
||||
next_token_scores.SetScore(eos_token_id_, std::numeric_limits<T>::lowest());
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpScores("MinLengthLogitsProcessor", next_token_scores.scores);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RepetitionPenaltyLogitsProcessor<T>::RepetitionPenaltyLogitsProcessor(float penalty) : penalty_(penalty) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RepetitionPenaltyLogitsProcessor<T>::Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) {
|
||||
const int batch_beam_size = next_token_scores.batch_beam_size;
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
gsl::span<T> beam_token_scores = next_token_scores.GetScores(i);
|
||||
gsl::span<const int64_t> sequence = sequences->GetSequence(i);
|
||||
|
||||
// Find unique word IDs in sequence.
|
||||
std::unordered_set<int64_t> unique_word_ids;
|
||||
for (const auto& word_id : sequence) {
|
||||
unique_word_ids.insert(word_id);
|
||||
}
|
||||
|
||||
for (const int64_t word_id : unique_word_ids) {
|
||||
T score = beam_token_scores[word_id];
|
||||
|
||||
// If score < 0, then repetition penalty > 1.0 has to multiplied to reduce the previous token probability,
|
||||
// This assumes that scores are either positive (like ctrl) or negative (like GPT-2), but not a mixture.
|
||||
beam_token_scores[word_id] = (score < 0 ? score * penalty_ : score / penalty_);
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpScores("RepetitionPenaltyLogitsProcessor", next_token_scores.scores);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
NoRepeatNGramLogitsProcessor<T>::NoRepeatNGramLogitsProcessor(int ngram_size) : ngram_size_(ngram_size) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void NoRepeatNGramLogitsProcessor<T>::Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) {
|
||||
if (ngram_size_ == 0 || ngram_size_ > sequences->GetSequenceLength()) {
|
||||
return;
|
||||
}
|
||||
|
||||
const gsl::index prefix_length = static_cast<gsl::index>(ngram_size_ - 1);
|
||||
int batch_beam_size = next_token_scores.batch_beam_size;
|
||||
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
gsl::span<T> beam_token_scores = next_token_scores.GetScores(i);
|
||||
gsl::span<const int64_t> sequence = sequences->GetSequence(i);
|
||||
|
||||
gsl::span<const int64_t> prefix = sequence.subspan(sequence.length() - prefix_length);
|
||||
ORT_ENFORCE(prefix.length() == prefix_length);
|
||||
|
||||
std::unordered_set<int64_t> blocked_word_ids;
|
||||
for (int j = 0; j <= static_cast<int>(sequence.length()) - ngram_size_; j++) {
|
||||
// Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length)
|
||||
// TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching.
|
||||
if (ngram_size_ == 1 || prefix == sequence.subspan(j, prefix_length)) {
|
||||
blocked_word_ids.insert(sequence[j + prefix_length]);
|
||||
}
|
||||
}
|
||||
|
||||
for (const int64_t word_id : blocked_word_ids) {
|
||||
beam_token_scores[word_id] = std::numeric_limits<T>::lowest();
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpScores("NoRepeatNGramLogitsProcessor", next_token_scores.scores);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
VocabMaskLogitsProcessor<T>::VocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask) : vocab_mask_(vocab_mask) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void VocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
|
||||
NextTokenScores<T>& next_token_scores) {
|
||||
assert(!vocab_mask_.empty());
|
||||
|
||||
// Process vocabulary mask and set tokens with mask value 0 to -inf.
|
||||
T* p = next_token_scores.scores.data();
|
||||
// next_token_scores shape (batch_size * num_beams, vocab_size)
|
||||
// vocab_mask shape (vocab_size). TODO: support shape (batch_size, vocab_size)
|
||||
for (int i = 0; i < next_token_scores.batch_beam_size; i++) {
|
||||
for (int j = 0; j < next_token_scores.vocab_size; j++, p++) {
|
||||
if (vocab_mask_[j] == 0) {
|
||||
*p = std::numeric_limits<T>::lowest();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
DumpScores("VocabMaskLogitsProcessor", next_token_scores.scores);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LogitsProcessorList<T>::Init(const BeamSearchParameters& parameters) {
|
||||
processor_list_.clear();
|
||||
|
||||
if (parameters.repetition_penalty != 1.0f) { // 1.0 means no penalty
|
||||
repetition_penalty_processor_ = std::make_unique<RepetitionPenaltyLogitsProcessor<T>>(parameters.repetition_penalty);
|
||||
processor_list_.push_back(repetition_penalty_processor_.get());
|
||||
}
|
||||
|
||||
if (parameters.no_repeat_ngram_size > 0) {
|
||||
no_repeat_ngram_processor_ = std::make_unique<NoRepeatNGramLogitsProcessor<T>>(parameters.no_repeat_ngram_size);
|
||||
processor_list_.push_back(no_repeat_ngram_processor_.get());
|
||||
}
|
||||
|
||||
if (!parameters.vocab_mask.empty()) {
|
||||
vocab_mask_processor_ = std::make_unique<VocabMaskLogitsProcessor<T>>(parameters.vocab_mask);
|
||||
processor_list_.push_back(vocab_mask_processor_.get());
|
||||
}
|
||||
|
||||
if (parameters.min_length > 0) {
|
||||
min_length_processor_ = std::make_unique<MinLengthLogitsProcessor<T>>(parameters.min_length, parameters.eos_token_id);
|
||||
processor_list_.push_back(min_length_processor_.get());
|
||||
}
|
||||
|
||||
batch_beam_size_ = parameters.BatchBeamSize();
|
||||
vocab_size_ = parameters.vocab_size;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LogitsProcessorList<T>::Process(const ISequences* sequences,
|
||||
gsl::span<T>& next_token_scores) {
|
||||
NextTokenScores<T> input_scores = {next_token_scores, batch_beam_size_, vocab_size_};
|
||||
for (size_t i = 0; i < processor_list_.size(); i++) {
|
||||
processor_list_[i]->Process(sequences, input_scores);
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiation
|
||||
template class MinLengthLogitsProcessor<float>;
|
||||
template class RepetitionPenaltyLogitsProcessor<float>;
|
||||
template class NoRepeatNGramLogitsProcessor<float>;
|
||||
template class VocabMaskLogitsProcessor<float>;
|
||||
template class LogitsProcessorList<float>;
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
99
onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
Normal file
99
onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
#pragma once
|
||||
#include "sequences.h"
|
||||
#include "beam_search_parameters.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
template <typename T>
|
||||
struct NextTokenScores {
|
||||
gsl::span<T>& scores;
|
||||
int batch_beam_size;
|
||||
int vocab_size;
|
||||
|
||||
gsl::span<T> GetScores(int batch_beam_index);
|
||||
|
||||
void SetScore(int token_id, T score);
|
||||
};
|
||||
|
||||
// Interface for all scorers for beam search or beam sample.
|
||||
template <typename T>
|
||||
class ILogitsProcessor {
|
||||
public:
|
||||
virtual ~ILogitsProcessor() {}
|
||||
|
||||
virtual void Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) = 0;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class MinLengthLogitsProcessor : public ILogitsProcessor<T> {
|
||||
public:
|
||||
MinLengthLogitsProcessor(int min_length, int eos_token_id);
|
||||
|
||||
void Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) override;
|
||||
|
||||
private:
|
||||
int min_length_;
|
||||
int eos_token_id_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class RepetitionPenaltyLogitsProcessor : public ILogitsProcessor<T> {
|
||||
public:
|
||||
RepetitionPenaltyLogitsProcessor(float penalty);
|
||||
|
||||
void Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) override;
|
||||
|
||||
private:
|
||||
float penalty_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class NoRepeatNGramLogitsProcessor : public ILogitsProcessor<T> {
|
||||
public:
|
||||
NoRepeatNGramLogitsProcessor(int ngram_size);
|
||||
|
||||
void Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) override;
|
||||
|
||||
private:
|
||||
int ngram_size_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class VocabMaskLogitsProcessor : public ILogitsProcessor<T> {
|
||||
public:
|
||||
VocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask);
|
||||
|
||||
void Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) override;
|
||||
|
||||
private:
|
||||
gsl::span<const int32_t> vocab_mask_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class LogitsProcessorList {
|
||||
public:
|
||||
LogitsProcessorList() = default ;
|
||||
void Init(const BeamSearchParameters& parameters);
|
||||
void Process(const ISequences* sequences, gsl::span<T>& next_token_scores);
|
||||
|
||||
private:
|
||||
int batch_beam_size_;
|
||||
int vocab_size_;
|
||||
std::vector<ILogitsProcessor<T>*> processor_list_;
|
||||
|
||||
std::unique_ptr<RepetitionPenaltyLogitsProcessor<T>> repetition_penalty_processor_;
|
||||
std::unique_ptr<NoRepeatNGramLogitsProcessor<T>> no_repeat_ngram_processor_;
|
||||
std::unique_ptr<VocabMaskLogitsProcessor<T>> vocab_mask_processor_;
|
||||
std::unique_ptr<MinLengthLogitsProcessor<T>> min_length_processor_;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
76
onnxruntime/contrib_ops/cpu/transformers/sequences.cc
Normal file
76
onnxruntime/contrib_ops/cpu/transformers/sequences.cc
Normal file
|
|
@ -0,0 +1,76 @@
|
|||
#include "sequences.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
void Sequences::Init(AllocatorPtr allocator, const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length) {
|
||||
size_t sequences_size = SafeInt<size_t>(batch_beam_size) * max_length;
|
||||
size_t buffer_size = sequences_size + sequences_size;
|
||||
gsl::span<int64_t> buffer = AllocateBuffer<int64_t>(allocator, sequences_space_buffer_, buffer_size, true, static_cast<int64_t>(0));
|
||||
|
||||
sequences[0] = buffer.subspan(0, sequences_size);
|
||||
sequences[1] = buffer.subspan(sequences_size);
|
||||
|
||||
// Copy input_ids to sequences[0].
|
||||
gsl::span<const int64_t> input = input_ids.Get<Tensor>().DataAsSpan<int64_t>();
|
||||
gsl::span<int64_t> output = sequences[0];
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
gsl::span<const int64_t> source = input.subspan(i * sequence_length, sequence_length);
|
||||
gsl::span<int64_t> target = output.subspan(i * max_length, sequence_length);
|
||||
gsl::copy(source, target);
|
||||
}
|
||||
current_sequences_buffer = 0;
|
||||
|
||||
batch_beam_size_ = batch_beam_size;
|
||||
max_length_ = max_length;
|
||||
current_length_ = sequence_length;
|
||||
}
|
||||
|
||||
gsl::span<const int64_t> Sequences::GetSequence(int beam_index) const {
|
||||
gsl::span<const int64_t> buffer(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size());
|
||||
gsl::span<const int64_t> sequence = buffer.subspan(beam_index * max_length_, current_length_);
|
||||
return sequence;
|
||||
}
|
||||
|
||||
int Sequences::GetSequenceLength() const {
|
||||
return current_length_;
|
||||
}
|
||||
|
||||
void Sequences::PrintSequences() {
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
for (int i = 0; i < batch_beam_size_; i++) {
|
||||
gsl::span<const int64_t> sequence = GetSequence(i);
|
||||
DumpString("sequences", i, false);
|
||||
DumpTensor<int64_t>(nullptr, sequence.data(), 1, current_length_);
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
||||
void Sequences::AppendNextTokenToSequences(
|
||||
gsl::span<int64_t>& beam_indices,
|
||||
gsl::span<int64_t>& beam_next_tokens) {
|
||||
gsl::span<const int64_t> input(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size());
|
||||
gsl::span<int64_t> output = sequences[1 - current_sequences_buffer];
|
||||
|
||||
for (int i = 0; i < batch_beam_size_; i++) {
|
||||
int beam_index = static_cast<int>(beam_indices[i]);
|
||||
gsl::span<const int64_t> source = input.subspan(beam_index * max_length_, current_length_);
|
||||
gsl::span<int64_t> target = output.subspan(i * max_length_, current_length_);
|
||||
gsl::copy(source, target);
|
||||
}
|
||||
|
||||
// Append next token to each beam.
|
||||
for (int i = 0; i < batch_beam_size_; i++) {
|
||||
output[i * max_length_ + current_length_] = beam_next_tokens[i];
|
||||
}
|
||||
|
||||
++current_length_;
|
||||
|
||||
// Rotate buffer for next round.
|
||||
current_sequences_buffer = 1 - current_sequences_buffer;
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
80
onnxruntime/contrib_ops/cpu/transformers/sequences.h
Normal file
80
onnxruntime/contrib_ops/cpu/transformers/sequences.h
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
#pragma once
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
class ISequences {
|
||||
public:
|
||||
virtual ~ISequences() {}
|
||||
virtual gsl::span<const int64_t> GetSequence(int beam_index) const = 0;
|
||||
virtual int GetSequenceLength() const = 0;
|
||||
};
|
||||
|
||||
// This class keeps track of sequences generated.
|
||||
class Sequences : public ISequences {
|
||||
public:
|
||||
Sequences() {}
|
||||
|
||||
// Initialize the sequence with initial input_ids and related parameters.
|
||||
void Init(AllocatorPtr allocator, const OrtValue& input_ids, int batch_beam_size, int sequence_length, int max_length);
|
||||
|
||||
// Returns a sequence of word IDs for a given beam index ( beam_index < batch_beam_size).
|
||||
gsl::span<const int64_t> GetSequence(int beam_index) const override;
|
||||
|
||||
// Returns current sequence length.
|
||||
int GetSequenceLength() const override;
|
||||
|
||||
// Print the sequences to StdOut in debug mode
|
||||
void PrintSequences();
|
||||
|
||||
// Select sequences based on beam indices, then append next token to selected sequences.
|
||||
void AppendNextTokenToSequences(
|
||||
gsl::span<int64_t>& beam_indices,
|
||||
gsl::span<int64_t>& beam_next_tokens);
|
||||
|
||||
private:
|
||||
gsl::span<int64_t> sequences_space; // shape (2, batch_size, num_beams, max_seq_length)
|
||||
BufferUniquePtr sequences_space_buffer_;
|
||||
|
||||
// Two buffers of shape (batch_size, num_beams, max_seq_length) to store sequences.
|
||||
// At each time, there is only one buffer is active. The other one will be active in next token.
|
||||
// Each AppendNextTokenToSequences call will trigger a rotation of active buffer.
|
||||
gsl::span<int64_t> sequences[2];
|
||||
|
||||
// Index (either 0 or 1) of two buffers that is currently is active.
|
||||
int current_sequences_buffer;
|
||||
|
||||
int batch_beam_size_;
|
||||
int max_length_;
|
||||
int current_length_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
gsl::span<T> AllocateBuffer(AllocatorPtr allocator,
|
||||
BufferUniquePtr& buffer,
|
||||
size_t elements,
|
||||
bool fill = false,
|
||||
T fill_value = T{}) {
|
||||
size_t bytes = SafeInt<size_t>(sizeof(T)) * elements;
|
||||
void* data = allocator->Alloc(bytes);
|
||||
BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
|
||||
buffer = std::move(temp_buffer);
|
||||
T* first = reinterpret_cast<T*>(buffer.get());
|
||||
auto span = gsl::make_span(first, elements);
|
||||
|
||||
if (fill) {
|
||||
std::fill_n(first, elements, fill_value);
|
||||
}
|
||||
|
||||
return span;
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -568,6 +568,139 @@ void DecoderAttentionTypeAndShapeInference(ONNX_NAMESPACE::InferenceContext& ctx
|
|||
}
|
||||
}
|
||||
|
||||
bool ParseScalar(const TensorProto* initializer, int& value) {
|
||||
std::vector<int32_t> parsed_data;
|
||||
if (initializer->data_type() == TensorProto::INT32) {
|
||||
const auto& data = ParseData<int32_t>(initializer);
|
||||
parsed_data.insert(parsed_data.end(), data.begin(), data.end());
|
||||
|
||||
if (parsed_data.size() == 1) {
|
||||
value = parsed_data[0];
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
void BeamSearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Type inference
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
// Here we assume that the third output exist only if second output exists.
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 1);
|
||||
if (ctx.getNumOutputs() > 2) {
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 5, 2);
|
||||
}
|
||||
}
|
||||
|
||||
// Shape inference
|
||||
// input 0 (input_ids) shape: (batch_size, sequence_length)
|
||||
// output 0 (sequences) shape: (batch_size, num_return_sequences, max_length)
|
||||
// output 1 (sequences_scores) shape: (batch_size, num_return_sequences)
|
||||
// output 2 (scores) shape: (max_length - sequence_length, batch_size, num_beams, vocab_size)
|
||||
if (!hasInputShape(ctx, 0)) {
|
||||
return;
|
||||
}
|
||||
auto& input_ids_shape = getInputShape(ctx, 0);
|
||||
auto& input_ids_dims = input_ids_shape.dim();
|
||||
if (input_ids_dims.size() != 2) {
|
||||
fail_shape_inference("Inputs 0 shall be 2 dimensions");
|
||||
}
|
||||
if (!(input_ids_dims[0].has_dim_value() && input_ids_dims[1].has_dim_value())) {
|
||||
return;
|
||||
}
|
||||
|
||||
int64_t batch_size = input_ids_dims[0].dim_value();
|
||||
int64_t sequence_length = input_ids_dims[1].dim_value();
|
||||
|
||||
const auto max_length = ctx.getInputData(1);
|
||||
const auto num_beams = ctx.getInputData(3);
|
||||
const auto num_return_sequences = ctx.getInputData(4);
|
||||
if (num_beams == nullptr || max_length == nullptr || num_return_sequences == nullptr) { // not initializer
|
||||
return;
|
||||
}
|
||||
|
||||
int max_length_value = 0;
|
||||
if (!ParseScalar(max_length, max_length_value) || max_length_value <= 0) {
|
||||
fail_shape_inference("Failed to parse max_length or it is not positive integer scalar");
|
||||
}
|
||||
|
||||
int num_beams_value = 0;
|
||||
if (!ParseScalar(num_beams, num_beams_value) || num_beams_value <= 0) {
|
||||
fail_shape_inference("Failed to parse num_beams or it is not positive integer scalar");
|
||||
}
|
||||
|
||||
int num_return_sequences_value = 0;
|
||||
if (!ParseScalar(num_return_sequences, num_return_sequences_value) || num_return_sequences_value <= 0) {
|
||||
fail_shape_inference("Failed to parse num_return_sequences or it is not positive integer scalar");
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::TensorShapeProto sequences_shape;
|
||||
sequences_shape.add_dim()->set_dim_value(batch_size);
|
||||
sequences_shape.add_dim()->set_dim_value(num_return_sequences_value);
|
||||
sequences_shape.add_dim()->set_dim_value(max_length_value);
|
||||
updateOutputShape(ctx, 0, sequences_shape);
|
||||
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
ONNX_NAMESPACE::TensorShapeProto sequences_scores_shape;
|
||||
sequences_shape.add_dim()->set_dim_value(batch_size);
|
||||
sequences_shape.add_dim()->set_dim_value(num_return_sequences_value);
|
||||
updateOutputShape(ctx, 1, sequences_scores_shape);
|
||||
|
||||
if (ctx.getNumOutputs() > 2) {
|
||||
ONNX_NAMESPACE::TensorShapeProto scores_shape;
|
||||
scores_shape.add_dim()->set_dim_value(max_length_value - sequence_length);
|
||||
scores_shape.add_dim()->set_dim_value(batch_size);
|
||||
scores_shape.add_dim()->set_dim_value(num_beams_value);
|
||||
scores_shape.add_dim(); // vocab_size is unknown
|
||||
updateOutputShape(ctx, 2, scores_shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void RegisterTextGenerationSchemas() {
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(BeamSearch)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetDoc("Beam Search for text generation. Supports GPT-2 decoder.")
|
||||
.Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT)
|
||||
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
|
||||
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr(
|
||||
"body",
|
||||
"The GPT-2 subgraph with input_ids, position_ids, attention_mask, past_0, past_1, ... as inputs, and logits, present_0, present_1, ... as output",
|
||||
AttributeProto::GRAPH)
|
||||
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
|
||||
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
|
||||
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
|
||||
.Input(3, "num_beams", "Number of beams for beam search. 1 means no beam search. Shape is (1)", "I")
|
||||
.Input(4, "num_return_sequences", "The number of returned sequences in the batch. Shape is (1)", "I")
|
||||
.Input(5, "temperature", "The value used to module the next token probabilities. Accepts value > 0.0. Shape is (1)", "T")
|
||||
.Input(6, "length_penalty",
|
||||
"Exponential penalty to the length. Default value 1.0 means no penalty."
|
||||
"Value > 1.0 encourages longer sequences, while values < 1.0 produces shorter sequences."
|
||||
"Shape is (1,)",
|
||||
"T", OpSchema::Optional)
|
||||
.Input(7, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
|
||||
.Input(8, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "M", OpSchema::Optional)
|
||||
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, num_return_sequences, max_sequence_length)", "I")
|
||||
.Output(1, "sequences_scores", "Final beam score of the generated sequences. Shape is (batch_size, num_return_sequences)", "T", OpSchema::Optional)
|
||||
.Output(2, "scores",
|
||||
"Processed beam scores for each vocabulary token at each generation step."
|
||||
"Beam scores consisting of log softmax scores for each vocabulary token and sum of log softmax of previously generated tokens in this beam."
|
||||
"Shape is (max_length - sequence_length, batch_size, num_beams, vocab_size)",
|
||||
"T", OpSchema::Optional)
|
||||
.TypeConstraint("T", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types")
|
||||
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask to integer types")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
BeamSearchShapeInference(ctx);
|
||||
});
|
||||
}
|
||||
|
||||
void RegisterBertSchemas() {
|
||||
static const char* Attention_ver1_doc = R"DOC(
|
||||
Multi-Head Self Attention that can be either unidirectional (like GPT-2) or bidirectional (like BERT).
|
||||
|
|
@ -750,8 +883,7 @@ Some boolean parameters are passed by runtime input for generic purpose
|
|||
.TypeConstraint("B", {"tensor(bool)"}, "Constrain key_padding_mask to bool tensors.")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
DecoderAttentionTypeAndShapeInference(ctx);
|
||||
});
|
||||
|
||||
});
|
||||
|
||||
static const char* EmbedLayerNormalization_ver1_doc = R"DOC(
|
||||
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
|
||||
|
|
@ -3146,6 +3278,7 @@ It's an extension of Gelu. It takes the sum of input A and bias input B as the i
|
|||
|
||||
RegisterNhwcSchemas();
|
||||
RegisterBertSchemas();
|
||||
RegisterTextGenerationSchemas();
|
||||
|
||||
#ifdef BUILD_MS_EXPERIMENTAL_OPS
|
||||
onnxruntime::signal::RegisterSignalSchemas();
|
||||
|
|
|
|||
|
|
@ -2587,6 +2587,15 @@ Status Graph::VerifyNodeAndOpMatch(const ResolveOptions& options) {
|
|||
}
|
||||
}
|
||||
|
||||
// verify subgraphs
|
||||
for (auto node_index : nodes_in_topological_order_) {
|
||||
auto& node = *GetNode(node_index);
|
||||
for (auto& entry : node.GetAttributeNameToMutableSubgraphMap()) {
|
||||
Graph* subgraph = entry.second;
|
||||
ORT_RETURN_IF_ERROR(subgraph->VerifyNodeAndOpMatch(options));
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -370,6 +370,56 @@ static Status TopKImpl(OpKernelContext* p_op_kernel_context, const Tensor* input
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Wrapper over core TopK implementation
|
||||
template <typename T>
|
||||
Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices) {
|
||||
const TensorShape& input_shape = input->Shape();
|
||||
|
||||
// Will return axis_ as is if positive or fixes it in case it is negative
|
||||
const auto axis_parsed = HandleNegativeAxis(axis, static_cast<int64_t>(input_shape.NumDimensions()));
|
||||
|
||||
// Check to ensure k is within the bounds of what is available in that specific axis
|
||||
if (input_shape[axis_parsed] < k) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "k argument [", k,
|
||||
"] should not be greater than specified axis dim value [", input_shape[axis_parsed], "]");
|
||||
}
|
||||
|
||||
// Resize output tensors to be the same shape as the input except
|
||||
// for the specified dimension ((i.e.) axis_parsed), which will be of size k. E.x. for an input tensor
|
||||
// of shape [3, 4, 5] and k=2 with axis_parsed=1, both of the outputs will be shape [3, 2, 5]
|
||||
TensorShape output_shape = input_shape;
|
||||
output_shape[axis_parsed] = k;
|
||||
|
||||
output_values = Tensor::Create(input->DataType(), output_shape, allocator);
|
||||
output_indices = Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
|
||||
|
||||
// no-op - no output buffers to fill - return silently
|
||||
if (k == 0) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (largest) {
|
||||
FindTopKElements<GreaterValueCmp<T>>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
|
||||
} else {
|
||||
FindTopKElements<LesserValueCmp<T>>(input, input_shape, output_values.get(), output_indices.get(), output_shape, k, sorted,
|
||||
gsl::narrow_cast<unsigned>(axis_parsed), threadpool);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// explicit instantiation
|
||||
template Status GetTopK<float>(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices);
|
||||
|
||||
// Opset ver - 1 to 9
|
||||
|
||||
static void TopkOpset9ConstructorCommon(const OpKernelInfo& op_kernel_info, int& axis, unsigned int& k) {
|
||||
|
|
|
|||
|
|
@ -19,4 +19,11 @@ class TopK final : public OpKernel {
|
|||
bool largest_; // opset-11 only
|
||||
bool sorted_; // opset-11 only
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status GetTopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices);
|
||||
} // namespace onnxruntime
|
||||
459
onnxruntime/python/tools/transformers/convert_beam_search.py
Normal file
459
onnxruntime/python/tools/transformers/convert_beam_search.py
Normal file
|
|
@ -0,0 +1,459 @@
|
|||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License.
|
||||
|
||||
import os
|
||||
import time
|
||||
import onnx
|
||||
import logging
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from onnx import helper
|
||||
import numpy as np
|
||||
from typing import List
|
||||
import torch
|
||||
from transformers import GPT2Config
|
||||
from gpt2_helper import PRETRAINED_GPT2_MODELS
|
||||
from convert_to_onnx import main as convert_gpt2_to_onnx
|
||||
from benchmark_helper import Precision
|
||||
"""
|
||||
This converts GPT2 model to onnx with beam search operator.
|
||||
|
||||
Examples:
|
||||
python convert_beam_search.py -m gpt2 --gpt2_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
|
||||
"""
|
||||
|
||||
config: GPT2Config = None
|
||||
|
||||
logger = logging.getLogger('')
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument('-m',
|
||||
'--model_name_or_path',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Model path, or pretrained model name in the list: ' + ', '.join(PRETRAINED_GPT2_MODELS))
|
||||
|
||||
parser.add_argument('--cache_dir',
|
||||
required=False,
|
||||
type=str,
|
||||
default=os.path.join('.', 'cache_models'),
|
||||
help='Directory to cache pre-trained models')
|
||||
|
||||
parser.add_argument('--gpt2_onnx',
|
||||
required=True,
|
||||
type=str,
|
||||
help='Output directory for GPT-2 onnx model, or model path ends with .onnx')
|
||||
|
||||
parser.add_argument('--output',
|
||||
required=False,
|
||||
type=str,
|
||||
help='Output directory for beam search model, or model path ends with .onnx')
|
||||
|
||||
parser.add_argument("-p",
|
||||
"--precision",
|
||||
required=False,
|
||||
type=Precision,
|
||||
default=Precision.FLOAT32,
|
||||
choices=[Precision.FLOAT32, Precision.FLOAT16],
|
||||
help="Precision of model to run. fp32 for full precision, fp16 for half or mixed precision")
|
||||
|
||||
parser.add_argument('--use_gpu', required=False, action='store_true', help="use GPU for inference")
|
||||
parser.set_defaults(use_gpu=False)
|
||||
|
||||
parser.add_argument('-e', '--use_external_data_format', required=False, action='store_true')
|
||||
parser.set_defaults(use_external_data_format=False)
|
||||
|
||||
parser.add_argument('--disable_parity', required=False, action='store_true', help="do not run parity test")
|
||||
parser.set_defaults(disable_parity=False)
|
||||
|
||||
parser.add_argument('--total_runs',
|
||||
required=False,
|
||||
type=int,
|
||||
default=1,
|
||||
help='Number of times of inference for latency measurement')
|
||||
|
||||
beam_search_group = parser.add_argument_group("beam search options")
|
||||
|
||||
beam_search_group.add_argument('--output_sequences_scores',
|
||||
required=False,
|
||||
action='store_true',
|
||||
help="output sequences scores")
|
||||
beam_search_group.set_defaults(output_sequences_scores=False)
|
||||
|
||||
beam_search_group.add_argument('--output_token_scores',
|
||||
required=False,
|
||||
action='store_true',
|
||||
help="output token scores")
|
||||
beam_search_group.set_defaults(output_token_scores=False)
|
||||
|
||||
beam_search_group.add_argument('--early_stopping', required=False, action='store_true')
|
||||
beam_search_group.set_defaults(early_stopping=False)
|
||||
|
||||
beam_search_group.add_argument('--min_length', type=int, required=False, default=1, help='Min sequence length')
|
||||
|
||||
beam_search_group.add_argument('--max_length', type=int, required=False, default=50, help='Max sequence length')
|
||||
|
||||
beam_search_group.add_argument('--no_repeat_ngram_size',
|
||||
type=int,
|
||||
required=False,
|
||||
default=0,
|
||||
help='No repeat ngram size')
|
||||
|
||||
beam_search_group.add_argument('--num_beams', type=int, required=False, default=4, help='Beam size')
|
||||
|
||||
beam_search_group.add_argument('--num_return_sequences',
|
||||
type=int,
|
||||
required=False,
|
||||
default=1,
|
||||
help='Number of return sequence <= num_beams')
|
||||
|
||||
beam_search_group.add_argument('--temperature',
|
||||
type=float,
|
||||
required=False,
|
||||
default=1,
|
||||
help='Softmax temperature for output logits.')
|
||||
|
||||
beam_search_group.add_argument('--length_penalty',
|
||||
type=float,
|
||||
required=False,
|
||||
default=1,
|
||||
help='Positive. >1 to penalize and <1 to encorage short sentence.')
|
||||
|
||||
beam_search_group.add_argument('--repetition_penalty',
|
||||
type=float,
|
||||
required=False,
|
||||
default=1,
|
||||
help='Positive. >1 to penalize and <1 to encorage.')
|
||||
|
||||
mixed_precision_option_group = parser.add_argument_group(
|
||||
"mixed precision conversion parameters that works when \"--precision fp16\" is specified")
|
||||
|
||||
mixed_precision_option_group.add_argument('--io_block_list',
|
||||
nargs='+',
|
||||
required=False,
|
||||
default=[],
|
||||
help='List of inputs or outputs in float32')
|
||||
|
||||
mixed_precision_option_group.add_argument(
|
||||
'--op_block_list',
|
||||
nargs='+',
|
||||
required=False,
|
||||
default=[],
|
||||
help='List of operators (like Add LayerNormalization FastGelu) to compute in float32.')
|
||||
|
||||
mixed_precision_option_group.add_argument('--node_block_list',
|
||||
nargs='+',
|
||||
required=False,
|
||||
default=[],
|
||||
help='List of node names to compute in float32.')
|
||||
|
||||
mixed_precision_option_group.add_argument('--force_fp16_initializers',
|
||||
required=False,
|
||||
action='store_true',
|
||||
help='Convert all float initializers to float16.')
|
||||
mixed_precision_option_group.set_defaults(force_fp16_initializers=False)
|
||||
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def gpt2_to_onnx(args):
|
||||
model_name = args.model_name_or_path
|
||||
|
||||
print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.gpt2_onnx} ...")
|
||||
arguments = [
|
||||
'--model_name_or_path', model_name, '--output', args.gpt2_onnx, '--optimize_onnx', '--precision',
|
||||
'fp32' if args.precision == Precision.FLOAT32 else 'fp16', '--test_runs', '1', '--test_cases', '10'
|
||||
]
|
||||
if args.use_gpu:
|
||||
arguments.append('--use_gpu')
|
||||
if args.use_external_data_format:
|
||||
arguments.append('--use_external_data_format')
|
||||
|
||||
# mixed precision conversion options
|
||||
if args.precision == Precision.FLOAT16:
|
||||
assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
|
||||
if args.io_block_list:
|
||||
arguments.append('--io_block_list')
|
||||
arguments.extend(args.io_block_list)
|
||||
if args.op_block_list:
|
||||
arguments.append('--op_block_list')
|
||||
arguments.extend(args.op_block_list)
|
||||
if args.node_block_list:
|
||||
arguments.append('--node_block_list')
|
||||
arguments.extend(args.node_block_list)
|
||||
if args.force_fp16_initializers:
|
||||
arguments.append('--force_fp16_initializers')
|
||||
|
||||
convert_gpt2_to_onnx(arguments)
|
||||
|
||||
|
||||
def shape_inference(gpt2_onnx_path):
|
||||
# Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
out = SymbolicShapeInference.infer_shapes(onnx.load(gpt2_onnx_path), auto_merge=True, guess_output_rank=False)
|
||||
if out:
|
||||
# TODO: Use external format if input has extra data.
|
||||
onnx.save(out, gpt2_onnx_path)
|
||||
else:
|
||||
print("Failed to run symbolic shape inference on the model.")
|
||||
|
||||
|
||||
def create_ort_session(model_path, use_gpu):
|
||||
from onnxruntime import SessionOptions, InferenceSession, __version__ as ort_version, GraphOptimizationLevel
|
||||
sess_options = SessionOptions()
|
||||
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
execution_providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] if use_gpu else ['CPUExecutionProvider']
|
||||
|
||||
ort_session = InferenceSession(model_path, sess_options, providers=execution_providers)
|
||||
return ort_session
|
||||
|
||||
|
||||
def convert_model(args):
|
||||
if os.path.exists(args.gpt2_onnx):
|
||||
print(f"skip convert_to_onnx since path existed: {args.gpt2_onnx}")
|
||||
else:
|
||||
gpt2_to_onnx(args)
|
||||
|
||||
print(f"Run symbolic shape inference on {args.gpt2_onnx}. The file will be overwritten.")
|
||||
shape_inference(args.gpt2_onnx)
|
||||
|
||||
global config
|
||||
config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
print(config)
|
||||
|
||||
eos_token_id = config.eos_token_id
|
||||
pad_token_id = config.eos_token_id
|
||||
vocab_size = config.vocab_size
|
||||
|
||||
model = onnx.load(args.gpt2_onnx)
|
||||
model.graph.name = "gpt2 subgraph"
|
||||
inputs = [
|
||||
"input_ids", "max_length", "min_length", "num_beams", "num_return_sequences", "temperature", "length_penalty",
|
||||
"repetition_penalty", "vocab_mask"
|
||||
]
|
||||
|
||||
outputs = ["sequences"]
|
||||
if args.output_sequences_scores:
|
||||
outputs.append("sequences_scores")
|
||||
|
||||
if args.output_token_scores:
|
||||
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
|
||||
outputs.append("scores")
|
||||
|
||||
node = helper.make_node('BeamSearch', inputs=inputs, outputs=outputs, name='BeamSearch_GPT2')
|
||||
node.domain = "com.microsoft"
|
||||
node.attribute.extend([
|
||||
helper.make_attribute("eos_token_id", eos_token_id),
|
||||
helper.make_attribute("pad_token_id", pad_token_id),
|
||||
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
|
||||
helper.make_attribute("body", model.graph),
|
||||
])
|
||||
|
||||
from onnx import TensorProto
|
||||
|
||||
# graph inputs
|
||||
input_ids = helper.make_tensor_value_info('input_ids', TensorProto.INT32, ['batch_size', 'sequence_length'])
|
||||
max_length = helper.make_tensor_value_info('max_length', TensorProto.INT32, [1])
|
||||
min_length = helper.make_tensor_value_info('min_length', TensorProto.INT32, [1])
|
||||
num_beams = helper.make_tensor_value_info('num_beams', TensorProto.INT32, [1])
|
||||
num_return_sequences = helper.make_tensor_value_info('num_return_sequences', TensorProto.INT32, [1])
|
||||
temperature = helper.make_tensor_value_info('temperature', TensorProto.FLOAT, [1])
|
||||
length_penalty = helper.make_tensor_value_info('length_penalty', TensorProto.FLOAT, [1])
|
||||
repetition_penalty = helper.make_tensor_value_info('repetition_penalty', TensorProto.FLOAT, [1])
|
||||
vocab_mask = helper.make_tensor_value_info('vocab_mask', TensorProto.INT32, [vocab_size])
|
||||
|
||||
graph_inputs = [
|
||||
input_ids, max_length, min_length, num_beams, num_return_sequences, temperature, length_penalty,
|
||||
repetition_penalty, vocab_mask
|
||||
]
|
||||
|
||||
# graph outputs
|
||||
sequences = helper.make_tensor_value_info('sequences', TensorProto.INT32,
|
||||
['batch_size', 'num_return_sequences', 'max_length'])
|
||||
|
||||
sequences_scores = helper.make_tensor_value_info('sequences_scores', TensorProto.FLOAT,
|
||||
['batch_size', 'num_return_sequences'])
|
||||
|
||||
scores = helper.make_tensor_value_info('scores', TensorProto.FLOAT,
|
||||
['max_length - sequence_length', 'batch_size', 'num_beams', vocab_size])
|
||||
|
||||
initializers = []
|
||||
|
||||
graph_outputs = [sequences]
|
||||
|
||||
if args.output_sequences_scores:
|
||||
graph_outputs.append(sequences_scores)
|
||||
|
||||
if args.output_token_scores:
|
||||
graph_outputs.append(scores)
|
||||
|
||||
new_graph = helper.make_graph([node], 'gpt2-beam-search', graph_inputs, graph_outputs, initializers)
|
||||
|
||||
# Create the model
|
||||
new_model = helper.make_model(new_graph, producer_name='onnxruntime.transformers', opset_imports=model.opset_import)
|
||||
onnx.save(new_model, args.output)
|
||||
|
||||
|
||||
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
|
||||
from transformers import GPT2Tokenizer, GPT2LMHeadModel
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
tokenizer.padding_side = "left"
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
model = GPT2LMHeadModel.from_pretrained(args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
pad_token_id=tokenizer.eos_token_id)
|
||||
|
||||
# Use different length sentences to test batching
|
||||
if sentences is None:
|
||||
sentences = ["The product is released", "I enjoy walking in the park", "Test best way to invest"]
|
||||
|
||||
inputs = tokenizer(sentences, return_tensors='pt', padding=True)
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
|
||||
bad_words = "walk in park"
|
||||
bad_words_ids = tokenizer.encode(bad_words, add_prefix_space=True)
|
||||
bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
|
||||
if use_vocab_mask:
|
||||
print("bad_words_ids", bad_words_ids)
|
||||
else:
|
||||
bad_words_ids = None
|
||||
|
||||
global config
|
||||
config = model.config
|
||||
eos_token_id = config.eos_token_id
|
||||
pad_token_id = config.eos_token_id
|
||||
vocab_size = config.vocab_size
|
||||
|
||||
torch_decoded_sequences = []
|
||||
if not args.disable_parity:
|
||||
print('-' * 50)
|
||||
print("Test PyTorch model and beam search with huggingface transformers...")
|
||||
beam_outputs = model.generate(input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=args.max_length,
|
||||
min_length=args.min_length,
|
||||
num_beams=args.num_beams,
|
||||
early_stopping=args.early_stopping,
|
||||
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
num_return_sequences=args.num_return_sequences,
|
||||
temperature=args.temperature,
|
||||
length_penalty=args.length_penalty,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
bad_words_ids=bad_words_ids,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True)
|
||||
print("input_ids", input_ids)
|
||||
print("huggingface transformers outputs:")
|
||||
print("sequences", beam_outputs.sequences)
|
||||
if args.output_sequences_scores:
|
||||
print("sequences_scores", beam_outputs.sequences_scores)
|
||||
if args.output_token_scores:
|
||||
print("scores", beam_outputs.scores)
|
||||
for i, sequence in enumerate(beam_outputs.sequences):
|
||||
decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
|
||||
torch_decoded_sequences.append(decoded_sequence)
|
||||
print("{}: {}".format(i, decoded_sequence))
|
||||
|
||||
print('-' * 50)
|
||||
print("Test ONNX model and bream search with onnxruntime...")
|
||||
|
||||
ort_session = create_ort_session(args.output, args.use_gpu)
|
||||
|
||||
vocab_mask = np.ones((vocab_size), dtype=np.int32)
|
||||
if use_vocab_mask:
|
||||
for bad_word_id in bad_words_ids:
|
||||
vocab_mask[bad_word_id] = 0
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids.cpu().numpy().astype(np.int32),
|
||||
"max_length": np.array([args.max_length], dtype=np.int32),
|
||||
"min_length": np.array([args.min_length], dtype=np.int32),
|
||||
"num_beams": np.array([args.num_beams], dtype=np.int32),
|
||||
"num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
|
||||
"temperature": np.array([args.temperature], dtype=np.float32),
|
||||
"length_penalty": np.array([args.length_penalty], dtype=np.float32),
|
||||
"repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
|
||||
"vocab_mask": vocab_mask
|
||||
}
|
||||
|
||||
test_data_dir = Path(args.output).parent.as_posix()
|
||||
print("test_data_dir", test_data_dir)
|
||||
from bert_test_data import output_test_data
|
||||
all_inputs = [inputs]
|
||||
for i, inputs in enumerate(all_inputs):
|
||||
dir = os.path.join(test_data_dir, 'test_data_set_' + str(i))
|
||||
output_test_data(dir, inputs)
|
||||
|
||||
print("inputs", inputs)
|
||||
|
||||
# Test performance
|
||||
latency = []
|
||||
for _ in range(args.total_runs):
|
||||
start = time.time()
|
||||
result = ort_session.run(None, inputs)
|
||||
latency.append(time.time() - start)
|
||||
batch_size = input_ids.shape[0]
|
||||
from benchmark_helper import get_latency_result
|
||||
output = get_latency_result(latency, batch_size)
|
||||
|
||||
print("ORT outputs:")
|
||||
sequences = result[0]
|
||||
print("sequences", sequences)
|
||||
if args.output_sequences_scores:
|
||||
print("sequences_scores", result[1])
|
||||
if args.output_token_scores:
|
||||
print("scores", result[2])
|
||||
|
||||
(batch_size, num_sequences, max_length) = sequences.shape
|
||||
ort_decoded_sequences = []
|
||||
for i in range(batch_size):
|
||||
for j in range(num_sequences):
|
||||
decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
|
||||
ort_decoded_sequences.append(decoded_sequence)
|
||||
print(f"batch {i} sequence {j}: {decoded_sequence}")
|
||||
|
||||
if not args.disable_parity:
|
||||
torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
|
||||
ort_sequences = torch.LongTensor(sequences)
|
||||
print("-" * 50)
|
||||
print("Torch Sequences:")
|
||||
print(torch_sequences)
|
||||
print(torch_decoded_sequences)
|
||||
print("-" * 50)
|
||||
print("ORT Sequences:")
|
||||
print(ort_sequences)
|
||||
print(ort_decoded_sequences)
|
||||
print("-" * 50)
|
||||
# Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
|
||||
is_same = (torch_decoded_sequences == ort_decoded_sequences)
|
||||
print("Torch and ORT result is ", "same" if is_same else "different")
|
||||
output["parity"] = is_same
|
||||
|
||||
print(output)
|
||||
return output
|
||||
|
||||
|
||||
def main(argv=None, sentences=None):
|
||||
args = parse_arguments(argv)
|
||||
|
||||
if os.path.exists(args.output):
|
||||
print(f"skip conversion since path existed: {args.output}")
|
||||
else:
|
||||
convert_model(args)
|
||||
|
||||
return test_model(args, use_vocab_mask=True, sentences=sentences)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
@ -17,7 +17,6 @@ This converts GPT2 model to onnx. Examples:
|
|||
|
||||
import os
|
||||
import argparse
|
||||
import coloredlogs
|
||||
import logging
|
||||
import torch
|
||||
import numpy
|
||||
|
|
|
|||
70
onnxruntime/test/python/transformers/test_beam_search.py
Normal file
70
onnxruntime/test/python/transformers/test_beam_search.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
#!/usr/bin/env python
|
||||
# coding: utf-8
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
# --------------------------------------------------------------------------
|
||||
|
||||
import unittest
|
||||
import os
|
||||
import pytest
|
||||
|
||||
from parity_utilities import find_transformers_source
|
||||
if find_transformers_source():
|
||||
from convert_beam_search import main as run
|
||||
else:
|
||||
from onnxruntime.transformers.convert_beam_search import main as run
|
||||
|
||||
|
||||
class TestBeamSearch(unittest.TestCase):
|
||||
def setUp(self):
|
||||
#TODO: use a smaller model and enable tests in CI pipeline
|
||||
self.model_name = "gpt2"
|
||||
self.gpt2_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_past_fp32_shape.onnx')
|
||||
self.beam_search_onnx_path = os.path.join('.', 'onnx_models', 'gpt2_beam_search.onnx')
|
||||
self.cpu_params = f'-m {self.model_name} --gpt2_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0'
|
||||
|
||||
def run_beam_search(self, arguments: str, sentences=None):
|
||||
return run(arguments.split(), sentences=sentences)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cpu(self):
|
||||
result = self.run_beam_search(self.cpu_params + " --num_return_sequences 2",
|
||||
sentences=["The product is released"])
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_early_stopping(self):
|
||||
result = self.run_beam_search(self.cpu_params + " --early_stopping")
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_temperature(self):
|
||||
result = self.run_beam_search(self.cpu_params + " --temperature 0.5")
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_length_penalty(self):
|
||||
result = self.run_beam_search(self.cpu_params + " --length_penalty 0.5")
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_no_repeat_ngram(self):
|
||||
for ngram_size in [1, 2]:
|
||||
result = self.run_beam_search(self.cpu_params + f' --no_repeat_ngram_size {ngram_size}')
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
@ -3,6 +3,10 @@
|
|||
"Affine ai.onnx CPUExecutionProvider",
|
||||
7811918192248490408
|
||||
],
|
||||
[
|
||||
"BeamSearch com.microsoft CPUExecutionProvider",
|
||||
6968087233460196528
|
||||
],
|
||||
[
|
||||
"Crop ai.onnx CPUExecutionProvider",
|
||||
6914973556202621376
|
||||
|
|
|
|||
Loading…
Reference in a new issue