Support T5 in BeamSearch operator (#11450)

(1) Support T5 in BeamSearch operator, and add both CPU and CUDA implementation.
(2) Change BeamSearch op: rename encoder_decoder_init attribute to encoder, and add decoder_start_token_id attribute
(3) Update convert_to_onnx for T5 to use int32 instead of int64 inputs as default.
(4) Add more tests in best_beam_search.py
(5) fix ORT_ENFORCE of hypothesis_buffer_offset_
(6) Improve ONNX conversion:
   (a) Change encoder some dynamic axes to fixed dim value
   (b) add --separate_encoder_and_decoder_init
   (c) correct name t5-3B => t5-3b, t5-11B => t5-11b
   (d) Add --use_int32_inputs in convert t5 to onnx
   (e) Allow t5 beam search conversion in one step
This commit is contained in:
Tianlei Wu 2022-06-10 15:06:57 -07:00 committed by GitHub
parent 768b9cfb60
commit def78a1b81
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
47 changed files with 3708 additions and 1465 deletions

View file

@ -355,10 +355,12 @@ This version of the operator has been available since version 1 of the 'com.micr
<dl>
<dt><tt>decoder</tt> : graph (required)</dt>
<dd>Decoder subgraph to execute in a loop.</dd>
<dt><tt>decoder_start_token_id</tt> : int</dt>
<dd>The id of the token that indicates decoding starts.</dd>
<dt><tt>early_stopping</tt> : int</dt>
<dd>early stop or not</dd>
<dt><tt>encoder_decoder_init</tt> : graph</dt>
<dd>subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
<dt><tt>encoder</tt> : graph</dt>
<dd>The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
<dt><tt>eos_token_id</tt> : int (required)</dt>
<dd>The id of the end-of-sequence token</dd>
<dt><tt>model_type</tt> : int</dt>

View file

@ -9,6 +9,7 @@
#pragma warning(disable : 4996)
#endif
#include <memory>
#include <assert.h>
#include <functional>
#include "core/common/safeint.h"
@ -26,11 +27,13 @@
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
#include "gsl/gsl"
#include "beam_search.h"
#include "logits_processor.h"
#include "sequences.h"
#include "dump_tensor.h"
#include "beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/beam_search.h"
#include "contrib_ops/cpu/transformers/logits_processor.h"
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/dump_tensor.h"
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/beam_search_impl_gpt.h"
#include "contrib_ops/cpu/transformers/beam_search_impl_t5.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
@ -53,590 +56,156 @@ REGISTER_KERNEL_TYPED(float)
namespace transformers {
template <typename T>
gsl::span<T> AllocateBuffer(AllocatorPtr allocator,
BufferUniquePtr& buffer,
size_t elements,
bool fill = false,
T fill_value = T{}) {
size_t bytes = SafeInt<size_t>(sizeof(T)) * elements;
void* data = allocator->Alloc(bytes);
BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
buffer = std::move(temp_buffer);
T* first = reinterpret_cast<T*>(buffer.get());
auto span = gsl::make_span(first, elements);
if (fill) {
std::fill_n(first, elements, fill_value);
}
return span;
}
template <typename T>
struct BeamSearchState : public IBeamSearchState<T> {
void Init(AllocatorPtr allocator,
int batch_size,
int num_beams,
int vocab_size,
int sequence_length,
int max_length,
bool output_scores) {
size_t batch_beam_size = SafeInt<size_t>(batch_size) * num_beams;
size_t next_token_size = SafeInt<size_t>(batch_beam_size) * vocab_size;
this->next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size);
this->next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size);
this->next_tokens = AllocateBuffer<int32_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_indices = AllocateBuffer<int32_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_positions = AllocateBuffer<int32_t>(allocator, next_positions_buffer_, batch_beam_size);
this->beam_scores = AllocateBuffer<float>(allocator, beam_scores_buffer_, batch_beam_size);
if (output_scores) {
size_t elements = SafeInt<size_t>(max_length - sequence_length) * batch_size * num_beams * vocab_size;
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements);
this->remaining_scores = this->scores;
}
}
private:
BufferUniquePtr next_token_logits_buffer_;
BufferUniquePtr next_token_scores_buffer_;
BufferUniquePtr next_tokens_buffer_;
BufferUniquePtr next_indices_buffer_;
BufferUniquePtr next_positions_buffer_;
BufferUniquePtr beam_scores_buffer_;
BufferUniquePtr scores_buffer_;
};
struct BeamSearchCpuState : public IBeamSearchCpuState {
Sequences sequences;
void Init(AllocatorPtr allocator, size_t batch_beam_size, int max_length, bool is_cuda) {
this->sequence_lengths = AllocateBuffer<int32_t>(allocator, sequence_lengths_buffer_, batch_beam_size);
this->sequences_space = AllocateBuffer<int32_t>(allocator, sequences_space_buffer_, SafeInt<size_t>(2) * batch_beam_size * max_length);
if (is_cuda) {
// buffers used by CUDA operator but not by CPU operator.
this->topk_scores = AllocateBuffer<float>(allocator, topk_scores_buffer_, 2 * batch_beam_size);
this->topk_tokens = AllocateBuffer<int32_t>(allocator, topk_tokens_buffer_, 2 * batch_beam_size);
this->topk_indices = AllocateBuffer<int32_t>(allocator, topk_indices_buffer_, 2 * batch_beam_size);
this->final_beam_scores = AllocateBuffer<float>(allocator, final_beam_scores_buffer_, batch_beam_size);
}
}
private:
BufferUniquePtr final_beam_scores_buffer_;
BufferUniquePtr sequence_lengths_buffer_;
BufferUniquePtr topk_scores_buffer_;
BufferUniquePtr topk_tokens_buffer_;
BufferUniquePtr topk_indices_buffer_;
BufferUniquePtr sequences_space_buffer_;
};
template <typename T>
class BeamSearchImpl {
public:
BeamSearchImpl(OpKernelContextInternal& context,
const SessionState& session_state,
GptSubgraph& gpt_subgraph,
concurrency::ThreadPool* thread_pool,
void* cuda_stream,
IConsoleDumper* cuda_dumper,
BeamSearchParameters& params,
const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const BeamSearchDeviceHelper::TopkFunc& topk_func,
const BeamSearchDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const BeamSearchDeviceHelper::UpdateFeedsFunc<T>& update_feeds_func)
: context_(context),
session_state_(session_state),
gpt_subgraph_(gpt_subgraph),
thread_pool_(thread_pool),
implicit_inputs_(context_.GetImplicitInputs()),
cuda_stream_(cuda_stream),
cuda_dumper_(cuda_dumper),
parameters_(&params),
cpu_allocator_(nullptr),
temp_space_allocator_(nullptr),
create_inputs_func_(create_inputs_func),
add_to_feeds_func_(add_to_feeds_func),
topk_func_(topk_func),
process_logits_func_(process_logits_func),
init_beam_state_func_(init_beam_state_func),
device_copy_func_(device_copy_func),
update_feeds_func_(update_feeds_func) {
parameters_->ParseFromInputs(&context);
cpu_allocator_ = session_state.GetExecutionProviders()
.Get(onnxruntime::kCpuExecutionProvider)
->GetAllocator(0, OrtMemTypeDefault);
}
// Initialize by validating all the inputs, and allocating the output tensors.
Status Initialize();
// Execute beam search in iterations util stopping criteria is reached.
// In each iteration, GPT subgraph is called, and next token for each sequence is generated.
Status Execute(const FeedsFetchesManager& feeds_fetches_manager);
private:
bool IsCuda() const { return cuda_stream_ != nullptr; }
// Validate inputs.
Status CheckInputs(const OpKernelContextInternal& context);
// Prepare the inputs for first inference of subgraph
Status CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths, OrtValue& expanded_input_ids, std::vector<OrtValue>& feeds, IAllocatorUniquePtr<char>& buffer);
// Update the input for next iteration.
Status UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices);
// Process logits and append next tokens to sequences.
Status GenerateNextToken(const OrtValue& logits,
gsl::span<int32_t>& beam_next_tokens,
gsl::span<int32_t>& beam_indices,
BeamSearchState<T>& beam_state,
BeamSearchCpuState& cpu_state,
int counter);
// Calculate scores from logits, then apply filtering and select next token for each beam.
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
BeamSearchState<T>& beam_state,
BeamSearchCpuState& cpu_state,
AllocatorPtr& allocator,
int counter);
const IConsoleDumper* GetConsoleDumper() const { return IsCuda() ? cuda_dumper_ : &(cpu_dumper_); }
OpKernelContextInternal& context_;
const SessionState& session_state_;
GptSubgraph& gpt_subgraph_;
concurrency::ThreadPool* thread_pool_;
const std::vector<const OrtValue*>& implicit_inputs_;
void* cuda_stream_;
IConsoleDumper* cuda_dumper_;
CpuTensorConsoleDumper cpu_dumper_;
BeamSearchParameters* parameters_;
LogitsProcessorList logits_processors_;
std::unique_ptr<BeamSearchScorer> beam_scorer_;
AllocatorPtr cpu_allocator_;
AllocatorPtr temp_space_allocator_;
// Device specific functions
BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_;
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
BeamSearchDeviceHelper::TopkFunc topk_func_;
BeamSearchDeviceHelper::ProcessLogitsFunc<T> process_logits_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
BeamSearchDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
BeamSearchDeviceHelper::UpdateFeedsFunc<T> update_feeds_func_;
};
void BeamSearch::Init(const OpKernelInfo& info) {
// Make sure the decoder attribute was present even though we don't need it here.
parameters_.ParseFromAttributes(info);
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5).
ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
ONNX_NAMESPACE::GraphProto proto;
if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
}
// Make sure the decoder attribute was present even though we don't need it here.
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("decoder", &proto).IsOK());
ORT_IGNORE_RETURN_VALUE(proto);
parameters_.ParseFromAttributes(info);
}
Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
// TODO: handle another subgraph with attribute name "encoder_decode_init"
if (attribute_name == "decoder") {
const auto& node = Node();
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
gpt_subgraph_->num_heads,
gpt_subgraph_->head_size,
gpt_subgraph_->num_layers);
const auto& node = Node();
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (attribute_name == "decoder") {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
gpt_subgraph_->num_heads,
gpt_subgraph_->head_size,
gpt_subgraph_->num_layers);
}
} else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) {
if (attribute_name == "encoder") {
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
t5_encoder_subgraph_ = std::make_unique<T5EncoderSubgraph>(node,
attribute_name,
subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
if (parameters_.decoder_start_token_id < 0) {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
} else {
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
}
} else if (attribute_name == "decoder") {
ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
t5_decoder_subgraph_ = std::make_unique<T5DecoderSubgraph>(node,
attribute_name,
subgraph_session_state.GetGraphViewer());
ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state));
decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager();
parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size,
t5_decoder_subgraph_->num_heads,
t5_decoder_subgraph_->head_size,
t5_decoder_subgraph_->num_layers);
}
}
return Status::OK();
}
Status BeamSearch::Compute(OpKernelContext* ctx) const {
if (parameters_.model_type != 0) {
// TODO: support encoder decoder model like T5
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Support of 'model_type' != 0 is not implemented");
}
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto* session_state = ctx_internal->SubgraphSessionState("decoder");
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
auto* decoder_session_state = ctx_internal->SubgraphSessionState("decoder");
ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
BeamSearchParameters parameters = parameters_; // make a copy since we will update the parameters based on inputs later
// Make a copy of parameters since we will update it based on inputs later
BeamSearchParameters parameters = parameters_;
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32
BeamSearchGpt<float> impl{
*ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
BeamSearchCpuDeviceHelper::CreateGptInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<int32_t>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateGptFeeds<float>};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(*decoder_feeds_fetches_manager_);
} else { // Output float16
BeamSearchGpt<MLFloat16> impl{
*ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
BeamSearchCpuDeviceHelper::CreateGptInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
update_gpt_feeds_fp16_func_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(*decoder_feeds_fetches_manager_);
}
}
auto* encoder_session_state = ctx_internal->SubgraphSessionState("encoder");
ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute.");
ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
// Subgraph has constraint that the output is either float or float16
if (!gpt_subgraph_->IsOutputFloat16()) {
BeamSearchImpl<float> impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
create_inputs_func_ ? create_inputs_func_ : BeamSearchCpuDeviceHelper::CreateInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<float>,
update_feeds_func_ ? update_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateFeeds<float>};
if (!t5_decoder_subgraph_->IsOutputFloat16()) {
BeamSearchT5<float> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState<float>,
device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<float>,
device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<int32_t>,
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : BeamSearchCpuDeviceHelper::CreateEncoderInputs,
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds<float>};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(*feeds_fetches_manager_);
return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
} else {
BeamSearchImpl<MLFloat16> impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
create_inputs_func_ ? create_inputs_func_ : BeamSearchCpuDeviceHelper::CreateInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
device_copy_func_,
update_feeds_fp16_func_};
BeamSearchT5<MLFloat16> impl{
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
device_copy_func_,
device_copy_int32_func_,
create_encoder_inputs_func_,
update_decoder_feeds_fp16_func_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(*feeds_fetches_manager_);
return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
}
}
template <typename T>
Status BeamSearchImpl<T>::CheckInputs(const OpKernelContextInternal& context) {
// Input shapes:
// input_ids : (batch_size, sequence_length)
// vocab_mask : (vocab_size) or nullptr
const Tensor* input_ids = context.Input<Tensor>(0);
const auto& dims = input_ids->Shape().GetDims();
if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ",
dims.size());
}
const Tensor* vocab_mask = context.Input<Tensor>(8);
if (vocab_mask != nullptr) { // vocab_mask is optional
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
if (vocab_mask_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ",
vocab_mask_dims.size());
}
// There is dependency on vocab_size parameter, which shall be set before calling this function.
if (static_cast<int>(vocab_mask_dims[0]) != parameters_->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ",
vocab_mask_dims[0]);
}
// store vocab mask in parameters.
parameters_->vocab_mask = vocab_mask->DataAsSpan<int32_t>();
}
const Tensor* prefix_vocab_mask = context.Input<Tensor>(9);
if (prefix_vocab_mask != nullptr) {
// prefix_vocab_mask is optional
const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims();
if (vocab_mask_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ",
vocab_mask_dims.size());
}
// prefix_vocab_mask first dimension should be same as the first dimension of input_ids
if (static_cast<int>(vocab_mask_dims[0]) != static_cast<int>(dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size");
}
// There is dependency on vocab_size parameter, which shall be set before calling this function.
if (static_cast<int>(vocab_mask_dims[1]) != parameters_->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ",
vocab_mask_dims[0]);
}
// store prefix vocab mask in parameters.
parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan<int32_t>();
}
return Status::OK();
}
template <typename T>
Status BeamSearchImpl<T>::Initialize() {
ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator_));
#define CHECK_SCALAR_INPUT(name, index, required) \
auto* name##_tensor = context_.Input<Tensor>(index); \
if (name##_tensor) { \
if (!name##_tensor->Shape().IsScalar()) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \
name##_tensor->Shape()); \
} \
} else if (required) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \
}
CHECK_SCALAR_INPUT(min_length, 1, false);
CHECK_SCALAR_INPUT(max_length, 2, true);
CHECK_SCALAR_INPUT(num_beams, 3, true);
CHECK_SCALAR_INPUT(num_return_sequences, 4, true);
CHECK_SCALAR_INPUT(temperature, 5, true);
CHECK_SCALAR_INPUT(length_penalty, 6, true);
ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'.");
ORT_RETURN_IF_ERROR(CheckInputs(context_));
// This flag will be updated later when the scores output exists.
parameters_->output_scores = false;
if (!IsCuda()) {
// Logits processor is used in CPU only. In CUDA, cuda kernels are used instead.
// Initialize processsors after CheckInputs so that parameters_->vocab_mask is ready.
logits_processors_.Init(*parameters_);
}
return Status::OK();
}
template <typename T>
Status BeamSearchImpl<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths, OrtValue& expanded_input_ids, std::vector<OrtValue>& feeds, IAllocatorUniquePtr<char>& buffer) {
const OrtValue* input_ids_value = context_.GetInputOrtValue(0);
const Tensor& input_ids = input_ids_value->Get<Tensor>();
return gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, sequence_lengths, expanded_input_ids, feeds, create_inputs_func_, add_to_feeds_func_, buffer);
}
template <typename T>
Status BeamSearchImpl<T>::ProcessLogits(
const OrtValue& logits,
BeamSearchState<T>& beam_state,
BeamSearchCpuState& cpu_state,
AllocatorPtr& allocator,
int counter) {
return process_logits_func_(logits, &beam_state, &cpu_state, &(cpu_state.sequences), allocator,
thread_pool_, &logits_processors_, beam_scorer_.get(),
parameters_, counter, cuda_stream_, GetConsoleDumper());
}
template <typename T>
Status BeamSearchImpl<T>::GenerateNextToken(
const OrtValue& logits,
gsl::span<int32_t>& beam_next_tokens,
gsl::span<int32_t>& beam_indices,
BeamSearchState<T>& beam_state,
BeamSearchCpuState& cpu_state,
int counter) {
// Process logits to get next token scores
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, cpu_state, temp_space_allocator_, counter));
gsl::span<float>& beam_scores = beam_scorer_->GetNextScores();
// It is optional to clone beam_scores. Change it to use same buffer also works for CPU:
// beam_state.beam_scores = beam_scores
// Here we make a copy to reduce the coupling with little cost (the buffer size is small).
ORT_RETURN_IF_ERROR(device_copy_func_(beam_state.beam_scores, beam_scores, cuda_stream_, DeviceCopyDirection::hostToDevice));
beam_next_tokens = beam_scorer_->GetNextTokens();
beam_indices = beam_scorer_->GetNextIndices();
#ifdef DEBUG_BEAM_SEARCH
cpu_dumper_.Print("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams);
cpu_dumper_.Print("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams);
cpu_dumper_.Print("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams);
#endif
cpu_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens);
#ifdef DEBUG_BEAM_SEARCH
cpu_state.sequences.PrintSequences(&cpu_dumper_);
#endif
return Status::OK();
}
template <typename T>
Status BeamSearchImpl<T>::UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices) {
return update_feeds_func_(temp_space_allocator_, cuda_stream_, last_outputs, next_inputs, current_length, position_ids,
beam_next_tokens, beam_indices, parameters_->num_beams, GetConsoleDumper());
}
template <typename T>
Status BeamSearchImpl<T>::Execute(const FeedsFetchesManager& feeds_fetches_manager) {
auto status = Status::OK();
int64_t sequences_dims[] = {parameters_->batch_size, parameters_->num_return_sequences, parameters_->max_length};
TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0]));
Tensor* output_sequences = context_.Output(0, sequences_shape);
int64_t sequences_scores_dims[] = {parameters_->batch_size, parameters_->num_return_sequences};
TensorShape sequences_scores_shape(&sequences_scores_dims[0], sizeof(sequences_scores_dims) / sizeof(sequences_scores_dims[0]));
Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape);
int64_t scores_dims[] = {
static_cast<int64_t>(parameters_->max_length) - static_cast<int64_t>(parameters_->sequence_length),
parameters_->batch_size, parameters_->num_beams, parameters_->vocab_size};
TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0]));
Tensor* output_scores = context_.Output(2, scores_shape);
// Update the flag to indicate whether scores exists in output
parameters_->output_scores = (output_scores != nullptr);
std::vector<OrtValue> feeds;
// TODO: allocate fetches. use ping-pong buffers for past state.
std::vector<OrtValue> fetches;
// Initialize resources
onnxruntime::OrtStlAllocator<HypothesisScore> hypothesis_score_allocator(cpu_allocator_);
onnxruntime::OrtStlAllocator<BeamHypotheses> beam_hyps_allocator(cpu_allocator_);
beam_scorer_ = std::make_unique<BeamSearchScorer>(static_cast<size_t>(parameters_->batch_size),
static_cast<size_t>(parameters_->num_beams),
static_cast<size_t>(parameters_->max_length),
parameters_->length_penalty,
parameters_->early_stopping,
static_cast<size_t>(parameters_->num_return_sequences),
parameters_->pad_token_id,
parameters_->eos_token_id,
hypothesis_score_allocator,
beam_hyps_allocator);
beam_scorer_->Initialize(cpu_allocator_, parameters_->sequence_length);
BeamSearchCpuState cpu_state;
cpu_state.Init(cpu_allocator_, static_cast<size_t>(parameters_->BatchBeamSize()), parameters_->max_length, IsCuda());
// buffer in GPU for input_ids, position_ids and attention_mask
// size_t buffer_bytes = SafeInt<size_t>(sizeof(int32_t) + sizeof(int32_t) + sizeof(int32_t)) * parameters_->batch_size * parameters_->num_beams * parameters_->sequence_length;
// IAllocatorUniquePtr<char> buffer = gpt_subgraph_.GetProvider()->GetScratchBuffer<char>(buffer_bytes);
IAllocatorUniquePtr<char> buffer;
OrtValue expanded_input_ids_in_cpu;
ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
BeamSearchState<T> beam_state;
beam_state.Init(temp_space_allocator_,
parameters_->batch_size,
parameters_->num_beams,
parameters_->vocab_size,
parameters_->sequence_length,
parameters_->max_length,
parameters_->output_scores);
cpu_state.sequences.Init(cpu_state.sequences_space,
parameters_->BatchBeamSize(),
parameters_->sequence_length,
parameters_->max_length);
gsl::span<const int32_t> input_ids = expanded_input_ids_in_cpu.Get<Tensor>().DataAsSpan<int32_t>();
init_beam_state_func_(&beam_state,
&cpu_state,
cpu_state.sequence_lengths,
parameters_->batch_size,
parameters_->num_beams,
input_ids,
parameters_->sequence_length,
parameters_->max_length,
cuda_stream_);
#ifdef DEBUG_BEAM_SEARCH
const IConsoleDumper* dumper = GetConsoleDumper();
dumper->Print("input_ids", feeds[0]);
dumper->Print("position_ids", feeds[1]);
dumper->Print("attention_mask", feeds[2]);
#endif
// position ids for all iterations except the first. It uses memory buffer owned by next_positions.
OrtValue position_ids;
int64_t dims[] = {parameters_->BatchBeamSize(), 1};
TensorShape shape(&dims[0], 2);
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), shape, beam_state.next_positions.data(), temp_space_allocator_->Info(), position_ids);
int current_length = parameters_->sequence_length;
int iteration_counter = 0;
while (current_length < parameters_->max_length) {
iteration_counter++;
#ifdef DEBUG_BEAM_SEARCH
auto cur_len = std::to_string(current_length);
dumper->Print("***CurrentLength", cur_len, true);
#endif
status = utils::ExecuteSubgraph(session_state_, feeds_fetches_manager, feeds, fetches, {},
ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger());
ORT_RETURN_IF_ERROR(status);
const OrtValue& logits = fetches[0];
gsl::span<int32_t> beam_next_tokens;
gsl::span<int32_t> beam_indices;
ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, cpu_state, iteration_counter));
// When all batches are finished, stop earlier to avoid wasting computation.
if (beam_scorer_->IsDone()) {
break;
}
// Increase sequence length after a new token is generated.
++current_length;
// Prepare inputs for next round of subgraph call.
if (current_length < parameters_->max_length) {
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
position_ids,
beam_next_tokens.as_span<const int32_t>(),
beam_indices.as_span<const int32_t>()));
}
fetches.clear();
}
gsl::span<const float> final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
if (IsCuda()) {
ORT_RETURN_IF_ERROR(device_copy_func_(cpu_state.final_beam_scores, final_beam_scores, nullptr, DeviceCopyDirection::deviceToHost));
final_beam_scores = gsl::make_span<const float>(cpu_state.final_beam_scores.data(), cpu_state.final_beam_scores.size());
}
beam_scorer_->Finalize(&(cpu_state.sequences),
final_beam_scores,
output_sequences,
output_sequences_scores);
// Output per token scores
if (output_scores != nullptr) {
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
gsl::span<const float> source = gsl::span<const float>(beam_state.scores.data(), beam_state.scores.size());
assert(target.length() == source.length());
ORT_RETURN_IF_ERROR(device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
}
return status;
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -2,12 +2,16 @@
// Licensed under the MIT License.
#pragma once
#include <memory>
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "beam_search_parameters.h"
#include "gpt_subgraph.h"
#include "beam_search_device_helper.h"
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"
namespace onnxruntime {
class FeedsFetchesManager;
@ -20,7 +24,11 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
class BeamSearch : public IControlFlowKernel {
public:
BeamSearch(const OpKernelInfo& info)
: IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) {
: IControlFlowKernel(info),
encoder_feeds_fetches_manager_(nullptr),
decoder_feeds_fetches_manager_(nullptr),
cuda_stream_(nullptr),
dumper_(nullptr) {
Init(info);
}
@ -36,54 +44,76 @@ class BeamSearch : public IControlFlowKernel {
void SetComputeStream(void* stream) { cuda_stream_ = stream; }
void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; }
// device helpers that is same for both GPT and encoder-decoder models.
void SetDeviceHelpers(
// const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const BeamSearchDeviceHelper::TopkFunc& topk_func) {
// create_inputs_func_ = create_inputs_func;
const BeamSearchDeviceHelper::TopkFunc& topk_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
const BeamSearchDeviceHelper::ProcessLogitsFunc<float>& process_logits_func,
const BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16>& process_logits_fp16_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_fp16_func) {
add_to_feeds_func_ = add_to_feeds_func;
topk_func_ = topk_func;
}
// Type dependent helpers: float
void SetDeviceHelpers(
const BeamSearchDeviceHelper::ProcessLogitsFunc<float>& process_logits_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const BeamSearchDeviceHelper::UpdateFeedsFunc<float>& update_feeds_func) {
process_logits_func_ = process_logits_func;
init_beam_state_func_ = init_beam_state_func;
device_copy_func_ = device_copy_func;
update_feeds_func_ = update_feeds_func;
device_copy_int32_func_ = device_copy_int32_func;
process_logits_func_ = process_logits_func;
process_logits_fp16_func_ = process_logits_fp16_func;
init_beam_state_func_ = init_beam_state_func;
init_beam_state_fp16_func_ = init_beam_state_fp16_func;
}
// Type dependent helpers: MLFloat16
void SetDeviceHelpers(
const BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16>& process_logits_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_func,
const BeamSearchDeviceHelper::UpdateFeedsFunc<MLFloat16>& update_feeds_func) {
process_logits_fp16_func_ = process_logits_func;
init_beam_state_fp16_func_ = init_beam_state_func;
update_feeds_fp16_func_ = update_feeds_func;
void SetDeviceHelpers_Gpt(
const BeamSearchDeviceHelper::UpdateGptFeedsFunc<float>& update_gpt_feeds_func,
const BeamSearchDeviceHelper::UpdateGptFeedsFunc<MLFloat16>& update_gpt_feeds_fp16_func) {
update_gpt_feeds_func_ = update_gpt_feeds_func;
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
}
// device helpers for encoder-decoder model like T5
void SetDeviceHelpers_EncoderDecoder(
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float>& update_decoder_feeds_func,
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func) {
update_decoder_feeds_func_ = update_decoder_feeds_func;
update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func;
}
private:
// Device specific functions
BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_;
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
BeamSearchDeviceHelper::TopkFunc topk_func_;
BeamSearchDeviceHelper::ProcessLogitsFunc<float> process_logits_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<float> init_beam_state_func_;
BeamSearchDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
BeamSearchDeviceHelper::UpdateFeedsFunc<float> update_feeds_func_;
BeamSearchDeviceHelper::DeviceCopyFunc<int32_t> device_copy_int32_func_;
BeamSearchDeviceHelper::ProcessLogitsFunc<float> process_logits_func_;
BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16> process_logits_fp16_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16> init_beam_state_fp16_func_;
BeamSearchDeviceHelper::UpdateFeedsFunc<MLFloat16> update_feeds_fp16_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<float> init_beam_state_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16> init_beam_state_fp16_func_;
//------------------------------------------------------------
// Device specific functions for GPT
//------------------------------------------------------------
BeamSearchDeviceHelper::UpdateGptFeedsFunc<float> update_gpt_feeds_func_;
BeamSearchDeviceHelper::UpdateGptFeedsFunc<MLFloat16> update_gpt_feeds_fp16_func_;
//------------------------------------------------------------
// Device specific functions for encoder-decoder model like T5
//------------------------------------------------------------
BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_;
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float> update_decoder_feeds_func_;
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16> update_decoder_feeds_fp16_func_;
//------------------------------------------------------------
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
//------------------------------------------------------------
std::unique_ptr<GptSubgraph> gpt_subgraph_;
FeedsFetchesManager* feeds_fetches_manager_;
std::unique_ptr<T5EncoderSubgraph> t5_encoder_subgraph_;
std::unique_ptr<T5DecoderSubgraph> t5_decoder_subgraph_;
FeedsFetchesManager* encoder_feeds_fetches_manager_;
FeedsFetchesManager* decoder_feeds_fetches_manager_;
void* cuda_stream_;

View file

@ -1,10 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <vector>
#include <algorithm>
#include "core/providers/cpu/math/top_k.h"
#include "core/providers/cpu/math/softmax_shared.h"
#include "core/common/safeint.h"
#include "gsl/gsl"
#include "sequences.h"
#include "beam_search_scorer.h"
#include "beam_search_device_helper.h"
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
namespace onnxruntime {
namespace contrib {
@ -25,11 +32,10 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
input->DataType(), " is not supported yet");
}
OrtValue ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator) {
// Input shape (batch_size, sequence_length)
template <typename T>
void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) {
// Input shape (batch_size, sequence_length). The input is required with data type T.
// Output shape (batch_size * num_beams, sequence_length)
if (num_beams == 1)
return input;
const TensorShape& input_shape = input.Get<Tensor>().Shape();
const int64_t& batch_size = input_shape[0];
@ -38,31 +44,28 @@ OrtValue ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocat
int64_t dims[] = {batch_size * num_beams, sequence_length};
TensorShape expanded_shape(&dims[0], 2);
OrtValue expanded;
MLDataType element_type = input.Get<Tensor>().DataType();
ORT_ENFORCE(element_type == DataTypeImpl::GetType<int32_t>(), "input_ids, position_ids and attention_mask is required to be int32 data type");
ORT_ENFORCE(element_type == DataTypeImpl::GetType<T>());
Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded);
const int32_t* input_data = input.Get<Tensor>().Data<int32_t>();
int32_t* expanded_data = expanded.GetMutable<Tensor>()->MutableData<int32_t>();
int32_t* target = expanded_data;
const T* input_data = input.Get<Tensor>().Data<T>();
T* expanded_data = expanded.GetMutable<Tensor>()->MutableData<T>();
T* target = expanded_data;
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
memcpy(target, input_data + i * sequence_length, sizeof(int32_t) * sequence_length);
memcpy(target, input_data + i * sequence_length, sizeof(T) * sequence_length);
target += sequence_length;
}
}
return expanded;
}
Status CreateInputs(
Status CreateGptInputs(
const Tensor* original_input_ids,
int num_beams,
int pad_token_id,
gsl::span<int32_t>& sequence_lengths,
AllocatorPtr alloactor,
AllocatorPtr allocator,
OrtValue& expanded_input_ids,
OrtValue& expanded_position_ids,
OrtValue& expanded_attention_mask) {
@ -74,21 +77,22 @@ Status CreateInputs(
// Allocate position_ids and attention_mask based on shape of input_ids
auto element_type = DataTypeImpl::GetType<int32_t>();
const OrtMemoryInfo& location = alloactor->Info();
const OrtMemoryInfo& location = allocator->Info();
// Use original input_ids. This requires the input_ids for subgraph is also int32.
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
OrtValue input_ids;
Tensor::InitOrtValue(element_type, input_ids_shape, const_cast<Tensor*>(original_input_ids)->MutableData<int32_t>(), location, input_ids);
Tensor::InitOrtValue(element_type, input_ids_shape,
const_cast<Tensor*>(original_input_ids)->MutableData<int32_t>(), location, input_ids);
OrtValue position_ids;
Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, position_ids);
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, position_ids);
OrtValue attention_mask;
auto mask_type = DataTypeImpl::GetType<int32_t>();
Tensor::InitOrtValue(mask_type, input_ids_shape, alloactor, attention_mask);
Tensor::InitOrtValue(mask_type, input_ids_shape, allocator, attention_mask);
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
// Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens
@ -115,43 +119,45 @@ Status CreateInputs(
}
}
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask
// TODO: Try expand outputs after first subgraph call instead. That may get better performance, but more complex to implement.
expanded_input_ids = ExpandInputs(input_ids, num_beams, alloactor);
expanded_position_ids = ExpandInputs(position_ids, num_beams, alloactor);
expanded_attention_mask = ExpandInputs(attention_mask, num_beams, alloactor);
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
ExpandInputs<int32_t>(input_ids, num_beams, allocator, expanded_input_ids);
ExpandInputs<int32_t>(position_ids, num_beams, allocator, expanded_position_ids);
ExpandInputs<int32_t>(attention_mask, num_beams, allocator, expanded_attention_mask);
return Status::OK();
}
Status AddToFeeds(const IExecutionProvider* /*execution_provider*/,
OrtValue& input_ids,
OrtValue& position_ids,
OrtValue& attention_mask,
std::initializer_list<OrtValue> inputs,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& /*buffer*/) {
feeds.push_back(input_ids);
feeds.push_back(position_ids);
feeds.push_back(attention_mask);
for (auto& input : inputs) {
if (input.IsAllocated()) {
feeds.push_back(input);
}
}
return Status::OK();
}
template <typename T>
void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
transformers::IBeamSearchCpuState* cpu_state,
gsl::span<int32_t>& sequence_lengths,
int batch_size,
int num_beams,
gsl::span<const int32_t> input_ids_in_cpu,
int sequence_length,
int max_length,
void* /*stream*/) {
memset(beam_state->beam_scores.data(), 0, beam_state->beam_scores.size_bytes());
memset(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes());
memset(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes());
memset(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes());
memset(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes());
memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes());
// T5 does not need position, so next_positions is empty for T5.
if (!beam_state->next_positions.empty()) {
memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes());
gsl::copy(sequence_lengths, beam_state->next_positions);
}
// Initialize score of first beam of each group with 0 and the rest with -1e9.
// This ensures that the beams in the same group don't produce same tokens every time.
@ -161,19 +167,6 @@ void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
beam_scores[SafeInt<gsl::index>(i) * num_beams + j] = -1e9;
}
}
gsl::copy(sequence_lengths, beam_state->next_positions);
memset(cpu_state->sequences_space.data(), 0, cpu_state->sequences_space.size_bytes());
// Copy input_ids to sequences[0].
gsl::span<int32_t> sequences_0 = cpu_state->sequences_space;
int batch_beam_size = batch_size * num_beams;
for (int i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < sequence_length; j++) {
sequences_0[SafeInt<gsl::index>(i) * max_length + j] = static_cast<int32_t>(input_ids_in_cpu[SafeInt<gsl::index>(i) * sequence_length + j]);
}
}
}
template <typename T>
@ -216,7 +209,8 @@ Status ProcessLogits(const OrtValue& logits, //
const T* current_logits = logits_data + (input_length - 1) * vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const T> source(current_logits, vocab_size);
gsl::span<T> target = next_token_logits.subspan(SafeInt<gsl::index>(i) * vocab_size, static_cast<gsl::index>(vocab_size));
gsl::span<T> target = next_token_logits.subspan(SafeInt<gsl::index>(i) * vocab_size,
static_cast<gsl::index>(vocab_size));
gsl::copy(source, target);
current_logits += input_length * vocab_size;
}
@ -224,7 +218,9 @@ Status ProcessLogits(const OrtValue& logits, //
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("logits", logits);
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
if (input_length > 1) {
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
}
#endif
// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
@ -244,12 +240,12 @@ Status ProcessLogits(const OrtValue& logits, //
logits_processors->Process(sequences, next_token_scores, step);
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, num_beams, vocab_size);
dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size);
#endif
// Add beam score to next token scores. Corresponding python code is like:
// next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
// TODO: use thread pool to parrellel
// TODO(tianleiwu): use thread pool to parallel
int offset = 0;
int batch_beam_index = 0;
for (int i = 0; i < batch_size; i++) {
@ -261,7 +257,7 @@ Status ProcessLogits(const OrtValue& logits, //
}
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("next_token_scores after adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
dumper->Print("next_token_scores adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
#endif
if (output_scores) {
@ -277,7 +273,8 @@ Status ProcessLogits(const OrtValue& logits, //
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
auto element_type = DataTypeImpl::GetType<T>();
OrtValue next_token_scores_value;
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value);
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(),
next_token_scores_value);
const Tensor& input = next_token_scores_value.Get<Tensor>();
constexpr int axis = 1;
@ -287,7 +284,8 @@ Status ProcessLogits(const OrtValue& logits, //
std::unique_ptr<Tensor> topk_scores;
std::unique_ptr<Tensor> topk_indices;
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, topk_scores, topk_indices));
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
topk_scores, topk_indices));
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("topk_scores", *(topk_scores.get()));
@ -331,32 +329,34 @@ Status DeviceCopy(gsl::span<T> target, gsl::span<const T> source, void* /*stream
return Status::OK();
}
// Copy present state to past state for GPT model
template <typename T>
void PickPastState(const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
gsl::span<const int32_t>& beam_indices,
AllocatorPtr allocator,
void* /*stream*/) {
void PickGptPastState(const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
gsl::span<const int32_t>& beam_indices,
AllocatorPtr allocator) {
int num_present_tensors = static_cast<int>(last_outputs.size()) - transformers::GptSubgraph::kFirstPresentOutputIndex;
for (int i = 0; i < num_present_tensors; ++i) {
const OrtValue& present = last_outputs[transformers::GptSubgraph::kFirstPresentOutputIndex + i];
for (size_t i = 1; i < last_outputs.size(); ++i) {
const OrtValue& present = last_outputs[i]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64)
// shape is like (2, batch_beam_size, 12, past_seq_len, 64)
const TensorShape& past_shape = present.Get<Tensor>().Shape();
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
// Create a tensor with same shape.
// TODO: allocate one buffer for all layers
// TODO(tianleiwu): allocate one buffer for all layers
OrtValue past;
auto past_type = DataTypeImpl::GetType<T>();
Tensor::InitOrtValue(past_type, past_shape, allocator, past);
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
gsl::span<T> past_span = gsl::make_span<T>(past.GetMutable<Tensor>()->MutableData<T>(), past_shape.Size());
gsl::span<const T> present_span = gsl::make_span<const T>(present.Get<Tensor>().Data<T>(), past_shape.Size());
for (gsl::index j = 0; j < beam_indices.length(); j++) {
int32_t beam_index = beam_indices[j];
gsl::span<const T> present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
gsl::span<const T> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
gsl::span<const T> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam,
block_size_per_beam);
gsl::span<T> past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
gsl::span<T> past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
@ -364,12 +364,12 @@ void PickPastState(const std::vector<OrtValue>& last_outputs,
gsl::copy(present_value, past_value);
}
next_inputs[i + 2] = past;
next_inputs[transformers::GptSubgraph::kFirstPastInputIndex + i] = past;
}
}
template <typename T>
Status UpdateFeeds(
Status UpdateGptFeeds(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
@ -382,6 +382,7 @@ Status UpdateFeeds(
const transformers::IConsoleDumper* dumper) {
// last_outputs: logits, present_0, present_1, ...
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
ORT_UNUSED_PARAMETER(stream);
// The following updates inputs for subgraph
@ -391,7 +392,7 @@ Status UpdateFeeds(
TensorShape input_ids_shape(&dims[0], 2);
auto int32_type = DataTypeImpl::GetType<int32_t>();
OrtValue input_ids;
// TODO: Reuse buffer for input_ids to reduce memory allocation.
// TODO(tianleiwu): Reuse buffer for input_ids to reduce memory allocation.
Tensor::InitOrtValue(int32_type, input_ids_shape, allocator, input_ids);
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_beam_size; i++) {
@ -433,25 +434,187 @@ Status UpdateFeeds(
// Update past state
if (num_beams == 1) {
// feed present_* output to past_* inputs one by one
for (size_t i = 1; i < last_outputs.size(); ++i) {
next_inputs[i + 2] = last_outputs[i];
const int k = transformers::GptSubgraph::kFirstPastInputIndex - transformers::GptSubgraph::kFirstPresentOutputIndex;
for (size_t i = transformers::GptSubgraph::kFirstPresentOutputIndex; i < last_outputs.size(); ++i) {
next_inputs[i + k] = last_outputs[i];
}
} else {
PickPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream);
PickGptPastState<T>(last_outputs, next_inputs, beam_indices, allocator);
}
return Status::OK();
}
// ---------------------------------------------------------------
// The following functions are for encoder-decoder model like T5
// ---------------------------------------------------------------
Status CreateEncoderInputs(
const Tensor* original_encoder_input_ids,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids) {
const TensorShape& input_ids_shape = original_encoder_input_ids->Shape();
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
const int64_t& batch_size = input_ids_shape[0];
const int64_t& sequence_length = input_ids_shape[1];
// Allocate attention_mask based on shape of input_ids
auto element_type = DataTypeImpl::GetType<int32_t>();
// Use original encoder_input_ids. This requires the input_ids for subgraph is also int32.
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
OrtValue encoder_input_ids;
Tensor::InitOrtValue(element_type,
input_ids_shape,
const_cast<Tensor*>(original_encoder_input_ids)->MutableData<int32_t>(),
allocator->Info(),
encoder_input_ids);
OrtValue encoder_attention_mask;
auto mask_type = DataTypeImpl::GetType<int32_t>();
Tensor::InitOrtValue(mask_type, input_ids_shape, allocator, encoder_attention_mask);
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
int32_t* mask_data = encoder_attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();
const int32_t* word_id = original_encoder_input_ids->Data<int32_t>();
int32_t* mask = mask_data;
for (int i = 0; i < batch_size; i++) {
int32_t abs_position = 0;
for (int j = 0; j < sequence_length; j++, word_id++, mask++) {
// T5Tokenizer might add one EOS pad token at the end.
// That EOS token shall have attention mask 1 even when EOS token is same as pad token.
// Here we only set attention mask to be 0 for left padding only, so as to be parity with huggingface.
if (*word_id == pad_token_id && abs_position == 0) {
*mask = 0;
} else {
*mask = 1;
abs_position++;
}
}
}
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
// for encoder_input_ids and encoder_attention_mask
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
ExpandInputs<int32_t>(encoder_input_ids, num_beams, allocator, expanded_encoder_input_ids);
ExpandInputs<int32_t>(encoder_attention_mask, num_beams, allocator, expanded_encoder_attention_mask);
// decoder_input_ids is optional.
if (start_token_id >= 0) {
// Expanded decoder_input_ids has shape (batch_size * num_beams, 1), and filled with start token ID
int64_t dims[] = {batch_size * num_beams, 1};
TensorShape decoder_input_ids_shape(&dims[0], 2);
Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, expanded_decoder_input_ids);
int32_t* data = expanded_decoder_input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
for (int i = 0; i < batch_size * num_beams; i++, data++) {
*data = start_token_id;
}
}
return Status::OK();
}
// Copy present state to past state for T5 model
template <typename T>
void PickT5PastState(const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t>& beam_indices,
AllocatorPtr allocator) {
for (int i = 0; i < num_present_tensors; ++i) {
const OrtValue& present = last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i];
// shape is like (batch_beam_size, 12, past_seq_len, 64)
const TensorShape& past_shape = present.Get<Tensor>().Shape();
auto block_size_per_beam = past_shape[1] * past_shape[2] * past_shape[3];
// Create a tensor with same shape.
// TODO(tianleiwu): allocate one buffer for all layers
OrtValue past;
Tensor::InitOrtValue(DataTypeImpl::GetType<T>(), past_shape, allocator, past);
gsl::span<T> past_span = gsl::make_span<T>(past.GetMutable<Tensor>()->MutableData<T>(), past_shape.Size());
gsl::span<const T> present_span = gsl::make_span<const T>(present.Get<Tensor>().Data<T>(), past_shape.Size());
for (gsl::index j = 0; j < beam_indices.length(); j++) {
int32_t beam_index = beam_indices[j];
gsl::span<const T> present_beam = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
gsl::span<T> past_beam = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
gsl::copy(present_beam, past_beam);
}
next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] = past;
}
}
// Update decoder inputs given decoder outputs of last iteration.
template <typename T>
Status UpdateDecoderFeeds(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
const transformers::IConsoleDumper* dumper) {
ORT_UNUSED_PARAMETER(stream);
// last_outputs: logits, present_key_self_0, present_value_self_0, ...
// next_inputs: input_ids,
// encoder_attention_mask, encoder_hidden_states,
// past_key_self_0, past_value_self_0, ...
// past_key_cross_0, past_value_cross_0, ...
// Only need copy beam next tokens to input_ids, and copy present_*_self_* to past_*_self_*,
// Update input_ids with next tokens.
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
int64_t dims[] = {batch_beam_size, 1};
TensorShape input_ids_shape(&dims[0], 2);
// TODO(tianleiwu): Reuse buffer for input_ids to reduce memory allocation.
OrtValue input_ids;
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);
gsl::copy(beam_next_tokens, input_ids.GetMutable<Tensor>()->MutableDataAsSpan<int32_t>());
next_inputs[0] = input_ids;
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("input_ids", input_ids);
#else
ORT_UNUSED_PARAMETER(dumper);
#endif
// Update past state
ORT_ENFORCE(last_outputs.size() >= static_cast<size_t>(1 + num_present_tensors));
// TODO(tianleiwu): remove num_beams==1 once GreedySearch operator is available.
if (num_beams == 1) {
// feed present_* output to past_* inputs one by one
for (int i = 0; i < num_present_tensors; ++i) {
next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] =
last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i];
}
} else {
PickT5PastState<T>(last_outputs, next_inputs, num_present_tensors, beam_indices, allocator);
}
return Status::OK();
}
//------------------------------------------------
// Explicit template instantiations of functions
//------------------------------------------------
template void InitBeamState<float>(
transformers::IBeamSearchState<float>* beam_state,
transformers::IBeamSearchCpuState* cpu_state,
gsl::span<int32_t>& sequence_lengths,
int batch_size,
int num_beams,
gsl::span<const int32_t> input_ids_in_cpu,
int sequence_length,
int max_length,
void* stream);
template Status ProcessLogits<float>(
@ -472,9 +635,15 @@ template Status DeviceCopy<float>(
gsl::span<float> target,
gsl::span<const float> source,
void* stream,
int copyDirectionn);
int copyDirection);
template Status UpdateFeeds<float>(
template Status DeviceCopy<int32_t>(
gsl::span<int32_t> target,
gsl::span<const int32_t> source,
void* stream,
int copyDirection);
template Status UpdateGptFeeds<float>(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
@ -486,6 +655,19 @@ template Status UpdateFeeds<float>(
int num_beams,
const transformers::IConsoleDumper* dumper);
template Status UpdateDecoderFeeds<float>(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
const transformers::IConsoleDumper* dumper);
template void ExpandInputs<int32_t>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);
} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#ifndef SHARED_PROVIDER
@ -6,9 +9,10 @@
#include "core/framework/allocator.h"
#endif
#include <vector>
#include "gsl/gsl"
#include "logits_processor.h"
#include "beam_search_shared.h"
#include "contrib_ops/cpu/transformers/logits_processor.h"
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
namespace onnxruntime {
class IExecutionProvider;
@ -31,40 +35,34 @@ namespace BeamSearchDeviceHelper {
using TopkFunc = std::function<Status(
const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
void* stream, // cudaStream_t stream,
void* stream, // cudaStream_t
onnxruntime::concurrency::ThreadPool* threadpool,
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices)>;
// Create subgraph inputs: input_ids, position_ids and attention_mask
using CreateInputsFunc = std::function<Status(
// Create subgraph inputs: input_ids, position_ids and attention_mask (for GPT-2).
using CreateGptInputsFunc = std::function<Status(
const Tensor* original_input_ids,
int num_beams,
int pad_token_id,
gsl::span<int32_t>& sequence_lengths,
AllocatorPtr alloactor,
AllocatorPtr allocator,
OrtValue& expanded_input_ids,
OrtValue& expanded_position_ids,
OrtValue& expanded_attention_mask)>;
using AddToFeedsFunc = std::function<Status(
const IExecutionProvider* provider,
OrtValue& input_ids,
OrtValue& position_ids,
OrtValue& attention_mask,
std::initializer_list<OrtValue> inputs,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer)>;
template <typename T>
using InitBeamStateFunc = std::function<void(
transformers::IBeamSearchState<T>* beam_state,
transformers::IBeamSearchCpuState* cpu_state,
gsl::span<int32_t>& sequence_lengths,
int batch_size,
int num_beams,
gsl::span<const int32_t> input_ids_in_cpu,
int sequence_length,
int max_length,
void* stream)>;
template <typename T>
@ -89,8 +87,9 @@ using DeviceCopyFunc = std::function<Status(
void* stream,
int copyDirection)>;
// Update subgraph inputs given outputs of last iteration (for GPT-2).
template <typename T>
using UpdateFeedsFunc = std::function<Status(
using UpdateGptFeedsFunc = std::function<Status(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
@ -102,6 +101,29 @@ using UpdateFeedsFunc = std::function<Status(
int num_beams,
const transformers::IConsoleDumper* dumper)>;
// Create encoder inputs (for encoder-decoder model like T5).
using CreateEncoderInputsFunc = std::function<Status(
const Tensor* original_encoder_input_ids,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids)>;
// Update decoder inputs given decoder outputs of last iteration (for encoder-decoder model like T5).
template <typename T>
using UpdateDecoderFeedsFunc = std::function<Status(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
const transformers::IConsoleDumper* dumper)>;
} // namespace BeamSearchDeviceHelper
// These are CPU specific device helper implementations
@ -114,33 +136,17 @@ Status TopK(
std::unique_ptr<Tensor>& output_values,
std::unique_ptr<Tensor>& output_indices);
Status CreateInputs(
const Tensor* original_input_ids,
int num_beams,
int pad_token_id,
gsl::span<int32_t>& sequence_lengths,
AllocatorPtr alloactor,
OrtValue& expanded_input_ids,
OrtValue& expanded_position_ids,
OrtValue& expanded_attention_mask);
Status AddToFeeds(
const IExecutionProvider* execution_provider,
OrtValue& input_ids,
OrtValue& position_ids,
OrtValue& attention_mask,
std::initializer_list<OrtValue> inputs,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer);
template <typename T>
void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
transformers::IBeamSearchCpuState* cpu_state,
gsl::span<int32_t>& sequence_lengths,
int batch_size,
int num_beams,
gsl::span<const int32_t> input_ids_in_cpu,
int sequence_length,
int max_length,
void* stream);
template <typename T>
@ -163,8 +169,22 @@ Status DeviceCopy(gsl::span<T> target,
void* stream,
int copyDirectionn);
// ---------------------------------------------------------------
// Functions for GPT model only
// ---------------------------------------------------------------
Status CreateGptInputs(
const Tensor* original_input_ids,
int num_beams,
int pad_token_id,
gsl::span<int32_t>& sequence_lengths,
AllocatorPtr allocator,
OrtValue& expanded_input_ids,
OrtValue& expanded_position_ids,
OrtValue& expanded_attention_mask);
template <typename T>
Status UpdateFeeds(
Status UpdateGptFeeds(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
@ -176,6 +196,38 @@ Status UpdateFeeds(
int num_beams,
const transformers::IConsoleDumper* dumper);
// ---------------------------------------------------------------
// Functions for encoder-decoder model like T5
// ---------------------------------------------------------------
Status CreateEncoderInputs(
const Tensor* original_encoder_input_ids,
int num_beams,
int pad_token_id,
int start_token_id,
AllocatorPtr allocator,
OrtValue& expanded_encoder_input_ids,
OrtValue& expanded_encoder_attention_mask,
OrtValue& expanded_decoder_input_ids);
// Update decoder inputs given decoder outputs of last iteration.
template <typename T>
Status UpdateDecoderFeeds(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
const transformers::IConsoleDumper* dumper);
// ---------------------------------------------------------------
// Utility Functions
// ---------------------------------------------------------------
template <typename T>
void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);
} // namespace BeamSearchCpuDeviceHelper
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -0,0 +1,364 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <vector>
#include <utility>
namespace onnxruntime {
namespace contrib {
namespace transformers {
template <typename T>
gsl::span<T> AllocateBuffer(AllocatorPtr allocator,
BufferUniquePtr& buffer,
size_t elements,
bool fill = false,
T fill_value = T{}) {
size_t bytes = SafeInt<size_t>(sizeof(T)) * elements;
void* data = allocator->Alloc(bytes);
BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
buffer = std::move(temp_buffer);
T* first = reinterpret_cast<T*>(buffer.get());
auto span = gsl::make_span(first, elements);
if (fill) {
std::fill_n(first, elements, fill_value);
}
return span;
}
template <typename T>
struct BeamSearchState : public IBeamSearchState<T> {
void Init(AllocatorPtr allocator,
int batch_size,
int num_beams,
int vocab_size,
int sequence_length,
int max_length,
bool output_scores,
bool use_position) {
size_t batch_beam_size = SafeInt<size_t>(batch_size) * num_beams;
size_t next_token_size = SafeInt<size_t>(batch_beam_size) * vocab_size;
this->next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size);
this->next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size);
this->next_tokens = AllocateBuffer<int32_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size);
this->next_indices = AllocateBuffer<int32_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size);
if (use_position) {
this->next_positions = AllocateBuffer<int32_t>(allocator, next_positions_buffer_, batch_beam_size);
}
this->beam_scores = AllocateBuffer<float>(allocator, beam_scores_buffer_, batch_beam_size);
if (output_scores) {
size_t elements = SafeInt<size_t>(max_length - sequence_length) * batch_size * num_beams * vocab_size;
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements);
this->remaining_scores = this->scores;
}
}
private:
BufferUniquePtr next_token_logits_buffer_;
BufferUniquePtr next_token_scores_buffer_;
BufferUniquePtr next_tokens_buffer_;
BufferUniquePtr next_indices_buffer_;
BufferUniquePtr next_positions_buffer_;
BufferUniquePtr beam_scores_buffer_;
BufferUniquePtr scores_buffer_;
};
struct BeamSearchCpuState : public IBeamSearchCpuState {
Sequences sequences;
void Init(AllocatorPtr allocator, size_t batch_beam_size, int max_length, int sequence_length, bool is_cuda) {
this->sequence_lengths = AllocateBuffer<int32_t>(allocator, sequence_lengths_buffer_, batch_beam_size);
size_t sequences_bytes = SafeInt<size_t>(2) * batch_beam_size * max_length;
this->sequences_space = AllocateBuffer<int32_t>(allocator, sequences_space_buffer_, sequences_bytes);
memset(this->sequences_space.data(), 0, this->sequences_space.size_bytes());
if (is_cuda) {
// buffers used by CUDA operator but not by CPU operator.
this->topk_scores = AllocateBuffer<float>(allocator, topk_scores_buffer_, 2 * batch_beam_size);
this->topk_tokens = AllocateBuffer<int32_t>(allocator, topk_tokens_buffer_, 2 * batch_beam_size);
this->topk_indices = AllocateBuffer<int32_t>(allocator, topk_indices_buffer_, 2 * batch_beam_size);
this->final_beam_scores = AllocateBuffer<float>(allocator, final_beam_scores_buffer_, batch_beam_size);
}
this->sequences.Init(this->sequences_space, static_cast<int>(batch_beam_size), sequence_length, max_length);
}
// Copy input_ids to sequences[0]
void SetSequence(gsl::span<const int32_t> input_ids_in_cpu,
size_t batch_beam_size,
int max_length,
int sequence_length) {
gsl::span<int32_t> sequences_0 = sequences_space;
for (size_t i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < sequence_length; j++) {
const size_t index = SafeInt<gsl::index>(i) * max_length + j;
sequences_0[index] = input_ids_in_cpu[SafeInt<gsl::index>(i) * sequence_length + j];
}
}
}
private:
BufferUniquePtr final_beam_scores_buffer_;
BufferUniquePtr sequence_lengths_buffer_;
BufferUniquePtr topk_scores_buffer_;
BufferUniquePtr topk_tokens_buffer_;
BufferUniquePtr topk_indices_buffer_;
BufferUniquePtr sequences_space_buffer_;
};
// Base class of beam search implementation that is common for both GPT-2 and T5.
template <typename T>
class BeamSearchBase {
public:
BeamSearchBase(OpKernelContextInternal& context,
const SessionState& decoder_session_state,
concurrency::ThreadPool* thread_pool,
void* cuda_stream,
IConsoleDumper* cuda_dumper,
BeamSearchParameters& params,
const BeamSearchDeviceHelper::TopkFunc& topk_func,
const BeamSearchDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func)
: context_(context),
decoder_session_state_(decoder_session_state),
thread_pool_(thread_pool),
implicit_inputs_(context_.GetImplicitInputs()),
cuda_stream_(cuda_stream),
cuda_dumper_(cuda_dumper),
parameters_(&params),
cpu_allocator_(nullptr),
temp_space_allocator_(nullptr),
topk_func_(topk_func),
process_logits_func_(process_logits_func),
device_copy_func_(device_copy_func),
device_copy_int32_func_(device_copy_int32_func) {
parameters_->ParseFromInputs(&context);
cpu_allocator_ = decoder_session_state.GetExecutionProviders()
.Get(onnxruntime::kCpuExecutionProvider)
->GetAllocator(0, OrtMemTypeDefault);
}
// Initialize by validating all the inputs, and allocating the output tensors.
Status Initialize();
// Validate inputs.
Status CheckInputs(const OpKernelContextInternal& context);
protected:
// Process logits and append next tokens to sequences.
Status GenerateNextToken(const OrtValue& logits,
gsl::span<int32_t>& beam_next_tokens,
gsl::span<int32_t>& beam_indices,
BeamSearchState<T>& beam_state,
BeamSearchCpuState& cpu_state,
int counter);
// Calculate scores from logits, then apply filtering and select next token for each beam.
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
BeamSearchState<T>& beam_state,
BeamSearchCpuState& cpu_state,
AllocatorPtr& allocator,
int counter);
bool IsCuda() const { return cuda_stream_ != nullptr; }
const IConsoleDumper* GetConsoleDumper() const { return IsCuda() ? cuda_dumper_ : &(cpu_dumper_); }
OpKernelContextInternal& context_;
const SessionState& decoder_session_state_;
concurrency::ThreadPool* thread_pool_;
const std::vector<const OrtValue*>& implicit_inputs_;
void* cuda_stream_;
IConsoleDumper* cuda_dumper_;
CpuTensorConsoleDumper cpu_dumper_;
BeamSearchParameters* parameters_;
LogitsProcessorList logits_processors_;
std::unique_ptr<BeamSearchScorer> beam_scorer_;
AllocatorPtr cpu_allocator_;
AllocatorPtr temp_space_allocator_;
// Device specific functions
BeamSearchDeviceHelper::TopkFunc topk_func_;
BeamSearchDeviceHelper::ProcessLogitsFunc<T> process_logits_func_;
BeamSearchDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
BeamSearchDeviceHelper::DeviceCopyFunc<int32_t> device_copy_int32_func_;
};
template <typename T>
Status BeamSearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
// Input shapes:
// input_ids : (batch_size, sequence_length)
// vocab_mask : (vocab_size) or nullptr
const Tensor* input_ids = context.Input<Tensor>(0);
const auto& dims = input_ids->Shape().GetDims();
if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'input_ids' is expected to have 2 dimensions, got ", dims.size());
}
const Tensor* vocab_mask = context.Input<Tensor>(8);
if (vocab_mask != nullptr) { // vocab_mask is optional
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
if (vocab_mask_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'vocab_mask' is expected to have 1 dimension, got ", vocab_mask_dims.size());
}
// There is dependency on vocab_size parameter, which shall be set before calling this function.
if (static_cast<int>(vocab_mask_dims[0]) != parameters_->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'vocab_mask' shape does not match with vocab_size, got ", vocab_mask_dims[0]);
}
// store vocab mask in parameters.
parameters_->vocab_mask = vocab_mask->DataAsSpan<int32_t>();
}
const Tensor* prefix_vocab_mask = context.Input<Tensor>(9);
if (prefix_vocab_mask != nullptr) {
// prefix_vocab_mask is optional
const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims();
if (vocab_mask_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'prefix_vocab_mask' is expected to be 2 dimensions, got ", vocab_mask_dims.size());
}
// prefix_vocab_mask first dimension should be same as the first dimension of input_ids
if (static_cast<int>(vocab_mask_dims[0]) != static_cast<int>(dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input_ids and prefix_vocab_mask must have the same batch_size");
}
// There is dependency on vocab_size parameter, which shall be set before calling this function.
if (static_cast<int>(vocab_mask_dims[1]) != parameters_->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'prefix_vocab_mask' shape[1] shall be vocab_size, got ", vocab_mask_dims[1]);
}
// store prefix vocab mask in parameters.
parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan<int32_t>();
}
return Status::OK();
}
template <typename T>
Status BeamSearchBase<T>::Initialize() {
ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator_));
#define CHECK_SCALAR_INPUT(name, index, required) \
auto* name##_tensor = context_.Input<Tensor>(index); \
if (name##_tensor) { \
if (!name##_tensor->Shape().IsScalar()) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \
name##_tensor->Shape()); \
} \
} else if (required) { \
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \
}
CHECK_SCALAR_INPUT(min_length, 1, false);
CHECK_SCALAR_INPUT(max_length, 2, true);
CHECK_SCALAR_INPUT(num_beams, 3, true);
CHECK_SCALAR_INPUT(num_return_sequences, 4, true);
CHECK_SCALAR_INPUT(temperature, 5, true);
CHECK_SCALAR_INPUT(length_penalty, 6, true);
ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams,
"'num_return_sequences' has to be smaller or equal to 'num_beams'.");
ORT_RETURN_IF_ERROR(CheckInputs(context_));
// This flag will be updated later when the scores output exists.
parameters_->output_scores = false;
if (!IsCuda()) {
// Logits processor is used in CPU only. In CUDA, cuda kernels are used instead.
// Initialize processors after CheckInputs so that parameters_->vocab_mask is ready.
logits_processors_.Init(*parameters_);
}
return Status::OK();
}
template <typename T>
Status BeamSearchBase<T>::ProcessLogits(
const OrtValue& logits,
BeamSearchState<T>& beam_state,
BeamSearchCpuState& cpu_state,
AllocatorPtr& allocator,
int counter) {
return process_logits_func_(logits, &beam_state, &cpu_state, &(cpu_state.sequences), allocator,
thread_pool_, &logits_processors_, beam_scorer_.get(),
parameters_, counter, cuda_stream_, GetConsoleDumper());
}
template <typename T>
Status BeamSearchBase<T>::GenerateNextToken(
const OrtValue& logits,
gsl::span<int32_t>& beam_next_tokens,
gsl::span<int32_t>& beam_indices,
BeamSearchState<T>& beam_state,
BeamSearchCpuState& cpu_state,
int counter) {
// Process logits to get next token scores
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, cpu_state, temp_space_allocator_, counter));
gsl::span<float>& beam_scores = beam_scorer_->GetNextScores();
// It is optional to clone beam_scores. Change it to use same buffer also works for CPU:
// beam_state.beam_scores = beam_scores
// Here we make a copy to reduce the coupling with little cost (the buffer size is small).
ORT_RETURN_IF_ERROR(device_copy_func_(beam_state.beam_scores,
beam_scores,
cuda_stream_,
DeviceCopyDirection::hostToDevice));
beam_next_tokens = beam_scorer_->GetNextTokens();
beam_indices = beam_scorer_->GetNextIndices();
#ifdef DEBUG_BEAM_SEARCH
cpu_dumper_.Print("beam_scores from scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams);
cpu_dumper_.Print("beam_next_tokens", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams);
cpu_dumper_.Print("beam_indices", beam_indices.data(), parameters_->batch_size, parameters_->num_beams);
#endif
cpu_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens);
#ifdef DEBUG_BEAM_SEARCH
cpu_state.sequences.PrintSequences(&cpu_dumper_);
#endif
return Status::OK();
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,278 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cpu/transformers/beam_search_impl_base.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
// Beam search implementation for GPT-2 model.
template <typename T>
class BeamSearchGpt : public BeamSearchBase<T> {
public:
BeamSearchGpt(OpKernelContextInternal& context,
const SessionState& decoder_session_state,
GptSubgraph& gpt_subgraph,
concurrency::ThreadPool* thread_pool,
void* cuda_stream,
IConsoleDumper* cuda_dumper,
BeamSearchParameters& params,
const BeamSearchDeviceHelper::CreateGptInputsFunc& create_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const BeamSearchDeviceHelper::TopkFunc& topk_func,
const BeamSearchDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
const BeamSearchDeviceHelper::UpdateGptFeedsFunc<T>& update_feeds_func)
: BeamSearchBase<T>(context, decoder_session_state, thread_pool,
cuda_stream, cuda_dumper, params,
topk_func, process_logits_func, device_copy_func, device_copy_int32_func),
gpt_subgraph_(gpt_subgraph),
create_inputs_func_(create_inputs_func),
add_to_feeds_func_(add_to_feeds_func),
init_beam_state_func_(init_beam_state_func),
update_feeds_func_(update_feeds_func) {
}
// Execute beam search in iterations util stopping criteria is reached.
// In each iteration, GPT subgraph is called, and next token for each sequence is generated.
Status Execute(const FeedsFetchesManager& feeds_fetches_manager);
private:
// Prepare the inputs for first inference of subgraph
Status CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer);
// Update the input for next iteration.
Status UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices);
GptSubgraph& gpt_subgraph_;
// Device specific functions
BeamSearchDeviceHelper::CreateGptInputsFunc create_inputs_func_;
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
BeamSearchDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
};
template <typename T>
Status BeamSearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer) {
const OrtValue* input_ids_value = this->context_.GetInputOrtValue(0);
const Tensor& input_ids = input_ids_value->Get<Tensor>();
return gpt_subgraph_.CreateInitialFeeds(input_ids,
this->implicit_inputs_,
this->parameters_->num_beams,
this->parameters_->pad_token_id,
sequence_lengths,
expanded_input_ids,
feeds,
this->create_inputs_func_,
this->add_to_feeds_func_,
buffer);
}
template <typename T>
Status BeamSearchGpt<T>::UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
OrtValue& position_ids,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices) {
return update_feeds_func_(this->temp_space_allocator_,
this->cuda_stream_,
last_outputs,
next_inputs,
current_length,
position_ids,
beam_next_tokens,
beam_indices,
this->parameters_->num_beams,
this->GetConsoleDumper());
}
template <typename T>
Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager& feeds_fetches_manager) {
auto status = Status::OK();
const BeamSearchParameters* parameters = this->parameters_;
int64_t sequences_dims[] = {parameters->batch_size, parameters->num_return_sequences, parameters->max_length};
TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0]));
Tensor* output_sequences = this->context_.Output(0, sequences_shape);
int64_t sequences_scores_dims[] = {parameters->batch_size, parameters->num_return_sequences};
TensorShape sequences_scores_shape(&sequences_scores_dims[0], 2);
Tensor* output_sequences_scores = this->context_.Output(1, sequences_scores_shape);
int64_t scores_dims[] = {
static_cast<int64_t>(parameters->max_length) - static_cast<int64_t>(parameters->sequence_length),
parameters->batch_size, parameters->num_beams, parameters->vocab_size};
TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0]));
Tensor* output_scores = this->context_.Output(2, scores_shape);
// Update the flag to indicate whether scores exists in output
this->parameters_->output_scores = (output_scores != nullptr);
std::vector<OrtValue> feeds;
// TODO(tianleiwu): allocate fetches. use ping-pong buffers for past state.
std::vector<OrtValue> fetches;
// Initialize resources
onnxruntime::OrtStlAllocator<HypothesisScore> hypothesis_score_allocator(this->cpu_allocator_);
onnxruntime::OrtStlAllocator<BeamHypotheses> beam_hyps_allocator(this->cpu_allocator_);
this->beam_scorer_ = std::make_unique<BeamSearchScorer>(static_cast<size_t>(parameters->batch_size),
static_cast<size_t>(parameters->num_beams),
static_cast<size_t>(parameters->max_length),
parameters->length_penalty,
parameters->early_stopping,
static_cast<size_t>(parameters->num_return_sequences),
parameters->pad_token_id,
parameters->eos_token_id,
hypothesis_score_allocator,
beam_hyps_allocator);
this->beam_scorer_->Initialize(this->cpu_allocator_, parameters->sequence_length);
BeamSearchCpuState cpu_state;
cpu_state.Init(this->cpu_allocator_,
static_cast<size_t>(parameters->BatchBeamSize()),
parameters->max_length,
parameters->sequence_length,
this->IsCuda());
// buffer in GPU for input_ids, position_ids and attention_mask
IAllocatorUniquePtr<char> buffer;
OrtValue expanded_input_ids_in_cpu;
ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
BeamSearchState<T> beam_state;
constexpr bool use_position = true;
beam_state.Init(this->temp_space_allocator_,
parameters->batch_size,
parameters->num_beams,
parameters->vocab_size,
parameters->sequence_length,
parameters->max_length,
parameters->output_scores,
use_position);
init_beam_state_func_(&beam_state,
cpu_state.sequence_lengths,
parameters->batch_size,
parameters->num_beams,
this->cuda_stream_);
gsl::span<const int32_t> input_ids = expanded_input_ids_in_cpu.Get<Tensor>().DataAsSpan<int32_t>();
cpu_state.SetSequence(input_ids,
static_cast<size_t>(parameters->BatchBeamSize()),
parameters->max_length,
parameters->sequence_length);
#ifdef DEBUG_BEAM_SEARCH
const IConsoleDumper* dumper = this->GetConsoleDumper();
dumper->Print("input_ids", feeds[0]);
dumper->Print("position_ids", feeds[1]);
dumper->Print("attention_mask", feeds[2]);
#endif
// Position ids for all iterations except the first. It uses memory buffer owned by next_positions.
OrtValue position_ids;
int64_t dims[] = {parameters->BatchBeamSize(), 1};
TensorShape shape(&dims[0], 2);
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(),
shape,
beam_state.next_positions.data(),
this->temp_space_allocator_->Info(),
position_ids);
int current_length = parameters->sequence_length;
int iteration_counter = 0;
while (current_length < parameters->max_length) {
iteration_counter++;
#ifdef DEBUG_BEAM_SEARCH
auto cur_len = std::to_string(current_length);
dumper->Print("***CurrentLength", cur_len, true);
#endif
status = utils::ExecuteSubgraph(this->decoder_session_state_,
feeds_fetches_manager,
feeds,
fetches,
{},
ExecutionMode::ORT_SEQUENTIAL,
this->context_.GetTerminateFlag(),
this->context_.Logger());
ORT_RETURN_IF_ERROR(status);
const OrtValue& logits = fetches[0];
gsl::span<int32_t> beam_next_tokens;
gsl::span<int32_t> beam_indices;
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
beam_next_tokens,
beam_indices,
beam_state,
cpu_state,
iteration_counter));
// When all batches are finished, stop earlier to avoid wasting computation.
if (this->beam_scorer_->IsDone()) {
break;
}
// Increase sequence length after a new token is generated.
++current_length;
// Prepare inputs for next round of subgraph call.
if (current_length < parameters->max_length) {
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
position_ids,
beam_next_tokens.as_span<const int32_t>(),
beam_indices.as_span<const int32_t>()));
}
fetches.clear();
}
gsl::span<const float> final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
if (this->IsCuda()) {
ORT_RETURN_IF_ERROR(this->device_copy_func_(cpu_state.final_beam_scores,
final_beam_scores,
nullptr,
DeviceCopyDirection::deviceToHost));
final_beam_scores = gsl::make_span<const float>(cpu_state.final_beam_scores.data(),
cpu_state.final_beam_scores.size());
}
this->beam_scorer_->Finalize(&(cpu_state.sequences),
final_beam_scores,
output_sequences,
output_sequences_scores);
// Output per token scores
if (output_scores != nullptr) {
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
gsl::span<const float> source = gsl::span<const float>(beam_state.scores.data(), beam_state.scores.size());
assert(target.length() == source.length());
ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
}
return status;
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,307 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cpu/transformers/beam_search_shared.h" // for DEBUG_BEAM_SEARCH
#include "contrib_ops/cpu/transformers/beam_search_impl_base.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
// Beam search implementation for T5 model.
template <typename T>
class BeamSearchT5 : public BeamSearchBase<T> {
public:
BeamSearchT5(OpKernelContextInternal& context,
const SessionState& encoder_session_state,
const SessionState& decoder_session_state,
T5EncoderSubgraph& encoder_subgraph,
T5DecoderSubgraph& decoder_subgraph,
concurrency::ThreadPool* thread_pool,
void* cuda_stream,
IConsoleDumper* cuda_dumper,
BeamSearchParameters& params,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const BeamSearchDeviceHelper::TopkFunc& topk_func,
const BeamSearchDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
const BeamSearchDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func,
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<T>& update_decoder_feeds_func)
: BeamSearchBase<T>(context, decoder_session_state, thread_pool,
cuda_stream, cuda_dumper, params,
topk_func, process_logits_func, device_copy_func, device_copy_int32_func),
encoder_session_state_(encoder_session_state),
encoder_subgraph_(encoder_subgraph),
decoder_subgraph_(decoder_subgraph),
add_to_feeds_func_(add_to_feeds_func),
init_beam_state_func_(init_beam_state_func),
create_encoder_inputs_func_(create_encoder_inputs_func),
update_decoder_feeds_func_(update_decoder_feeds_func) {
}
// Execute beam search in iterations util stopping criteria is reached.
Status Execute(const FeedsFetchesManager& encoder_feeds_fetches_manager,
const FeedsFetchesManager& decoder_feeds_fetches_manager);
private:
const SessionState& encoder_session_state_;
T5EncoderSubgraph& encoder_subgraph_;
T5DecoderSubgraph& decoder_subgraph_;
// Device specific functions
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
BeamSearchDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_;
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<T> update_decoder_feeds_func_;
};
template <typename T>
Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches_manager,
const FeedsFetchesManager& decoder_feeds_fetches_manager) {
auto status = Status::OK();
const BeamSearchParameters* parameters = this->parameters_;
ORT_ENFORCE(parameters->sequence_length == 1);
// Allocate output tensors.
int64_t sequences_dims[] = {parameters->batch_size, parameters->num_return_sequences, parameters->max_length};
TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0]));
Tensor* output_sequences = this->context_.Output(0, sequences_shape);
int64_t sequences_scores_dims[] = {parameters->batch_size, parameters->num_return_sequences};
constexpr int64_t dims = sizeof(sequences_scores_dims) / sizeof(sequences_scores_dims[0]);
TensorShape sequences_scores_shape(&sequences_scores_dims[0], dims);
Tensor* output_sequences_scores = this->context_.Output(1, sequences_scores_shape);
int64_t scores_dims[] = {
static_cast<int64_t>(parameters->max_length) - static_cast<int64_t>(parameters->sequence_length),
parameters->batch_size, parameters->num_beams, parameters->vocab_size};
TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0]));
Tensor* output_scores = this->context_.Output(2, scores_shape);
// Update the flag to indicate whether scores exists in output
this->parameters_->output_scores = (output_scores != nullptr);
// ------------------------------------
// Call encoder subgraph.
// ------------------------------------
std::vector<OrtValue> encoder_feeds;
std::vector<OrtValue> encoder_fetches;
const OrtValue* encoder_input_ids_value = this->context_.GetInputOrtValue(0);
const Tensor& encoder_input_ids = encoder_input_ids_value->Get<Tensor>();
BeamSearchCpuState cpu_state;
cpu_state.Init(this->cpu_allocator_,
static_cast<size_t>(parameters->BatchBeamSize()),
parameters->max_length,
parameters->sequence_length,
this->IsCuda());
IAllocatorUniquePtr<char> buffer;
OrtValue expanded_decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state
ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds(
encoder_input_ids,
this->implicit_inputs_,
parameters->num_beams,
parameters->pad_token_id,
parameters->decoder_start_token_id,
encoder_feeds,
this->create_encoder_inputs_func_,
this->add_to_feeds_func_,
buffer,
expanded_decoder_input_ids));
ORT_RETURN_IF_ERROR(utils::ExecuteSubgraph(this->encoder_session_state_,
encoder_feeds_fetches_manager,
encoder_feeds,
encoder_fetches,
{},
ExecutionMode::ORT_SEQUENTIAL,
this->context_.GetTerminateFlag(),
this->context_.Logger()));
#ifdef DEBUG_BEAM_SEARCH
const IConsoleDumper* dumper = this->GetConsoleDumper();
for (size_t i = 0; i < encoder_feeds.size(); i++) {
dumper->Print("encoder_feeds", static_cast<int>(i), true);
dumper->Print("", encoder_feeds[i]);
}
for (int i = 0; i <= T5EncoderSubgraph::kFirstPresentOutputIndex; i++) {
dumper->Print("encoder_fetches", i, true);
dumper->Print("", encoder_fetches[i]);
}
#endif
// ------------------------------------
// Initialize resources
// ------------------------------------
// Copy expanded_decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam.
cpu_state.SetSequence(expanded_decoder_input_ids.Get<Tensor>().DataAsSpan<int32_t>(),
static_cast<size_t>(parameters->BatchBeamSize()),
parameters->max_length,
parameters->sequence_length);
onnxruntime::OrtStlAllocator<HypothesisScore> hypothesis_score_allocator(this->cpu_allocator_);
onnxruntime::OrtStlAllocator<BeamHypotheses> beam_hyps_allocator(this->cpu_allocator_);
this->beam_scorer_ = std::make_unique<BeamSearchScorer>(static_cast<size_t>(parameters->batch_size),
static_cast<size_t>(parameters->num_beams),
static_cast<size_t>(parameters->max_length),
parameters->length_penalty,
parameters->early_stopping,
static_cast<size_t>(parameters->num_return_sequences),
parameters->pad_token_id,
parameters->eos_token_id,
hypothesis_score_allocator,
beam_hyps_allocator);
this->beam_scorer_->Initialize(this->cpu_allocator_, parameters->sequence_length);
BeamSearchState<T> beam_state;
constexpr bool use_position = false;
beam_state.Init(this->temp_space_allocator_,
parameters->batch_size,
parameters->num_beams,
parameters->vocab_size,
parameters->sequence_length,
parameters->max_length,
parameters->output_scores,
use_position);
init_beam_state_func_(&beam_state,
cpu_state.sequence_lengths,
parameters->batch_size,
parameters->num_beams,
this->cuda_stream_);
// ------------------------------------------------------------------------------
// Generate next token from logits output from encoder, and initialize decoder inputs.
// ------------------------------------------------------------------------------
gsl::span<int32_t> beam_next_tokens;
gsl::span<int32_t> beam_indices;
int iteration_counter = 0;
std::vector<OrtValue> decoder_feeds;
int current_length = parameters->sequence_length;
if (current_length + 1 < parameters->max_length) {
++iteration_counter;
ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0],
beam_next_tokens,
beam_indices,
beam_state,
cpu_state,
iteration_counter));
++current_length; // Increase sequence length after a new token is generated.
ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(beam_next_tokens.as_span<const int32_t>(),
this->implicit_inputs_,
encoder_feeds,
encoder_fetches,
decoder_feeds,
this->device_copy_int32_func_,
this->cuda_stream_));
}
// TODO(tianleiwu): allocate fetches. use ping-pong buffers for past state.
std::vector<OrtValue> decoder_fetches;
while (current_length < parameters->max_length) {
iteration_counter++;
#ifdef DEBUG_BEAM_SEARCH
auto cur_len = std::to_string(current_length);
dumper->Print("***CurrentLength", cur_len, true);
for (int i = 0; i <= T5DecoderSubgraph::kFirstPastInputIndex; i++) {
dumper->Print("decoder_feeds", i, true);
dumper->Print("", decoder_feeds[i]);
}
#endif
status = utils::ExecuteSubgraph(this->decoder_session_state_,
decoder_feeds_fetches_manager,
decoder_feeds,
decoder_fetches,
{},
ExecutionMode::ORT_SEQUENTIAL,
this->context_.GetTerminateFlag(),
this->context_.Logger());
ORT_RETURN_IF_ERROR(status);
#ifdef DEBUG_BEAM_SEARCH
for (int i = 0; i <= T5DecoderSubgraph::kFirstPresentOutputIndex; i++) {
dumper->Print("decoder_fetches", i, true);
dumper->Print("", decoder_fetches[i]);
}
#endif
const OrtValue& logits = decoder_fetches[0];
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
beam_next_tokens,
beam_indices,
beam_state,
cpu_state,
iteration_counter));
// When all batches are finished, stop earlier to avoid wasting computation.
if (this->beam_scorer_->IsDone()) {
break;
}
// Increase sequence length after a new token is generated.
++current_length;
// Prepare inputs for next round of subgraph call.
if (current_length < parameters->max_length) {
const int num_present_outputs = 2 * parameters->num_layers; // number of outputs with name like present_*
ORT_RETURN_IF_ERROR(this->update_decoder_feeds_func_(
this->temp_space_allocator_,
this->cuda_stream_,
decoder_fetches,
decoder_feeds,
num_present_outputs,
beam_next_tokens.as_span<const int32_t>(),
beam_indices.as_span<const int32_t>(),
parameters->num_beams,
this->GetConsoleDumper()));
}
decoder_fetches.clear();
}
gsl::span<const float> final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
if (this->IsCuda()) {
ORT_RETURN_IF_ERROR(this->device_copy_func_(cpu_state.final_beam_scores,
final_beam_scores,
nullptr,
DeviceCopyDirection::deviceToHost));
final_beam_scores = gsl::make_span<const float>(cpu_state.final_beam_scores.data(),
cpu_state.final_beam_scores.size());
}
this->beam_scorer_->Finalize(&(cpu_state.sequences),
final_beam_scores,
output_sequences,
output_sequences_scores);
// Output per token scores
if (output_scores != nullptr) {
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
gsl::span<const float> source = gsl::span<const float>(beam_state.scores.data(), beam_state.scores.size());
assert(target.length() == source.length());
ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
}
return status;
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,6 +1,7 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "beam_search_parameters.h"
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
namespace onnxruntime {
namespace contrib {
@ -17,10 +18,11 @@ Status BeamSearchParameters::Validate() const {
}
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", 0));
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IBeamSearchParameters::kModelTypeGpt));
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
decoder_start_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("decoder_start_token_id", -1));
no_repeat_ngram_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_repeat_ngram_size", 0));
}
@ -30,25 +32,32 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
const auto& dims = input_ids->Shape().GetDims();
ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
batch_size = static_cast<int>(dims[0]);
sequence_length = static_cast<int>(dims[1]);
// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;
auto* max_length_tensor = context->Input<Tensor>(1);
max_length = max_length_tensor ? static_cast<int>(*max_length_tensor->Data<int32_t>()) : kMaxSequenceLength;
ORT_ENFORCE(max_length > sequence_length, "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")");
ORT_ENFORCE(max_length <= kMaxSequenceLength, "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength);
ORT_ENFORCE(max_length > sequence_length,
"max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")");
ORT_ENFORCE(max_length <= kMaxSequenceLength,
"max_length (", max_length, ") shall be no more than ", kMaxSequenceLength);
auto* min_length_tensor = context->Input<Tensor>(2);
min_length = min_length_tensor ? static_cast<int>(*min_length_tensor->Data<int32_t>()) : 0;
auto* num_beams_tensor = context->Input<Tensor>(3);
num_beams = num_beams_tensor ? static_cast<int>(*num_beams_tensor->Data<int32_t>()) : 1;
// TODO: limit num_beams > 1 when we can have another operator for greedy search.
ORT_ENFORCE(num_beams >= 1 && num_beams <= kMaxNumBeams, "num_beams shall be a positive integer no more than ", kMaxNumBeams, ", got ", num_beams);
// TODO(tianleiwu): limit num_beams > 1 when we can have another operator for greedy search.
ORT_ENFORCE(num_beams >= 1 && num_beams <= kMaxNumBeams,
"num_beams shall be a positive integer no more than ", kMaxNumBeams, ", got ", num_beams);
auto* num_return_sequences_tensor = context->Input<Tensor>(4);
num_return_sequences = num_return_sequences_tensor ? static_cast<int>(*num_return_sequences_tensor->Data<int32_t>()) : 1;
ORT_ENFORCE(num_return_sequences >= 1, "num_return_sequences shall be a positive integer, got ", num_return_sequences);
ORT_ENFORCE(num_beams >= num_return_sequences, "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
num_return_sequences = num_return_sequences_tensor ? *num_return_sequences_tensor->Data<int32_t>() : 1;
ORT_ENFORCE(num_return_sequences >= 1,
"num_return_sequences shall be a positive integer, got ", num_return_sequences);
ORT_ENFORCE(num_beams >= num_return_sequences,
"num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
auto* temperature_tensor = context->Input<Tensor>(5);
temperature = temperature_tensor ? static_cast<float>(*temperature_tensor->Data<float>()) : 1;

View file

@ -4,7 +4,7 @@
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "beam_search_shared.h"
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
namespace onnxruntime {
namespace contrib {

View file

@ -10,7 +10,7 @@
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/cpu/rnn/rnn_helpers.h"
#include "beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
namespace onnxruntime {
namespace contrib {
@ -69,7 +69,7 @@ void BeamHypotheses::Output(
}
// Since pop get the worst sequence, so output it in the reverse order.
// The frist (worst) beam shall be put at the last position among top_k sequences.
// The first (worst) beam shall be put at the last position among top_k sequences.
int index = top_k - 1;
while (!beams_.empty()) {
auto item = beams_.top();
@ -124,7 +124,7 @@ void BeamSearchScorer::Initialize(AllocatorPtr& allocator, int sequence_length)
ORT_ENFORCE(next_beam_scores_.empty()); // Make sure this is called only once.
size_t batch_beam_size = batch_size_ * num_beams_;
constexpr bool no_fill = false; // do not fill values after allocation
constexpr bool no_fill = false; // Do not fill values after allocation
done_ = Allocate<bool>(allocator, batch_size_, done_ptr_, no_fill);
std::fill_n(done_.data(), done_.size(), false);
@ -134,8 +134,8 @@ void BeamSearchScorer::Initialize(AllocatorPtr& allocator, int sequence_length)
next_beam_indices_ = Allocate<int32_t>(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill);
// Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
size_t buffer_per_beam = (SafeInt<size_t>(max_length_) * (max_length_ + 1) - SafeInt<size_t>(sequence_length - 1) * sequence_length) / 2;
hypothesis_buffer_length_ = batch_beam_size * buffer_per_beam;
size_t per_beam = (SafeInt<size_t>(max_length_) * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2;
hypothesis_buffer_length_ = batch_beam_size * per_beam;
hypothesis_buffer_ = Allocate<int32_t>(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill);
}
@ -155,7 +155,8 @@ void BeamSearchScorer::Process(ISequences* sequences,
for (size_t batch = 0; batch < batch_size_; batch++) {
BeamHypotheses& beam_hyp = beam_hyps_[batch];
if (done_[batch]) {
ORT_ENFORCE(beam_hyp.Size() >= gsl::narrow_cast<int>(num_beams_), "Batch can only be done if all beams have been generated");
ORT_ENFORCE(beam_hyp.Size() >= gsl::narrow_cast<int>(num_beams_),
"Batch can only be done if all beams have been generated");
// Pad the batch.
for (size_t j = 0; j < num_beams_; j++) {
@ -203,7 +204,7 @@ void BeamSearchScorer::Process(ISequences* sequences,
}
ORT_ENFORCE(beam_idx == num_beams_);
ORT_ENFORCE(hypothesis_buffer_offset_ <= batch_size_ * num_beams_ * max_length_);
ORT_ENFORCE(hypothesis_buffer_offset_ <= hypothesis_buffer_length_);
// Check if we are done so that we can save a pad step if all(done)
if (!done_[batch]) {
@ -258,7 +259,8 @@ void BeamSearchScorer::Finalize(ISequences* sequences,
BeamHypotheses& beam_hyp = beam_hyps_[batch_index];
const size_t num_return_sequences = num_beam_hyps_to_keep_;
auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, num_return_sequences * max_length_);
auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_,
num_return_sequences * max_length_);
if (output_sequence_scores != nullptr) {
batch_sequence_score = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences);
@ -274,4 +276,4 @@ void BeamSearchScorer::Finalize(ISequences* sequences,
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -12,8 +12,8 @@
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "core/providers/cpu/containers.h"
#include "sequences.h"
#include "beam_search_shared.h"
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
namespace onnxruntime {
namespace contrib {
@ -52,7 +52,7 @@ class BeamHypotheses {
// Output results. Note that it will clear all beams.
void Output(int top_k, // number of sequences to return
int max_length, // max sequence length
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, with shape (num_return_sequences, max_length)
gsl::span<int32_t>& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
gsl::span<float>& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
private:
@ -60,7 +60,9 @@ class BeamHypotheses {
float length_penalty_;
bool early_stopping_;
float worst_score_;
std::priority_queue<HypothesisScore, onnxruntime::FastAllocVector<HypothesisScore>, HypothesisScoreCompare> beams_; // min-heap for top k
// Min-heap for top k
std::priority_queue<HypothesisScore, onnxruntime::FastAllocVector<HypothesisScore>, HypothesisScoreCompare> beams_;
};
class BeamSearchScorer : public IBeamScorer {
@ -103,7 +105,7 @@ class BeamSearchScorer : public IBeamScorer {
int eos_token_id_;
IAllocatorUniquePtr<bool> done_ptr_; // Allocated buffer for done_
gsl::span<bool> done_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size).
gsl::span<bool> done_; // Flags indicates whether each batch is finished or not. Shape is (batch_size).
IAllocatorUniquePtr<float> next_beam_scores_ptr_;
gsl::span<float> next_beam_scores_;
@ -117,11 +119,11 @@ class BeamSearchScorer : public IBeamScorer {
IAllocatorUniquePtr<int32_t> hypothesis_buffer_ptr_; // Allocated buffer to hold all hypotheses
gsl::span<int32_t> hypothesis_buffer_; // Span of the allocated buffer
size_t hypothesis_buffer_length_; // Total number of elements
size_t hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer.
size_t hypothesis_buffer_offset_; // Offset of available buffer, or length of used buffer.
onnxruntime::FastAllocVector<BeamHypotheses> beam_hyps_;
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "gsl/gsl"
@ -23,10 +26,10 @@ struct IBeamSearchState {
gsl::span<float> next_token_scores; // shape (batch_size, num_beams * vocab_size)
gsl::span<int32_t> next_tokens; // shape (batch_size, 2 * num_beams)
gsl::span<int32_t> next_indices; // shape (batch_size, 2 * num_beams)
gsl::span<int32_t> next_positions; // shape (batch_size, num_beams). Next position value for position_ids.
gsl::span<int32_t> next_positions; // shape (batch_size, num_beams), empty for T5. Next position for position_ids.
gsl::span<float> beam_scores; // shape (batch_size, num_beams)
gsl::span<float> scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
gsl::span<float> remaining_scores; // portion of scores that is avaiable for appending next token scores.
gsl::span<float> remaining_scores; // portion of scores that is available for appending next token scores.
};
struct IBeamSearchCpuState {
@ -72,10 +75,14 @@ class IBeamScorer {
};
struct IBeamSearchParameters {
static constexpr int kModelTypeGpt = 0;
static constexpr int kModelTypeT5 = 1;
// Parameters from node attributes
int model_type;
int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5
int eos_token_id;
int pad_token_id;
int decoder_start_token_id;
int no_repeat_ngram_size;
bool early_stopping;
@ -88,7 +95,7 @@ struct IBeamSearchParameters {
float length_penalty;
float repetition_penalty;
int batch_size; // deduce from first dimension of input_ids
int sequence_length; // deduce from second dimension of input_ids
int sequence_length; // deduce from second dimension of input_ids of GPT-2 or decoder_input_ids of T5
gsl::span<const int32_t> vocab_mask;
gsl::span<const int32_t> prefix_vocab_mask;
@ -128,4 +135,4 @@ class IConsoleDumper {
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -70,12 +70,17 @@ void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) {
void DumpCpuTensor(const char* name, const Tensor& tensor) {
const auto& shape = tensor.Shape();
if (nullptr != name) {
std::cout << std::string(name) << std::endl;
}
std::cout << "Shape:" << shape << std::endl;
size_t num_dims = shape.NumDimensions();
if (num_dims >= 3) {
int dim0 = static_cast<int>(shape.SizeToDimension(num_dims - 2));
int dim1 = static_cast<int>(shape[num_dims - 2]);
int dim2 = static_cast<int>(shape[num_dims - 1]);
DumpCpuTensor(name, tensor, dim0, dim1, dim2);
DumpCpuTensor(nullptr, tensor, dim0, dim1, dim2);
return;
}
@ -85,7 +90,7 @@ void DumpCpuTensor(const char* name, const Tensor& tensor) {
num_rows = static_cast<size_t>(shape[0]);
}
size_t row_size = num_items / num_rows;
DumpCpuTensor(name, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
DumpCpuTensor(nullptr, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
}
void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const {
@ -209,4 +214,4 @@ void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -4,7 +4,7 @@
#pragma once
#include <string>
#include "core/framework/tensorprotoutils.h"
#include "beam_search_shared.h"
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
namespace onnxruntime {
namespace contrib {
@ -30,4 +30,4 @@ class CpuTensorConsoleDumper : public IConsoleDumper {
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,248 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/framework_common.h"
#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "gsl/gsl"
#include "gpt_subgraph.h"
#include "dump_tensor.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
namespace transformers {
GptSubgraph::GptSubgraph(
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in)
: node(node_in), attribute(attribute_name), subgraph(subgraph_in), allocator_(nullptr), is_output_float16_(false) {
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
auto& subgraph_inputs = subgraph.GetInputs();
auto& subgraph_outputs = subgraph.GetOutputs();
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
// outputs: logits, present_0, present_1, ...
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
// CheckSubgraph will verify inputs and outputs later.
subgraph_input_names.reserve(num_subgraph_inputs);
for (int i = 0; i < num_subgraph_inputs; ++i) {
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
}
subgraph_output_names.reserve(num_subgraph_outputs);
for (int i = 0; i < num_subgraph_outputs; ++i) {
subgraph_output_names.push_back(subgraph_outputs[i]->Name());
}
}
Status GptSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
ORT_RETURN_IF(num_subgraph_outputs <= 1,
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in inputs and outputs).");
ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2,
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2");
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ",
subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ",
subgraph_inputs[1]->Name());
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ",
subgraph_inputs[2]->Name());
ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ",
subgraph_inputs[3]->Name());
// Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads.
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape();
ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ",
past_shape->dim_size());
ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2,
"subgraph past state dimension 0 shall have length of 2");
ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for number of heads");
ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0,
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
// check subgraph outputs
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ",
subgraph_outputs[0]->Name());
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ",
subgraph_outputs[1]->Name());
// Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size.
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ",
logits_shape->dim_size());
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
// Save parameters related to the subgraph.
num_heads = static_cast<int>(past_shape->dim(2).dim_value());
head_size = static_cast<int>(past_shape->dim(4).dim_value());
vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
num_layers = static_cast<int>(subgraph_outputs.size()) - 1;
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
"subgraph input 0 (input_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
"subgraph input 1 (position_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
"subgraph input 2 (attention_mask) shall have int32 type");
auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(output_type != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && output_type != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16,
"subgraph output 0 (logits) shall be float or float16 data type");
ORT_RETURN_IF(subgraph_inputs[3]->TypeAsProto()->tensor_type().elem_type() != output_type,
"subgraph input 3 (past_0) shall shall have same data type of logits output");
ORT_RETURN_IF(subgraph_outputs[1]->TypeAsProto()->tensor_type().elem_type() != output_type,
"subgraph output 1 (present_0) shall shall have same data type of logits output");
is_output_float16_ = (output_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16);
return Status::OK();
}
Status GptSubgraph::Setup(const SessionState& session_state,
const SessionState& subgraph_session_state) {
session_state_ = &session_state;
subgraph_session_state_ = &subgraph_session_state;
std::vector<std::string> feed_names;
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
// Currently, input_ids is in CPU even for CUDA operator, so we have to use logits location as default.
const OrtMemoryInfo& default_location = utils::FindMemoryInfoForValue(subgraph_session_state, "logits");
// position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
// as we skip them when we call FindDevicesForValues, and default them to be in the same device as input_ids
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
for (auto& entry : node.ImplicitInputDefs()) {
feed_names.push_back(entry->Name());
}
std::vector<OrtDevice> feed_locations;
feed_locations.resize(feed_names.size());
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
if (i >= subgraph_input_names.size()) { // implicit inputs
const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]);
feed_locations[i] = location.device;
} else {
feed_locations[i] = default_location.device;
}
}
std::unique_ptr<FeedsFetchesManager> ffm;
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names,
subgraph_session_state.GetOrtValueNameIdxMap(), ffm));
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
// setup the locations where we want the subgraph output to end up on
std::vector<const OrtMemoryInfo*> fetch_locations;
fetch_locations.reserve(num_subgraph_outputs);
// past state need to be where we can feed them in to the next iteration, so set the fetch location to match the feed location.
for (int i = 0; i < num_subgraph_outputs; ++i) {
fetch_locations.push_back(&default_location);
}
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
feeds_fetches_manager_ = std::move(ffm);
// Check subgraph only need once so put in Setup function.
auto& inputs = subgraph.GetInputs();
auto& outputs = subgraph.GetOutputs();
ORT_RETURN_IF_ERROR(Validate(inputs, outputs));
return Status::OK();
}
const IExecutionProvider* GptSubgraph::GetProvider() const {
const ExecutionProviders& providers = session_state_->GetExecutionProviders();
const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider);
const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider);
const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider;
return provider;
}
Status GptSubgraph::CreateInitialFeeds(
const Tensor& input_ids,
const std::vector<const OrtValue*>& implicit_inputs,
int num_beams,
int pad_token_id,
gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr<char>& buffer) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
const IExecutionProvider* provider = GetProvider();
const TensorShape& input_ids_shape = input_ids.Shape();
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
const int64_t& batch_size = input_ids_shape[0];
// Subgraph inputs:
// input_ids: shape (B, S) wher B is batch size, and S is sequence length
// position_ids: shape (B, S)
// attention_mask: shape (B, P+S), where past_sequence_length (P) is 0
// After expansion, their shapes will become (B, M*S), where M is num_beams.
// Allocate subgraph inputs to be same device as input_ids
AllocatorPtr cpu_alloactor = session_state_->GetAllocator(input_ids.Location());
// Store allocator, which will be used in remaining feeds
auto default_allocator = provider->GetAllocator(0, OrtMemTypeDefault);
allocator_ = default_allocator;
// Initialize empty past state
auto past_type = IsOutputFloat16() ? DataTypeImpl::GetType<MLFloat16>() : DataTypeImpl::GetType<float>();
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size};
TensorShape past_shape(&past_state_dims[0], 5);
OrtValue empty_past;
Tensor::InitOrtValue(past_type, past_shape, default_allocator, empty_past);
// The ordering is the same as used in Setup
feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
OrtValue expanded_position_ids;
OrtValue expanded_attention_mask;
ORT_RETURN_IF_ERROR(create_inputs_func(&input_ids, num_beams, pad_token_id, sequence_lengths, cpu_alloactor, expanded_input_ids, expanded_position_ids, expanded_attention_mask));
ORT_RETURN_IF_ERROR(add_to_feeds_func(provider, expanded_input_ids, expanded_position_ids, expanded_attention_mask, feeds, buffer));
// The remaing inputs are past state.
for (int i = 3; i < num_subgraph_inputs; ++i) {
feeds.push_back(empty_past);
}
// pass in implicit inputs
for (const auto* entry : implicit_inputs) {
feeds.push_back(*entry);
}
return Status::OK();
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,7 +1,11 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <memory>
#include <assert.h>
#include "logits_processor.h"
#include "dump_tensor.h"
#include "core/common/safeint.h"
#include "contrib_ops/cpu/transformers/logits_processor.h"
#include "contrib_ops/cpu/transformers/dump_tensor.h"
namespace onnxruntime {
namespace contrib {
@ -102,7 +106,8 @@ void NoRepeatNGramLogitsProcessor<T>::Process(const ISequences* sequences,
std::unordered_set<int32_t> blocked_word_ids;
for (int j = 0; j <= static_cast<int>(sequence.length()) - ngram_size_; j++) {
// Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length)
// TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching.
// TODO(tianleiwu): build N-Gram index (hash table with prefix of length NGram - 1 as key,
// and list of last word of NGram as value) for fast matching.
if (ngram_size_ == 1 || prefix == sequence.subspan(j, prefix_length)) {
blocked_word_ids.insert(sequence[static_cast<gsl::index>(j) + prefix_length]);
}
@ -119,7 +124,8 @@ void NoRepeatNGramLogitsProcessor<T>::Process(const ISequences* sequences,
}
template <typename T>
VocabMaskLogitsProcessor<T>::VocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask) : vocab_mask_(vocab_mask) {
VocabMaskLogitsProcessor<T>::VocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask)
: vocab_mask_(vocab_mask) {
}
template <typename T>
@ -145,8 +151,10 @@ void VocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
}
template <typename T>
PrefixVocabMaskLogitsProcessor<T>::PrefixVocabMaskLogitsProcessor(const gsl::span<const int32_t>& prefix_vocab_mask, int batch_size)
: prefix_vocab_mask_(prefix_vocab_mask), batch_size_(batch_size) {
PrefixVocabMaskLogitsProcessor<T>::PrefixVocabMaskLogitsProcessor(const gsl::span<const int32_t>& prefix_vocab_mask,
int batch_size)
: prefix_vocab_mask_(prefix_vocab_mask),
batch_size_(batch_size) {
}
template <typename T>
@ -159,7 +167,7 @@ void PrefixVocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
assert(num_beams * batch_size_ == next_token_scores.batch_beam_size);
// Process prefix vocabulary mask and set tokens with mask value 0 to -inf.
// prefix_vocab_mask shape (batch_szie, vocab_size).
// prefix_vocab_mask shape (batch_size, vocab_size).
T* p = next_token_scores.scores.data();
for (int i = 0; i < batch_size_; i++) {
size_t prefix_vocab_mask_offset = SafeInt<size_t>(i) * next_token_scores.vocab_size;
@ -181,7 +189,8 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
processor_list_.clear();
if (parameters.repetition_penalty != 1.0f) { // 1.0 means no penalty
repetition_penalty_processor_ = std::make_unique<RepetitionPenaltyLogitsProcessor<float>>(parameters.repetition_penalty);
repetition_penalty_processor_ = std::make_unique<RepetitionPenaltyLogitsProcessor<float>>(
parameters.repetition_penalty);
processor_list_.push_back(repetition_penalty_processor_.get());
}
@ -196,12 +205,14 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
}
if (!parameters.prefix_vocab_mask.empty()) {
prefix_vocab_mask_processor_ = std::make_unique<PrefixVocabMaskLogitsProcessor<float>>(parameters.prefix_vocab_mask, parameters.batch_size);
prefix_vocab_mask_processor_ = std::make_unique<PrefixVocabMaskLogitsProcessor<float>>(parameters.prefix_vocab_mask,
parameters.batch_size);
processor_list_.push_back(prefix_vocab_mask_processor_.get());
}
if (parameters.min_length > 0) {
min_length_processor_ = std::make_unique<MinLengthLogitsProcessor<float>>(parameters.min_length, parameters.eos_token_id);
min_length_processor_ = std::make_unique<MinLengthLogitsProcessor<float>>(parameters.min_length,
parameters.eos_token_id);
processor_list_.push_back(min_length_processor_.get());
}

View file

@ -1,9 +1,12 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "sequences.h"
#include "beam_search_parameters.h"
#include "core/common/inlined_containers.h"
#include "beam_search_shared.h"
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
namespace onnxruntime {
namespace contrib {

View file

@ -1,5 +1,8 @@
#include "sequences.h"
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/common/safeint.h"
#include "contrib_ops/cpu/transformers/sequences.h"
namespace onnxruntime {
namespace contrib {
@ -20,8 +23,10 @@ void Sequences::Init(gsl::span<int32_t> buffer, int batch_beam_size, int sequenc
}
gsl::span<const int32_t> Sequences::GetSequence(int beam_index) const {
gsl::span<const int32_t> buffer(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size());
gsl::span<const int32_t> sequence = buffer.subspan(SafeInt<size_t>(beam_index) * max_length_, static_cast<gsl::index>(current_length_));
gsl::span<const int32_t> buffer(sequences[current_sequences_buffer].data(),
sequences[current_sequences_buffer].size());
gsl::span<const int32_t> sequence = buffer.subspan(SafeInt<size_t>(beam_index) * max_length_,
static_cast<gsl::index>(current_length_));
return sequence;
}
@ -42,13 +47,16 @@ void Sequences::PrintSequences(const IConsoleDumper* dumper) const {
void Sequences::AppendNextTokenToSequences(
gsl::span<int32_t>& beam_indices,
gsl::span<int32_t>& beam_next_tokens) {
gsl::span<const int32_t> input(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size());
gsl::span<const int32_t> input(sequences[current_sequences_buffer].data(),
sequences[current_sequences_buffer].size());
gsl::span<int32_t> output = sequences[1 - current_sequences_buffer];
for (int i = 0; i < batch_beam_size_; i++) {
int beam_index = beam_indices[i];
gsl::span<const int32_t> source = input.subspan(SafeInt<size_t>(beam_index) * max_length_, static_cast<gsl::index>(current_length_));
gsl::span<int32_t> target = output.subspan(SafeInt<size_t>(i) * max_length_, static_cast<gsl::index>(current_length_));
gsl::span<const int32_t> source = input.subspan(SafeInt<size_t>(beam_index) * max_length_,
static_cast<gsl::index>(current_length_));
gsl::span<int32_t> target = output.subspan(SafeInt<size_t>(i) * max_length_,
static_cast<gsl::index>(current_length_));
gsl::copy(source, target);
}
@ -65,4 +73,4 @@ void Sequences::AppendNextTokenToSequences(
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,7 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "gsl/gsl"
#include "beam_search_shared.h"
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
namespace onnxruntime {
namespace contrib {
@ -47,4 +50,4 @@ class Sequences : public ISequences {
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -0,0 +1,163 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <utility>
#include "core/framework/framework_common.h"
#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "gsl/gsl"
#include "contrib_ops/cpu/transformers/subgraph_base.h"
#include "contrib_ops/cpu/transformers/dump_tensor.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
Subgraph::Subgraph(
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in)
: node(node_in),
attribute(attribute_name),
subgraph(subgraph_in),
num_heads(0),
head_size(0),
vocab_size(0),
num_layers(0),
allocator_(nullptr),
is_output_float16_(false) {
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
auto& subgraph_inputs = subgraph.GetInputs();
auto& subgraph_outputs = subgraph.GetOutputs();
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
// outputs: logits, present_0, present_1, ...
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
// CheckSubgraph will verify inputs and outputs later.
subgraph_input_names.reserve(num_subgraph_inputs);
for (int i = 0; i < num_subgraph_inputs; ++i) {
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
}
subgraph_output_names.reserve(num_subgraph_outputs);
for (int i = 0; i < num_subgraph_outputs; ++i) {
subgraph_output_names.push_back(subgraph_outputs[i]->Name());
}
}
Status Subgraph::Setup(const SessionState& session_state,
const SessionState& subgraph_session_state) {
session_state_ = &session_state;
subgraph_session_state_ = &subgraph_session_state;
std::vector<std::string> feed_names;
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
// Use the first output (logits) to find device location.
const OrtMemoryInfo& default_location = utils::FindMemoryInfoForValue(subgraph_session_state,
subgraph_output_names[0]);
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
for (auto& entry : node.ImplicitInputDefs()) {
feed_names.push_back(entry->Name());
}
std::vector<OrtDevice> feed_locations;
feed_locations.resize(feed_names.size());
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
if (i >= subgraph_input_names.size()) { // Implicit inputs
const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]);
feed_locations[i] = location.device;
} else {
feed_locations[i] = default_location.device;
}
}
std::unique_ptr<FeedsFetchesManager> ffm;
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names,
subgraph_session_state.GetOrtValueNameIdxMap(), ffm));
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
// Setup the locations where we want the subgraph output to end up on
std::vector<const OrtMemoryInfo*> fetch_locations;
fetch_locations.reserve(num_subgraph_outputs);
// Past state need to be where we can feed them in to the next iteration, so set the location to match the feed.
for (int i = 0; i < num_subgraph_outputs; ++i) {
fetch_locations.push_back(&default_location);
}
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
feeds_fetches_manager_ = std::move(ffm);
// Check subgraph only need once so put in Setup function.
auto& inputs = subgraph.GetInputs();
auto& outputs = subgraph.GetOutputs();
ORT_RETURN_IF_ERROR(Validate(inputs, outputs));
return Status::OK();
}
const IExecutionProvider* Subgraph::GetProvider() const {
const ExecutionProviders& providers = session_state_->GetExecutionProviders();
const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider);
const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider);
const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider;
return provider;
}
Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shape,
const ONNX_NAMESPACE::TensorShapeProto* logits_shape,
bool merged_past) {
if (merged_past) {
// Merged past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads)
ORT_RETURN_IF(past_shape->dim_size() != 5,
"subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size());
ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2,
"subgraph past state dimension 0 shall have length of 2");
ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for number of heads");
ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0,
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
this->num_heads = static_cast<int>(past_shape->dim(2).dim_value());
this->head_size = static_cast<int>(past_shape->dim(4).dim_value());
} else {
// Past state shape is like (batch_size, num_heads, past_seq_len, hidden_size/num_heads).
ORT_RETURN_IF(past_shape->dim_size() != 4,
"subgraph output present_key_self_0 is expected to have 4 dimension, got ", past_shape->dim_size());
ORT_RETURN_IF(!past_shape->dim(1).has_dim_value() || past_shape->dim(1).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for number of heads");
ORT_RETURN_IF(!past_shape->dim(3).has_dim_value() || past_shape->dim(3).dim_value() <= 0,
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
this->num_heads = static_cast<int>(past_shape->dim(1).dim_value());
this->head_size = static_cast<int>(past_shape->dim(3).dim_value());
}
// Logits shape is like (batch_size, seq_len, vocabulary_size)
ORT_RETURN_IF(logits_shape->dim_size() != 3,
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
return Status::OK();
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -2,6 +2,9 @@
// Licensed under the MIT License.
#pragma once
#include <vector>
#include <string>
#include "gsl/gsl"
#include "core/framework/allocator.h"
#include "core/framework/feeds_fetches_manager.h"
@ -15,21 +18,22 @@ namespace onnxruntime {
namespace contrib {
namespace transformers {
// A class for GPT-2 subgraph inputs and outputs preparation.
struct GptSubgraph {
GptSubgraph(
class Subgraph {
public:
Subgraph(
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in);
virtual ~Subgraph() {}
const onnxruntime::Node& node; // node that contains the subgraph
const std::string& attribute; // attribute of th node that contains the subgraph. Not used yet.
const GraphViewer& subgraph; // the subgraph
const onnxruntime::Node& node; // Node that contains the subgraph
const std::string& attribute; // Attribute of th node that contains the subgraph. Not used yet.
const GraphViewer& subgraph; // The subgraph
int num_implicit_inputs;
int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience.
int num_subgraph_outputs; // same as subgraph_output_names.size()
int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience.
int num_subgraph_outputs; // Same as subgraph_output_names.size()
std::vector<std::string> subgraph_input_names;
std::vector<std::string> subgraph_output_names;
@ -40,32 +44,23 @@ struct GptSubgraph {
int vocab_size;
int num_layers;
// Setup exectuion
// Setup execution
Status Setup(const SessionState& session_state,
const SessionState& subgraph_session_state);
// Create inputs for first inference of subgraph.
Status CreateInitialFeeds(
const Tensor& input_ids,
const std::vector<const OrtValue*>& implicit_inputs,
int num_beams,
int pad_token_id,
gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr<char>& buffer);
FeedsFetchesManager* GetFeedsFetchesManager() const { return feeds_fetches_manager_.get(); }
const IExecutionProvider* GetProvider() const;
bool IsOutputFloat16() const { return is_output_float16_; }
virtual Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) = 0;
protected:
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs);
Status GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shape,
const ONNX_NAMESPACE::TensorShapeProto* logits_shape,
bool merged_past);
AllocatorPtr allocator_;
const SessionState* session_state_;

View file

@ -0,0 +1,167 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/framework_common.h"
#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "gsl/gsl"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
#include "contrib_ops/cpu/transformers/dump_tensor.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
Status GptSubgraph::CreateInitialFeeds(
const Tensor& input_ids,
const std::vector<const OrtValue*>& implicit_inputs,
int num_beams,
int pad_token_id,
gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
const BeamSearchDeviceHelper::CreateGptInputsFunc& create_gpt_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr<char>& buffer) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
const IExecutionProvider* provider = GetProvider();
const TensorShape& input_ids_shape = input_ids.Shape();
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
const int64_t& batch_size = input_ids_shape[0];
// Subgraph inputs:
// input_ids: shape (B, S) where B is batch size, and S is sequence length
// position_ids: shape (B, S)
// attention_mask: shape (B, P+S), where past_sequence_length (P) is 0
// After expansion, their shapes will become (B, M*S), where M is num_beams.
// Allocate subgraph inputs to be same device as input_ids
AllocatorPtr cpu_allocator = session_state_->GetAllocator(input_ids.Location());
// Store allocator, which will be used in remaining feeds
auto default_allocator = provider->GetAllocator(0, OrtMemTypeDefault);
allocator_ = default_allocator;
// Initialize empty past state
auto past_type = IsOutputFloat16() ? DataTypeImpl::GetType<MLFloat16>() : DataTypeImpl::GetType<float>();
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size};
TensorShape past_shape(&past_state_dims[0], 5);
OrtValue empty_past;
Tensor::InitOrtValue(past_type, past_shape, default_allocator, empty_past);
// The ordering is the same as used in Setup
feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
OrtValue expanded_position_ids;
OrtValue expanded_attention_mask;
ORT_RETURN_IF_ERROR(create_gpt_inputs_func(&input_ids,
num_beams,
pad_token_id,
sequence_lengths,
cpu_allocator,
expanded_input_ids,
expanded_position_ids,
expanded_attention_mask));
ORT_RETURN_IF_ERROR(add_to_feeds_func(provider,
{expanded_input_ids, expanded_position_ids, expanded_attention_mask},
feeds,
buffer));
// The remaining inputs are past state.
for (int i = kFirstPastInputIndex; i < num_subgraph_inputs; ++i) {
feeds.push_back(empty_past);
}
// Pass in implicit inputs
for (const auto* entry : implicit_inputs) {
feeds.push_back(*entry);
}
return Status::OK();
}
Status GptSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
ORT_RETURN_IF(num_subgraph_outputs <= kFirstPresentOutputIndex,
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in outputs).");
ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2,
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2");
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
"subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids",
"subgraph input 1 shall be named as position_ids, got: ", subgraph_inputs[1]->Name());
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask",
"subgraph input 2 shall be named as attention_mask, got: ", subgraph_inputs[2]->Name());
ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0",
"subgraph input 3 shall be named as past_0, got: ", subgraph_inputs[3]->Name());
// Past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads).
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape();
ORT_RETURN_IF(past_shape->dim_size() != 5,
"subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size());
ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2,
"subgraph past state dimension 0 shall have length of 2");
ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for number of heads");
ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0,
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
// check subgraph outputs
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
"subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0",
"subgraph input 1 shall be named as present_0, got: ", subgraph_outputs[1]->Name());
// Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size.
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
ORT_RETURN_IF(logits_shape->dim_size() != 3,
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
// Save parameters related to the subgraph.
num_heads = static_cast<int>(past_shape->dim(2).dim_value());
head_size = static_cast<int>(past_shape->dim(4).dim_value());
vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
num_layers = static_cast<int>(subgraph_outputs.size()) - 1;
constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"subgraph input 0 (input_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"subgraph input 1 (position_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"subgraph input 2 (attention_mask) shall have int32 type");
auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(output_type != float32_type && output_type != float16_type,
"subgraph output 0 (logits) shall be float or float16 data type");
ORT_RETURN_IF(subgraph_inputs[kFirstPastInputIndex]->TypeAsProto()->tensor_type().elem_type() != output_type,
"subgraph input 3 (past_0) shall shall have same data type of logits output");
ORT_RETURN_IF(subgraph_outputs[kFirstPresentOutputIndex]->TypeAsProto()->tensor_type().elem_type() != output_type,
"subgraph output 1 (present_0) shall shall have same data type of logits output");
is_output_float16_ = (output_type == float16_type);
return Status::OK();
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,42 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cpu/transformers/subgraph_base.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
// A class for GPT-2 subgraph inputs and outputs preparation.
class GptSubgraph : public Subgraph {
public:
GptSubgraph(
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {}
// Create inputs for first inference of subgraph.
Status CreateInitialFeeds(
const Tensor& input_ids,
const std::vector<const OrtValue*>& implicit_inputs,
int num_beams,
int pad_token_id,
gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
const BeamSearchDeviceHelper::CreateGptInputsFunc& create_gpt_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr<char>& buffer);
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;
constexpr static int kFirstPastInputIndex = 3;
constexpr static int kFirstPresentOutputIndex = 1;
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,154 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/framework_common.h"
#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "gsl/gsl"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/dump_tensor.h"
#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
/* T5 Decoder Subgraph.
Inputs:
input_ids: int32 (B, 1)
encoder_attention_mask: int32 (B, encode_sequence_length)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
... (for each self attention layer)
past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
... (for each cross attention layer)
Outputs:
logits: (B, 1, vocab_size)
present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
... (for each self attention layer)
Note:
B = batch_size * num_beams
Data type of input or output is float or float16 if not specified.
*/
Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
ORT_RETURN_IF(num_subgraph_inputs < 7 || (num_subgraph_inputs - kFirstPastInputIndex) % 4 != 0,
"number of outputs expected to be 3 + 4 * layers, got:", num_subgraph_inputs);
ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - kFirstPresentOutputIndex) % 2 != 0,
"number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs);
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
"decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
"decoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name());
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states",
"decoder subgraph input 2 shall be named as encoder_hidden_states, got: ", subgraph_inputs[2]->Name());
// check subgraph outputs
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
"decoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[kFirstPresentOutputIndex]->Shape();
// Save parameters related to the subgraph.
ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false));
num_layers = (static_cast<int>(subgraph_outputs.size()) - kFirstPresentOutputIndex) / 2;
constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 0 (input_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"decoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
for (int i = kFirstPastInputIndex; i < num_subgraph_inputs; i++) {
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
"decoder subgraph past inputs shall have same data type as that of encoder_hidden_states");
}
for (int i = 0; i < num_subgraph_outputs; i++) {
ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
"decoder subgraph output shall have same data type as that of encoder_hidden_states");
}
is_output_float16_ = (subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() == float16_type);
return Status::OK();
}
// Create inputs for decoder from the following data sources:
// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
// encoder fetches: logits,
// encoder_hidden_states,
// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
// decoder_feeds: input_ids,
// encoder_attention_mask,
// encoder_hidden_states,
// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
Status T5DecoderSubgraph::CreateInitialFeeds(
gsl::span<const int32_t> beam_next_tokens,
const std::vector<const OrtValue*>& implicit_inputs,
const std::vector<OrtValue>& encoder_feeds,
const std::vector<OrtValue>& encoder_fetches,
std::vector<OrtValue>& decoder_feeds,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
void* stream) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
// Allocate subgraph inputs from same device as inputs of encoder subgraph.
AllocatorPtr allocator = session_state_->GetAllocator(encoder_feeds[0].Get<Tensor>().Location());
// Copy beam next tokens in CPU to input_ids in provider device (CPU for CPU EP, or GPU for CUDA EP).
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
int64_t dims[] = {batch_beam_size, 1};
TensorShape input_ids_shape(&dims[0], 2);
OrtValue input_ids;
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);
ORT_RETURN_IF_ERROR(device_copy_int32_func(
input_ids.GetMutable<Tensor>()->MutableDataAsSpan<int32_t>(),
beam_next_tokens,
stream,
DeviceCopyDirection::hostToDevice));
// The ordering is the same as used in Setup.
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
decoder_feeds.push_back(input_ids);
// The encoder_attention_mask is copied from the second input of encoder.
decoder_feeds.push_back(encoder_feeds[1]);
// The encoder_hidden_states and past states are copied from the second output of encoder.
for (size_t j = 1; j < encoder_fetches.size(); j++) {
decoder_feeds.push_back(encoder_fetches[j]);
}
// Pass through implicit inputs.
for (const auto* entry : implicit_inputs) {
decoder_feeds.push_back(*entry);
}
return Status::OK();
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,39 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cpu/transformers/subgraph_base.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
// A class for T5 decoder subgraph inputs and outputs preparation.
class T5DecoderSubgraph : public Subgraph {
public:
T5DecoderSubgraph(
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {}
// Create inputs for first inference of decoder subgraph.
Status CreateInitialFeeds(
gsl::span<const int32_t> beam_next_tokens,
const std::vector<const OrtValue*>& implicit_inputs,
const std::vector<OrtValue>& encoder_feeds,
const std::vector<OrtValue>& encoder_fetches,
std::vector<OrtValue>& decoder_feeds,
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
void* stream);
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;
constexpr static int kFirstPastInputIndex = 3;
constexpr static int kFirstPresentOutputIndex = 1;
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,145 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/framework/framework_common.h"
#include "core/framework/session_state.h"
#include "core/framework/tensorprotoutils.h"
#include "core/framework/utils.h"
#include "core/providers/cpu/tensor/utils.h"
#include "gsl/gsl"
#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
/* T5 Encoder Subgraph (It also contains decoder initialization where decoder_input_ids are filled with start token ID).
Inputs:
encoder_input_ids: int32 (B, encode_sequence_length)
encoder_attention_mask: int32 (B, encode_sequence_length)
decoder_input_ids: int32 (B, 1)
Outputs:
logits: (B, 1, vocab_size)
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
present_key_self_0: (B, num_heads, 1, head_size)
present_value_self_0: (B, num_heads, 1, head_size)
... (for each self attention layer)
present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
... (for each cross attention layer)
Note:
Here, B = batch_size * num_beams since we expand the inputs.
Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams.
Data type of input or output is float or float16 if not specified.
*/
Status T5EncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) {
ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs);
ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs);
ORT_RETURN_IF((static_cast<int>(subgraph_outputs.size()) - kFirstPresentOutputIndex) % 4 != 0,
"number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs);
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "encoder_input_ids",
"encoder subgraph input 0 shall be named as encoder_input_ids, got: ", subgraph_inputs[0]->Name());
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
"encoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name());
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids",
"encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name());
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
"encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states",
"encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name());
ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0",
"encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name());
ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0",
"encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name());
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape();
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
// Save parameters related to the subgraph.
ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false));
num_layers = (static_cast<int>(subgraph_outputs.size()) - kFirstPresentOutputIndex) / 4;
constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"encoder subgraph input 0 (encoder_input_ids) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"encoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
"encoder subgraph input 2 (decoder_input_ids) shall have int32 type");
auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type();
ORT_RETURN_IF(output_type != float32_type && output_type != float16_type,
"encoder subgraph output 0 (logits) shall be float or float16 data type");
for (int i = 1; i < num_subgraph_outputs; i++) {
ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != output_type,
"encoder subgraph outputs 1, 2, ... shall have same data type");
}
is_output_float16_ = (output_type == float16_type);
return Status::OK();
}
// Create inputs for first inference of subgraph.
Status T5EncoderSubgraph::CreateInitialFeeds(
const Tensor& encoder_input_ids,
const std::vector<const OrtValue*>& implicit_inputs,
int num_beams,
int pad_token_id,
int start_token_id,
std::vector<OrtValue>& feeds,
const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr<char>& buffer,
OrtValue& expanded_decoder_input_ids) {
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
// The ordering is the same as used in Setup.
feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
// Allocate subgraph inputs to be same device as encoder_input_ids.
AllocatorPtr cpu_allocator = session_state_->GetAllocator(encoder_input_ids.Location());
// TODO(tianleiwu): expand the outputs instead of inputs to save computation.
OrtValue expanded_encoder_input_ids;
OrtValue expanded_encoder_attention_mask;
ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&encoder_input_ids,
num_beams,
pad_token_id,
start_token_id,
cpu_allocator,
expanded_encoder_input_ids,
expanded_encoder_attention_mask,
expanded_decoder_input_ids));
const IExecutionProvider* provider = GetProvider();
ORT_RETURN_IF_ERROR(add_to_feeds_func(
provider,
{expanded_encoder_input_ids, expanded_encoder_attention_mask, expanded_decoder_input_ids},
feeds,
buffer));
for (const auto* entry : implicit_inputs) {
feeds.push_back(*entry);
}
return Status::OK();
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,41 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cpu/transformers/subgraph_base.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
// A class for T5 encoder subgraph inputs and outputs preparation.
class T5EncoderSubgraph : public Subgraph {
public:
T5EncoderSubgraph(
const onnxruntime::Node& node_in,
const std::string& attribute_name,
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {}
// Create inputs for first inference of subgraph.
Status CreateInitialFeeds(
const Tensor& encoder_input_ids,
const std::vector<const OrtValue*>& implicit_inputs,
int num_beams,
int pad_token_id,
int start_token_id,
std::vector<OrtValue>& feeds,
const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
IAllocatorUniquePtr<char>& buffer,
OrtValue& expanded_decoder_input_ids);
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
const std::vector<const NodeArg*>& subgraph_outputs) override;
constexpr static int kFirstPresentOutputIndex = 2;
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -4,8 +4,8 @@
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cuda_execution_provider.h"
#include "contrib_ops/cuda/transformers/beam_search.h"
#include "beam_search_device_helper.h"
#include "dump_cuda_tensor.h"
#include "contrib_ops/cuda/transformers/beam_search_device_helper.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
namespace onnxruntime {
namespace contrib {
@ -38,16 +38,19 @@ BeamSearch::BeamSearch(const OpKernelInfo& info)
SetComputeStream(static_cast<void*>(info.GetExecutionProvider()->GetComputeStream()));
SetDeviceHelpers(BeamSearchCudaDeviceHelper::AddToFeeds,
BeamSearchCudaDeviceHelper::TopK);
SetDeviceHelpers(BeamSearchCudaDeviceHelper::ProcessLogits<float>,
BeamSearchCudaDeviceHelper::InitBeamState<float>,
BeamSearchCudaDeviceHelper::TopK,
BeamSearchCudaDeviceHelper::DeviceCopy<float>,
BeamSearchCudaDeviceHelper::UpdateFeeds<float>);
BeamSearchCudaDeviceHelper::DeviceCopy<int32_t>,
BeamSearchCudaDeviceHelper::ProcessLogits<float>,
BeamSearchCudaDeviceHelper::ProcessLogits<MLFloat16>,
BeamSearchCudaDeviceHelper::InitBeamState<float>,
BeamSearchCudaDeviceHelper::InitBeamState<MLFloat16>);
SetDeviceHelpers(BeamSearchCudaDeviceHelper::ProcessLogits<MLFloat16>,
BeamSearchCudaDeviceHelper::InitBeamState<MLFloat16>,
BeamSearchCudaDeviceHelper::UpdateFeeds<MLFloat16>);
SetDeviceHelpers_Gpt(BeamSearchCudaDeviceHelper::UpdateGptFeeds<float>,
BeamSearchCudaDeviceHelper::UpdateGptFeeds<MLFloat16>);
SetDeviceHelpers_EncoderDecoder(BeamSearchCudaDeviceHelper::UpdateDecoderFeeds<float>,
BeamSearchCudaDeviceHelper::UpdateDecoderFeeds<MLFloat16>);
SetConsoleDumper(&g_cuda_dumper);
}
@ -71,4 +74,4 @@ Status BeamSearch::Compute(OpKernelContext* context) const {
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,16 +1,17 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/math/topk_impl.h"
#include "core/providers/cuda/math/softmax.h"
#include "core/providers/cuda/shared_inc/accumulation_type.h"
#include "core/framework/ort_value.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
#include "beam_search_impl.h"
#include <cuda_runtime.h>
#include "dump_cuda_tensor.h"
#ifdef DEBUG_BEAM_SEARCH
using namespace onnxruntime::contrib::cuda::transformers;
#endif
#include "contrib_ops/cuda/transformers/beam_search_impl.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
namespace onnxruntime {
namespace concurrency {
@ -52,7 +53,7 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
output_indices = Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
if (input->IsDataType<float>()) {
return TopKImpl<float>(nullptr, // We limit number of beams in BeamSearchParameters, so that K <= 256 and kernel is not needed
return TopKImpl<float>(nullptr, // We limit number of beams in BeamSearchParameters, so K <= 256 and use NULL here
reinterpret_cast<cudaStream_t>(stream),
input->Data<float>(),
static_cast<float*>(output_values->MutableDataRaw()),
@ -87,35 +88,53 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
}
Status AddToFeeds(const IExecutionProvider* execution_provider,
OrtValue& input_ids,
OrtValue& position_ids,
OrtValue& attention_mask,
std::initializer_list<OrtValue> inputs,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer) {
// Copy tensors to GPU, then add to feeds
const CUDAExecutionProvider* provider = reinterpret_cast<const CUDAExecutionProvider*>(execution_provider);
const TensorShape& shape = input_ids.Get<Tensor>().Shape();
ORT_ENFORCE(shape.NumDimensions() == 2);
const int64_t elements = shape[0] * shape[1];
size_t total_bytes = 0;
for (auto& input : inputs) {
if (input.IsAllocated()) {
total_bytes += input.Get<Tensor>().Shape().Size() * input.Type()->Size();
}
}
ORT_ENFORCE(total_bytes > 0);
AllocatorPtr pinned_allocator = provider->GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU);
cudaStream_t stream = static_cast<cudaStream_t>(provider->GetComputeStream());
size_t bytes = (sizeof(int32_t) + sizeof(int32_t) + sizeof(int32_t)) * elements;
auto pinned_buffer = IAllocator::MakeUniquePtr<void>(pinned_allocator, bytes);
auto pinned_buffer = IAllocator::MakeUniquePtr<void>(pinned_allocator, total_bytes);
char* pinned_data = static_cast<char*>(pinned_buffer.get());
// Copy tensors to one pinned memory buffer (so that we only need copy to GPU once)
memcpy(pinned_data, input_ids.Get<Tensor>().Data<int32_t>(), sizeof(int32_t) * elements);
memcpy(pinned_data + sizeof(int32_t) * elements, position_ids.Get<Tensor>().Data<int32_t>(), sizeof(int32_t) * elements);
memcpy(pinned_data + 2 * sizeof(int32_t) * elements, attention_mask.Get<Tensor>().Data<int32_t>(), sizeof(int32_t) * elements);
char* destination = pinned_data;
for (auto& input : inputs) {
if (input.IsAllocated()) {
const Tensor& tensor = input.Get<Tensor>();
const size_t bytes = input.Type()->Size() * tensor.Shape().Size();
MLDataType dataType = tensor.DataType();
if (dataType == DataTypeImpl::GetType<int32_t>()) {
memcpy(destination, input.Get<Tensor>().Data<int32_t>(), bytes);
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
memcpy(destination, input.Get<Tensor>().Data<int64_t>(), bytes);
} else {
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
"AddToFeeds: An implementation for the input type ",
dataType, " is not supported yet");
}
// Do not need alignment because GPT has int32 inputs (past is empty) and T5 encoder has int64 inputs.
destination += bytes;
}
}
if (!buffer) {
buffer = provider->GetScratchBuffer<char>(bytes);
buffer = provider->GetScratchBuffer<char>(total_bytes);
}
char* gpu_data = buffer.get();
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, bytes, cudaMemcpyHostToDevice, stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream));
// Create an event to make sure the async copy is finished before reading the data.
onnxruntime::contrib::cuda::AutoDestoryCudaEvent new_event;
@ -124,33 +143,32 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, stream));
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone));
// TODO: allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call.
OrtValue device_input_ids;
OrtValue device_position_ids;
OrtValue device_attention_mask;
// TODO(tianleiwu): allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call.
const OrtMemoryInfo& location = provider->GetAllocator(0, OrtMemTypeDefault)->Info();
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), shape, gpu_data, location, device_input_ids);
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), shape, gpu_data + sizeof(int32_t) * elements, location, device_position_ids);
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), shape, gpu_data + 2 * sizeof(int32_t) * elements, location, device_attention_mask);
for (auto& input : inputs) {
if (input.IsAllocated()) {
const Tensor& tensor = input.Get<Tensor>();
const TensorShape& shape = tensor.Shape();
const size_t bytes = input.Type()->Size() * shape.Size();
MLDataType dataType = tensor.DataType();
feeds.push_back(device_input_ids);
feeds.push_back(device_position_ids);
feeds.push_back(device_attention_mask);
OrtValue device_input;
Tensor::InitOrtValue(dataType, shape, gpu_data, location, device_input);
gpu_data += bytes;
feeds.push_back(device_input);
}
}
return Status::OK();
}
template <typename T>
void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
transformers::IBeamSearchCpuState* cpu_state,
gsl::span<int32_t>& sequence_lengths,
int batch_size,
int num_beams,
gsl::span<const int32_t> input_ids_in_cpu,
int sequence_length,
int max_length,
void* stream) {
// TODO: we can use another stream to avoid blocking subgraph execution.
// TODO(tianleiwu): we can use another stream to avoid blocking subgraph execution.
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream);
cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream);
cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream);
@ -162,18 +180,9 @@ void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
// copy sequence lengths to GPU
// since next_positions is only needed to update feeds after subgraph execution, so it is fine to use Async here.
// cudaMemsetAsync(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes(), cuda_stream);
cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), cudaMemcpyHostToDevice, cuda_stream);
memset(cpu_state->sequences_space.data(), 0, cpu_state->sequences_space.size_bytes());
// Copy input_ids to sequences[0]
gsl::span<int32_t> sequences_0 = cpu_state->sequences_space;
int batch_beam_size = batch_size * num_beams;
for (int i = 0; i < batch_beam_size; i++) {
for (int j = 0; j < sequence_length; j++) {
sequences_0[SafeInt<gsl::index>(i) * max_length + j] = static_cast<int32_t>(input_ids_in_cpu[SafeInt<gsl::index>(i) * sequence_length + j]);
}
if (!beam_state->next_positions.empty()) { // next_positions is empty for T5
cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
cudaMemcpyHostToDevice, cuda_stream);
}
}
@ -220,12 +229,13 @@ Status ProcessLogits(const OrtValue& logits, //
// When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1.
gsl::span<T>& next_token_logits = beam_state->next_token_logits;
if (input_length > 1) {
// TODO: use one kernel to replace a loop of memory copy.
// TODO(tianleiwu): use one kernel to replace a loop of memory copy.
const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span<const T> source(reinterpret_cast<const T*>(current_logits), vocab_size);
gsl::span<T> target = next_token_logits.subspan(i * vocab_size, vocab_size);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size,
cudaMemcpyDeviceToDevice, cuda_stream));
current_logits += input_length * vocab_size;
}
}
@ -253,12 +263,14 @@ Status ProcessLogits(const OrtValue& logits, //
// Copy sequences to device only when repetition penalty or no repeat ngram is used in kernel
BufferUniquePtr sequences_buffer;
int current_sequence_length = sequences->GetSequenceLength();
if (parameters->repetition_penalty != 1.0f || (parameters->no_repeat_ngram_size > 0 && current_sequence_length >= parameters->no_repeat_ngram_size)) {
bool run_ngram = parameters->no_repeat_ngram_size > 0 && current_sequence_length >= parameters->no_repeat_ngram_size;
if (parameters->repetition_penalty != 1.0f || run_ngram) {
size_t bytes = SafeInt<size_t>(sizeof(int32_t)) * batch_beam_size * parameters->max_length;
void* data = allocator->Alloc(bytes);
BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
sequences_buffer = std::move(temp_buffer);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes, cudaMemcpyHostToDevice, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes,
cudaMemcpyHostToDevice, cuda_stream));
}
cuda::LaunchLogitsProcessKernel<float>(
@ -277,20 +289,25 @@ Status ProcessLogits(const OrtValue& logits, //
cuda_stream);
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, num_beams, vocab_size);
dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size);
#endif
// Add beam score to next token scores. Corresponding python code is like:
// next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
cuda::LaunchAddProbsKernel(next_token_scores.data(), beam_state->beam_scores.data(), batch_size, num_beams, vocab_size, cuda_stream);
cuda::LaunchAddProbsKernel(next_token_scores.data(), beam_state->beam_scores.data(),
batch_size, num_beams, vocab_size, cuda_stream);
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("next_token_scores after adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
dumper->Print("next_token_scores adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
#endif
if (output_scores) {
// Append next token scores to the scores output.
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state->remaining_scores.data(), next_token_scores.data(), next_token_scores.size_bytes(), cudaMemcpyDeviceToDevice, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state->remaining_scores.data(),
next_token_scores.data(),
next_token_scores.size_bytes(),
cudaMemcpyDeviceToDevice,
cuda_stream));
beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size());
}
@ -301,7 +318,8 @@ Status ProcessLogits(const OrtValue& logits, //
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
auto element_type = DataTypeImpl::GetType<float>();
OrtValue next_token_scores_value;
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value);
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(),
next_token_scores_value);
const Tensor& input = next_token_scores_value.Get<Tensor>();
constexpr int axis = 1;
@ -311,7 +329,8 @@ Status ProcessLogits(const OrtValue& logits, //
std::unique_ptr<Tensor> topk_scores;
std::unique_ptr<Tensor> topk_indices;
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, topk_scores, topk_indices));
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
topk_scores, topk_indices));
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("topk_scores", *(topk_scores.get()));
@ -322,7 +341,8 @@ Status ProcessLogits(const OrtValue& logits, //
// next_indices = (next_tokens / vocab_size).long()
// next_tokens = next_tokens % vocab_size
const int64_t* next_token_indices = topk_indices->Data<int64_t>();
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(), batch_size, top_k, vocab_size, cuda_stream);
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
batch_size, top_k, vocab_size, cuda_stream);
const float* data = topk_scores->Data<float>();
@ -332,12 +352,26 @@ Status ProcessLogits(const OrtValue& logits, //
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
#endif
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), data, topk_scores->Shape().Size() * sizeof(float), cudaMemcpyDeviceToHost, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(), beam_state->next_tokens.data(), beam_state->next_tokens.size_bytes(), cudaMemcpyDeviceToHost, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_indices.data(), beam_state->next_indices.data(), beam_state->next_indices.size_bytes(), cudaMemcpyDeviceToHost, cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
data,
topk_scores->Shape().Size() * sizeof(float),
cudaMemcpyDeviceToHost,
cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(),
beam_state->next_tokens.data(),
beam_state->next_tokens.size_bytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_indices.data(),
beam_state->next_indices.data(),
beam_state->next_indices.size_bytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
gsl::span<const float> next_scores = gsl::make_span(cpu_state->topk_scores.data(), static_cast<typename gsl::span<float>::index_type>(topk_scores->Shape().Size()));
gsl::span<const float> next_scores = gsl::make_span(
cpu_state->topk_scores.data(),
static_cast<typename gsl::span<float>::index_type>(topk_scores->Shape().Size()));
gsl::span<const int32_t> next_tokens(cpu_state->topk_tokens.data(), beam_state->next_tokens.size());
gsl::span<const int32_t> next_indices(cpu_state->topk_indices.data(), beam_state->next_indices.size());
@ -354,55 +388,98 @@ template <typename T>
Status DeviceCopy(gsl::span<T> target, gsl::span<const T> source, void* stream, int copyDirection) {
assert(copyDirection >= 0 && copyDirection <= 3);
if (stream == nullptr) {
CUDA_RETURN_IF_ERROR(cudaMemcpy(target.data(), source.data(), source.size_bytes(), static_cast<cudaMemcpyKind>(copyDirection)));
CUDA_RETURN_IF_ERROR(cudaMemcpy(target.data(), source.data(), source.size_bytes(),
static_cast<cudaMemcpyKind>(copyDirection)));
} else {
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), static_cast<cudaMemcpyKind>(copyDirection), cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(),
static_cast<cudaMemcpyKind>(copyDirection), cuda_stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
}
return Status::OK();
}
template <typename T>
Status PickPastState(const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
gsl::span<const int32_t>& beam_indices,
AllocatorPtr allocator,
void* stream) {
for (size_t i = 1; i < last_outputs.size(); ++i) {
const OrtValue& present = last_outputs[i]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64)
Status PickGptPastState(const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
gsl::span<const int32_t>& beam_indices,
AllocatorPtr allocator,
void* stream) {
int num_present_tensors = static_cast<int>(last_outputs.size()) - transformers::GptSubgraph::kFirstPresentOutputIndex;
for (int i = 0; i < num_present_tensors; ++i) {
const OrtValue& present = last_outputs[transformers::GptSubgraph::kFirstPresentOutputIndex + i];
// shape is like (2, batch_beam_size, 12, past_seq_len, 64)
const TensorShape& past_shape = present.Get<Tensor>().Shape();
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
// Create a tensor with same shape.
// TODO: allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data.
// TODO(tianleiwu): allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data.
OrtValue past;
auto past_type = DataTypeImpl::GetType<T>();
Tensor::InitOrtValue(past_type, past_shape, allocator, past);
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
gsl::span<T> past_span = gsl::make_span<T>(past.GetMutable<Tensor>()->MutableData<T>(), past_shape.Size());
gsl::span<const T> present_span = gsl::make_span<const T>(present.Get<Tensor>().Data<T>(), past_shape.Size());
for (gsl::index j = 0; j < beam_indices.length(); j++) {
int32_t beam_index = beam_indices[j];
gsl::span<const T> present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
gsl::span<const T> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
gsl::span<const T> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam,
block_size_per_beam);
gsl::span<T> past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
gsl::span<T> past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_key.data(), present_key.data(), present_key.size_bytes(), cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_value.data(), present_value.data(), present_value.size_bytes(), cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_key.data(), present_key.data(), present_key.size_bytes(),
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_value.data(), present_value.data(), present_value.size_bytes(),
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
}
next_inputs[i + 2] = past;
next_inputs[transformers::GptSubgraph::kFirstPastInputIndex + i] = past;
}
return Status::OK();
}
// Copy present state to past state for T5 model
template <typename T>
Status PickT5PastState(const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t>& beam_indices,
AllocatorPtr allocator,
void* stream) {
for (int i = 0; i < num_present_tensors; ++i) {
const OrtValue& present = last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i];
// shape is like (batch_beam_size, 12, past_seq_len, 64)
const TensorShape& past_shape = present.Get<Tensor>().Shape();
auto block_size_per_beam = past_shape[1] * past_shape[2] * past_shape[3];
// Create a tensor with same shape.
// TODO(tianleiwu): allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data.
OrtValue past;
Tensor::InitOrtValue(DataTypeImpl::GetType<T>(), past_shape, allocator, past);
gsl::span<T> past_span = gsl::make_span<T>(past.GetMutable<Tensor>()->MutableData<T>(), past_shape.Size());
gsl::span<const T> present_span = gsl::make_span<const T>(present.Get<Tensor>().Data<T>(), past_shape.Size());
for (gsl::index j = 0; j < beam_indices.length(); j++) {
int32_t beam_index = beam_indices[j];
gsl::span<const T> present_beam = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
gsl::span<T> past_beam = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_beam.data(), present_beam.data(), present_beam.size_bytes(),
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
}
next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] = past;
}
return Status::OK();
}
template <typename T>
Status UpdateFeeds(
Status UpdateGptFeeds(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
@ -421,7 +498,8 @@ Status UpdateFeeds(
OrtValue input_ids;
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids);
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream)));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream)));
next_inputs[0] = input_ids;
// Update position IDs
@ -439,7 +517,8 @@ Status UpdateFeeds(
int32_t* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();
// Launch kernel to update position_ids and attention_mask for next iteration
cuda::LaunchUpdateKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length, reinterpret_cast<cudaStream_t>(stream));
cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length,
reinterpret_cast<cudaStream_t>(stream));
next_inputs[2] = attention_mask;
@ -453,12 +532,13 @@ Status UpdateFeeds(
// Update past state
if (num_beams == 1) {
const int k = transformers::GptSubgraph::kFirstPastInputIndex - transformers::GptSubgraph::kFirstPresentOutputIndex;
// feed present_* output to past_* inputs one by one
for (size_t i = 1; i < last_outputs.size(); ++i) {
next_inputs[i + 2] = last_outputs[i];
for (size_t i = transformers::GptSubgraph::kFirstPresentOutputIndex; i < last_outputs.size(); ++i) {
next_inputs[i + k] = last_outputs[i];
}
} else {
ORT_RETURN_IF_ERROR(PickPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream));
ORT_RETURN_IF_ERROR(PickGptPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream));
}
// Make sure data is ready before next subgraph execution.
@ -466,15 +546,63 @@ Status UpdateFeeds(
return Status::OK();
}
// Update decoder inputs given decoder outputs of last iteration.
template <typename T>
Status UpdateDecoderFeeds(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
const transformers::IConsoleDumper* dumper) {
// last_outputs: logits, present_key_self_0, present_value_self_0, ...
// next_inputs: input_ids,
// encoder_attention_mask, encoder_hidden_states,
// past_key_self_0, past_value_self_0, ...
// past_key_cross_0, past_value_cross_0, ...
// Only need copy beam next tokens to input_ids, and copy present_*_self_* to past_*_self_*,
// Update input_ids with next tokens.
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
int64_t dims[] = {batch_beam_size, 1};
TensorShape input_ids_shape(&dims[0], 2);
auto element_type = DataTypeImpl::GetType<int32_t>();
OrtValue input_ids;
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids);
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream)));
next_inputs[0] = input_ids;
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("input_ids", input_ids);
#else
ORT_UNUSED_PARAMETER(dumper);
#endif
// Update past state
ORT_ENFORCE(last_outputs.size() >= static_cast<size_t>(1 + num_present_tensors));
// TODO(tianleiwu): remove num_beams==1 once GreedySearch operator is available.
if (num_beams == 1) {
// feed present_* output to past_* inputs one by one
for (int i = 0; i < num_present_tensors; ++i) {
next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] =
last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i];
return Status::OK();
}
}
return PickT5PastState<T>(last_outputs, next_inputs, num_present_tensors, beam_indices, allocator, stream);
}
// Explicit template instantiations of functions
template void InitBeamState<float>(transformers::IBeamSearchState<float>* beam_state,
transformers::IBeamSearchCpuState* cpu_state,
gsl::span<int32_t>& sequence_lengths,
int batch_size,
int num_beams,
gsl::span<const int32_t> input_ids_in_cpu,
int sequence_length,
int max_length,
void* stream);
template Status ProcessLogits<float>(const OrtValue& logits,
@ -494,9 +622,15 @@ template Status DeviceCopy<float>(
gsl::span<float> target,
gsl::span<const float> source,
void* stream,
int copyDirectionn);
int copyDirection);
template Status UpdateFeeds<float>(
template Status DeviceCopy<int32_t>(
gsl::span<int32_t> target,
gsl::span<const int32_t> source,
void* stream,
int copyDirection);
template Status UpdateGptFeeds<float>(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
@ -510,13 +644,9 @@ template Status UpdateFeeds<float>(
// Float16
template void InitBeamState<MLFloat16>(transformers::IBeamSearchState<MLFloat16>* beam_state,
transformers::IBeamSearchCpuState* cpu_state,
gsl::span<int32_t>& sequence_lengths,
int batch_size,
int num_beams,
gsl::span<const int32_t> input_ids_in_cpu,
int sequence_length,
int max_length,
void* stream);
template Status ProcessLogits<MLFloat16>(const OrtValue& logits,
@ -532,7 +662,7 @@ template Status ProcessLogits<MLFloat16>(const OrtValue& logits,
void* stream,
const transformers::IConsoleDumper* dumper);
template Status UpdateFeeds<MLFloat16>(
template Status UpdateGptFeeds<MLFloat16>(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
@ -544,6 +674,28 @@ template Status UpdateFeeds<MLFloat16>(
int num_beams,
const transformers::IConsoleDumper* dumper);
template Status UpdateDecoderFeeds<float>(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
const transformers::IConsoleDumper* dumper);
template Status UpdateDecoderFeeds<MLFloat16>(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
const transformers::IConsoleDumper* dumper);
} // namespace BeamSearchCudaDeviceHelper
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,3 +1,6 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/shared_inc/fpgeneric.h"
@ -26,21 +29,15 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
std::unique_ptr<Tensor>& output_indices);
Status AddToFeeds(const IExecutionProvider* execution_provider,
OrtValue& input_ids,
OrtValue& position_ids,
OrtValue& attention_mask,
std::initializer_list<OrtValue> inputs,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer);
template <typename T>
void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
transformers::IBeamSearchCpuState* cpu_state,
gsl::span<int32_t>& sequence_lengths,
int batch_size,
int num_beams,
gsl::span<const int32_t> input_ids_in_cpu,
int sequence_length,
int max_length,
void* stream);
template <typename T>
@ -64,7 +61,7 @@ Status DeviceCopy(gsl::span<T> target,
int copyDirection);
template <typename T>
Status UpdateFeeds(
Status UpdateGptFeeds(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
@ -76,6 +73,23 @@ Status UpdateFeeds(
int num_beams,
const transformers::IConsoleDumper* dumper);
// ---------------------------------------------------------------
// Functions for encoder-decoder model like T5
// ---------------------------------------------------------------
// Update decoder inputs given decoder outputs of last iteration.
template <typename T>
Status UpdateDecoderFeeds(
AllocatorPtr allocator,
void* stream,
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int num_present_tensors,
gsl::span<const int32_t> beam_next_tokens,
gsl::span<const int32_t> beam_indices,
int num_beams,
const transformers::IConsoleDumper* dumper);
} // namespace BeamSearchCudaDeviceHelper
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -1,10 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "beam_search_impl.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "cub/util_type.cuh"
#include "contrib_ops/cuda/transformers/beam_search_impl.h"
namespace onnxruntime {
namespace contrib {
@ -52,7 +52,11 @@ void LaunchNextTokenKernel(const int64_t* next_token_indices,
int total_elements = batch_size * top_k;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
NextTokenKernel<<<gridSize, blockSize, 0, stream>>>(next_token_indices, next_indices, next_tokens, vocab_size, total_elements);
NextTokenKernel<<<gridSize, blockSize, 0, stream>>>(next_token_indices,
next_indices,
next_tokens,
vocab_size,
total_elements);
}
template <typename T>
@ -153,8 +157,19 @@ void LaunchLogitsProcessKernel(
int total_elements = batch_size * num_beams * vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(next_token_scores, vocab_mask, prefix_vocab_mask, num_beams, vocab_size, total_elements, demote_token_id,
sequences, max_sequence_length, current_sequence_length, repetition_penalty, no_repeat_ngram_size);
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(
next_token_scores,
vocab_mask,
prefix_vocab_mask,
num_beams,
vocab_size,
total_elements,
demote_token_id,
sequences,
max_sequence_length,
current_sequence_length,
repetition_penalty,
no_repeat_ngram_size);
}
// Instantiation
@ -221,11 +236,11 @@ template void LaunchAddProbsKernel(
cudaStream_t stream);
template <typename T>
__global__ void UpdateInputsKernel(const T* old_mask_data,
T* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length) {
__global__ void UpdateGptInputsKernel(const T* old_mask_data,
T* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < batch_beam_size * current_length) {
// Update attention mask.
@ -240,19 +255,20 @@ __global__ void UpdateInputsKernel(const T* old_mask_data,
}
}
void LaunchUpdateKernel(const int32_t* old_mask_data,
int32_t* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length,
cudaStream_t stream) {
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
int32_t* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length,
cudaStream_t stream) {
assert(current_length > 0);
int total_elements = batch_beam_size * current_length;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
UpdateInputsKernel<int32_t><<<gridSize, blockSize, 0, stream>>>(old_mask_data, mask_data, next_positions, batch_beam_size, current_length);
UpdateGptInputsKernel<int32_t><<<gridSize, blockSize, 0, stream>>>(
old_mask_data, mask_data, next_positions, batch_beam_size, current_length);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -2,8 +2,10 @@
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include <cuda_fp16.h>
namespace onnxruntime {
namespace contrib {
namespace cuda {
@ -46,12 +48,12 @@ void LaunchNextTokenKernel(const int64_t* next_token_indices,
int vocab_size,
cudaStream_t stream);
void LaunchUpdateKernel(const int32_t* old_mask_data,
int32_t* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length,
cudaStream_t stream);
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
int32_t* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib

View file

@ -1,9 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cuda_runtime_api.h>
#include "core/providers/cuda/cuda_common.h"
#include "dump_cuda_tensor.h"
#include "core/framework/print_tensor_utils.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
namespace onnxruntime {
namespace contrib {
@ -38,11 +39,13 @@ class PinnedHostBuffer {
};
template <typename T>
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1) {
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool is_gpu_tensor) {
// Occasionally, user will need dump CPU tensor in CUDA EP.
// In that case, we copy tensor data as well. It is not needed, but it keeps code simple.
int num_items = dim0 * dim1;
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
cudaDeviceSynchronize();
cudaMemcpy(*data, tensor, num_items * sizeof(T), cudaMemcpyDeviceToHost);
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
if (nullptr != name) {
std::cout << std::string(name) << std::endl;
@ -56,11 +59,11 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1) {
}
template <typename T>
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) {
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, bool is_gpu_tensor) {
int num_items = dim0 * dim1 * dim2;
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
cudaDeviceSynchronize();
cudaMemcpy(*data, tensor, num_items * sizeof(T), cudaMemcpyDeviceToHost);
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
if (nullptr != name) {
std::cout << std::string(name) << std::endl;
@ -75,14 +78,15 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di
void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) {
MLDataType dataType = tensor.DataType();
bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU);
if (dataType == DataTypeImpl::GetType<float>()) {
DumpGpuTensor<float>(name, tensor.Data<float>(), dim0, dim1, dim2);
DumpGpuTensor<float>(name, tensor.Data<float>(), dim0, dim1, dim2, is_gpu_tensor);
} else if (dataType == DataTypeImpl::GetType<MLFloat16>()) {
DumpGpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1, dim2);
DumpGpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1, dim2, is_gpu_tensor);
} else if (dataType == DataTypeImpl::GetType<int32_t>()) {
DumpGpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1, dim2);
DumpGpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1, dim2, is_gpu_tensor);
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
DumpGpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1, dim2);
DumpGpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1, dim2, is_gpu_tensor);
} else {
assert(0);
}
@ -90,14 +94,15 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, i
void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) {
MLDataType dataType = tensor.DataType();
bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU);
if (dataType == DataTypeImpl::GetType<float>()) {
DumpGpuTensor<float>(name, tensor.Data<float>(), dim0, dim1);
DumpGpuTensor<float>(name, tensor.Data<float>(), dim0, dim1, is_gpu_tensor);
} else if (dataType == DataTypeImpl::GetType<MLFloat16>()) {
DumpGpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1);
DumpGpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1, is_gpu_tensor);
} else if (dataType == DataTypeImpl::GetType<int32_t>()) {
DumpGpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1);
DumpGpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1, is_gpu_tensor);
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
DumpGpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1);
DumpGpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1, is_gpu_tensor);
} else {
assert(0);
}
@ -106,12 +111,18 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) {
void DumpGpuTensor(const char* name, const Tensor& tensor) {
const auto& shape = tensor.Shape();
if (nullptr != name) {
std::cout << std::string(name) << std::endl;
}
std::cout << "Shape:" << shape << std::endl;
std::cout << tensor.Location().ToString() << std::endl;
size_t num_dims = shape.NumDimensions();
if (num_dims >= 3) {
int dim0 = static_cast<int>(shape.SizeToDimension(num_dims - 2));
int dim1 = static_cast<int>(shape[num_dims - 2]);
int dim2 = static_cast<int>(shape[num_dims - 1]);
DumpGpuTensor(name, tensor, dim0, dim1, dim2);
DumpGpuTensor(nullptr, tensor, dim0, dim1, dim2);
return;
}
@ -121,47 +132,47 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) {
num_rows = static_cast<size_t>(shape[0]);
}
size_t row_size = num_items / num_rows;
DumpGpuTensor(name, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
DumpGpuTensor(nullptr, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
}
void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor<float>(name, tensor, dim0, dim1);
DumpGpuTensor<float>(name, tensor, dim0, dim1, true);
}
void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1);
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1, true);
}
void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1);
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1, true);
}
void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const {
if (is_enabled_)
DumpGpuTensor<int32_t>(name, tensor, dim0, dim1);
DumpGpuTensor<int32_t>(name, tensor, dim0, dim1, true);
}
void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const {
if (is_enabled_)
DumpGpuTensor<float>(name, tensor, dim0, dim1, dim2);
DumpGpuTensor<float>(name, tensor, dim0, dim1, dim2, true);
}
void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const {
if (is_enabled_)
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1, dim2);
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1, dim2, true);
}
void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const {
if (is_enabled_)
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1, dim2);
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1, dim2, true);
}
void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const {
if (is_enabled_)
DumpGpuTensor<int32_t>(name, tensor, dim0, dim1, dim2);
DumpGpuTensor<int32_t>(name, tensor, dim0, dim1, dim2, true);
}
void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const {
@ -235,4 +246,4 @@ void CudaTensorConsoleDumper::Print(const char*, const std::string&, bool) const
} // namespace transformers
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#pragma once
#include <string>
#include "core/framework/tensorprotoutils.h"
#include "core/framework/ort_value.h"
@ -33,4 +34,4 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons
} // namespace transformers
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
} // namespace onnxruntime

View file

@ -957,10 +957,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
.SetDoc("Beam Search for text generation. Supports GPT-2 decoder.")
.Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT)
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
.Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast<int64_t>(-1))
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("encoder_decoder_init", "subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
.Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
.Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH)
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")

View file

@ -12,3 +12,5 @@ import convert_to_onnx
# added for backward compatible
import gpt2_helper
sys.path.append(os.path.join(os.path.dirname(__file__), "models", "t5"))

View file

@ -9,9 +9,8 @@
import argparse
import os
import random
import sys
from pathlib import Path
from typing import Dict, List, Tuple, Union
from typing import Dict, Optional, Tuple
import numpy as np
from onnx import ModelProto, TensorProto, numpy_helper
@ -27,7 +26,7 @@ def fake_input_ids_data(
input_ids (TensorProto): graph input of the input_ids input tensor
batch_size (int): batch size
sequence_length (int): sequence length
dictionary_size (int): vacaburary size of dictionary
dictionary_size (int): vocabulary size of dictionary
Returns:
np.ndarray: the input tensor created
@ -115,28 +114,28 @@ def fake_input_mask_data(
return data
def output_test_data(dir: str, inputs: np.ndarray):
def output_test_data(directory: str, inputs: Dict[str, np.ndarray]):
"""Output input tensors of test data to a directory
Args:
dir (str): path of a directory
inputs (numpy.ndarray): numpy array
directory (str): path of a directory
inputs (Dict[str, np.ndarray]): map from input name to value
"""
if not os.path.exists(dir):
if not os.path.exists(directory):
try:
os.mkdir(dir)
os.mkdir(directory)
except OSError:
print("Creation of the directory %s failed" % dir)
print("Creation of the directory %s failed" % directory)
else:
print("Successfully created the directory %s " % dir)
print("Successfully created the directory %s " % directory)
else:
print("Warning: directory %s existed. Files will be overwritten." % dir)
print("Warning: directory %s existed. Files will be overwritten." % directory)
index = 0
for name, data in inputs.items():
tensor = numpy_helper.from_array(data, name)
with open(os.path.join(dir, "input_{}.pb".format(index)), "wb") as f:
f.write(tensor.SerializeToString())
with open(os.path.join(directory, "input_{}.pb".format(index)), "wb") as file:
file.write(tensor.SerializeToString())
index += 1
@ -158,7 +157,7 @@ def fake_test_data(
batch_size (int): batch size
sequence_length (int): sequence length
test_cases (int): number of test cases
dictionary_size (int): vocaburary size of dictionary for input_ids
dictionary_size (int): vocabulary size of dictionary for input_ids
verbose (bool): print more information or not
random_seed (int): random seed
input_ids (TensorProto): graph input of input IDs
@ -167,7 +166,8 @@ def fake_test_data(
random_mask_length (bool): whether mask random number of words at the end
Returns:
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictonary with input name as key and a tensor as value
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
with input name as key and a tensor as value
"""
assert input_ids is not None
@ -202,7 +202,7 @@ def generate_test_data(
input_mask: TensorProto,
random_mask_length: bool,
):
"""Create given number of minput data for testing
"""Create given number of input data for testing
Args:
batch_size (int): batch size
@ -216,7 +216,8 @@ def generate_test_data(
random_mask_length (bool): whether mask random number of words at the end
Returns:
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictonary with input name as key and a tensor as value
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
with input name as key and a tensor as value
"""
dictionary_size = 10000
all_inputs = fake_test_data(
@ -251,12 +252,13 @@ def get_graph_input_from_embed_node(onnx_model, embed_node, input_index):
def find_bert_inputs(
onnx_model: OnnxModel,
input_ids_name: str = None,
segment_ids_name: str = None,
input_mask_name: str = None,
) -> Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]:
input_ids_name: Optional[str] = None,
segment_ids_name: Optional[str] = None,
input_mask_name: Optional[str] = None,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
"""Find graph inputs for BERT model.
First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming.
First, we will deduce inputs from EmbedLayerNormalization node.
If not found, we will guess the meaning of graph inputs based on naming.
Args:
onnx_model (OnnxModel): onnx model object
@ -266,10 +268,12 @@ def find_bert_inputs(
Raises:
ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
ValueError: Exptected graph input number does not match with specifeid input_ids_name, segment_ids_name and input_mask_name
ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
and input_mask_name
Returns:
Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: input tensors of input_ids, segment_ids and input_mask
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
segment_ids and input_mask
"""
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
@ -340,12 +344,13 @@ def find_bert_inputs(
def get_bert_inputs(
onnx_file: str,
input_ids_name: str = None,
segment_ids_name: str = None,
input_mask_name: str = None,
):
input_ids_name: Optional[str] = None,
segment_ids_name: Optional[str] = None,
input_mask_name: Optional[str] = None,
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
"""Find graph inputs for BERT model.
First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming.
First, we will deduce inputs from EmbedLayerNormalization node.
If not found, we will guess the meaning of graph inputs based on naming.
Args:
onnx_file (str): onnx model path
@ -354,11 +359,12 @@ def get_bert_inputs(
input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
Returns:
Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: input tensors of input_ids, segment_ids and input_mask
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
segment_ids and input_mask
"""
model = ModelProto()
with open(onnx_file, "rb") as f:
model.ParseFromString(f.read())
with open(onnx_file, "rb") as file:
model.ParseFromString(file.read())
onnx_model = OnnxModel(model)
return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
@ -447,9 +453,9 @@ def create_and_save_test_data(
test_cases: int,
seed: int,
verbose: bool,
input_ids_name: str,
segment_ids_name: str,
input_mask_name: str,
input_ids_name: Optional[str],
segment_ids_name: Optional[str],
input_mask_name: Optional[str],
only_input_tensors: bool,
):
"""Create test data for a model, and save test data to a directory.
@ -482,24 +488,24 @@ def create_and_save_test_data(
)
for i, inputs in enumerate(all_inputs):
dir = os.path.join(output_dir, "test_data_set_" + str(i))
output_test_data(dir, inputs)
directory = os.path.join(output_dir, "test_data_set_" + str(i))
output_test_data(directory, inputs)
if only_input_tensors:
return
import onnxruntime
sess = onnxruntime.InferenceSession(model)
output_names = [output.name for output in sess.get_outputs()]
session = onnxruntime.InferenceSession(model)
output_names = [output.name for output in session.get_outputs()]
for i, inputs in enumerate(all_inputs):
dir = os.path.join(output_dir, "test_data_set_" + str(i))
result = sess.run(output_names, inputs)
directory = os.path.join(output_dir, "test_data_set_" + str(i))
result = session.run(output_names, inputs)
for i, output_name in enumerate(output_names):
tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_names[i])
with open(os.path.join(dir, "output_{}.pb".format(i)), "wb") as f:
f.write(tensor_result.SerializeToString())
tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_name)
with open(os.path.join(directory, "output_{}.pb".format(i)), "wb") as file:
file.write(tensor_result.SerializeToString())
def main():

View file

@ -6,11 +6,20 @@
This converts GPT2 or T5 model to onnx with beam search operator.
Example 1: convert gpt2 model with beam search:
python convert_beam_search.py -m gpt2 --decoder_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
Example 2: convert T5 model with beam search:
python ./models/t5/convert_to_onnx.py -m t5-small -s
python convert_beam_search.py -m t5-small --model_type t5 --decoder_onnx ./onnx_models/t5-small_decoder.onnx --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder_decoder_init.onnx --output ./onnx_models/t5_small_beam_search.onnx
python convert_beam_search.py -m gpt2 --decoder_onnx ./onnx_models/gpt2_past_fp32.onnx \
--output ./onnx_models/gpt2_beam_search.onnx --output_sequences_scores
Example 2: convert T5 model with beam search in two steps:
cd ./models/t5
python convert_to_onnx.py -m t5-small
cd ../..
python convert_beam_search.py -m t5-small --model_type t5 \
--decoder_onnx ./models/t5/onnx_models/t5-small_decoder.onnx \
--encoder_decoder_init_onnx ./models/t5/onnx_models/t5-small_encoder_decoder_init.onnx \
--output ./models/t5/onnx_models/t5_small_beam_search.onnx
Example 3: convert T5 model with beam search. All in one step:
python convert_beam_search.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx
"""
import argparse
@ -19,27 +28,36 @@ import os
import sys
import time
from pathlib import Path
from typing import List, Union
from typing import Any, Dict, List, Optional, Union
import numpy as np
import onnx
import torch
from benchmark_helper import Precision
from onnx import helper
from onnx import onnx_pb as onnx_proto
from packaging import version
from transformers import GPT2Config, T5Config
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, T5Config, T5ForConditionalGeneration, T5Tokenizer
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers
sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2"))
from convert_to_onnx import main as convert_gpt2_to_onnx
from gpt2_helper import PRETRAINED_GPT2_MODELS
from gpt2_helper import PRETRAINED_GPT2_MODELS # noqa: E402
from models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx # noqa: E402
config: Union[GPT2Config, T5Config] = None
sys.path.append(os.path.join(os.path.dirname(__file__), "models", "t5"))
from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models # noqa: E402
logger = logging.getLogger("")
def parse_arguments(argv=None):
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
"""Parse arguments
Args:
argv (Optional[List[str]], optional): _description_. Defaults to None.
Returns:
argparse.Namespace: Parsed arguments.
"""
parser = argparse.ArgumentParser()
parser.add_argument(
@ -69,9 +87,10 @@ def parse_arguments(argv=None):
parser.add_argument(
"--decoder_onnx",
required=True,
required=False,
type=str,
help="Output directory for decoder onnx model, or model path ends with .onnx",
default="",
help="Path of onnx model for decoder. Required for gpt2 model type.",
)
parser.add_argument(
@ -79,14 +98,14 @@ def parse_arguments(argv=None):
required=False,
type=str,
default="",
help="path of ONNX model for encoder and decoder initialization. Required for t5 model type.",
help="Path of ONNX model for encoder and decoder initialization. For t5 model type.",
)
parser.add_argument(
"--output",
required=False,
type=str,
help="Output directory for beam search model, or model path ends with .onnx",
help="Output path for onnx model with beam search.",
)
parser.add_argument(
@ -113,6 +132,14 @@ def parse_arguments(argv=None):
)
parser.set_defaults(disable_parity=False)
parser.add_argument(
"--verbose",
required=False,
action="store_true",
help="Print more information",
)
parser.set_defaults(verbose=False)
parser.add_argument(
"--torch_performance",
required=False,
@ -185,7 +212,7 @@ def parse_arguments(argv=None):
type=float,
required=False,
default=1,
help="Positive. >1 to penalize and <1 to encorage short sentence.",
help="Positive. >1 to penalize and <1 to encourage short sentence.",
)
beam_search_group.add_argument(
@ -193,7 +220,7 @@ def parse_arguments(argv=None):
type=float,
required=False,
default=1,
help="Positive. >1 to penalize and <1 to encorage.",
help="Positive. >1 to penalize and <1 to encourage.",
)
beam_search_group.add_argument(
@ -217,10 +244,14 @@ def parse_arguments(argv=None):
return args
def gpt2_to_onnx(args):
def gpt2_to_onnx(args: argparse.Namespace):
"""Convert GPT-2 model to onnx
Args:
args (argparse.Namespace): arguments parsed from command line
"""
model_name = args.model_name_or_path
print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.decoder_onnx} ...")
arguments = [
"--model_name_or_path",
model_name,
@ -233,7 +264,7 @@ def gpt2_to_onnx(args):
"1",
"--test_cases",
"10",
"--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, postion_ids and attention_mask
"--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, position_ids and attention_mask
]
if args.use_gpu:
arguments.append("--use_gpu")
@ -242,39 +273,78 @@ def gpt2_to_onnx(args):
if args.precision == Precision.FLOAT16:
assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
# TODO: Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
# TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
# Need change cuda kernel to support a combination of fp32 logits and fp16 past state.
# Currently logits and past state shall be same data type.
arguments.extend(["--op_block_list", "Add", "LayerNormalization", "FastGelu"])
convert_gpt2_to_onnx(arguments)
if args.verbose:
print(f"arguments for convert_to_onnx:{arguments}")
convert_gpt2_to_onnx(argv=arguments)
def shape_inference(decoder_onnx_path):
if version.parse(onnx.__version__) >= version.parse("1.11.0"):
logger.warn("SymbolicShapeInference might fail using onnx version 1.11. Please install 1.10.0 for now.")
def t5_to_onnx(args: argparse.Namespace):
"""Convert T5 model to onnx
Args:
args (argparse.Namespace): arguments parsed from command line
"""
paths = export_t5_onnx_models(
args.model_name_or_path,
args.cache_dir,
Path(args.output).parent,
use_gpu=args.use_gpu,
use_external_data_format=args.use_external_data_format,
optimize_onnx=False,
precision=args.precision,
verbose=False,
use_decoder_start_token=False,
merge_encoder_and_decoder_init=True,
overwrite=True,
disable_auto_mixed_precision=False,
use_int32_inputs=True,
)
args.encoder_decoder_init_onnx = paths[0]
args.decoder_onnx = paths[1]
def shape_inference(onnx_path: str):
"""Shape inference on an onnx file, which will be overwritten.
Args:
onnx_path (str): Path of onnx model
"""
# Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
out = SymbolicShapeInference.infer_shapes(onnx.load(decoder_onnx_path), auto_merge=True, guess_output_rank=False)
out = SymbolicShapeInference.infer_shapes(onnx.load(onnx_path), auto_merge=True, guess_output_rank=False)
if out:
# TODO: Use external format if input has extra data.
onnx.save(out, decoder_onnx_path)
# TODO(tianleiwu): Use external format if input has extra data.
onnx.save(out, onnx_path)
else:
print("Failed to run symbolic shape inference on the model.")
def create_ort_session(model_path, use_gpu):
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
from onnxruntime import __version__ as ort_version
from onnxruntime import get_available_providers
def create_ort_session(model_path: str, use_gpu: bool) -> InferenceSession:
"""Create OnnxRuntime session.
Args:
model_path (str): onnx model path
use_gpu (bool): use GPU or not
Raises:
RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified.
Returns:
onnxruntime.InferenceSession: The created session.
"""
sess_options = SessionOptions()
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
if use_gpu:
if "CUDAExecutionProvider" not in get_available_providers():
raise RuntimeError("CUDAExecutionProvider is not avaiable for --use_gpu!")
raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!")
else:
print("use CUDAExecutionProvider")
@ -282,11 +352,26 @@ def create_ort_session(model_path, use_gpu):
return ort_session
def verify_gpt2_subgraph(graph, precision):
def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision):
"""Verify GPT-2 subgraph
Args:
graph (onnx.GraphProto): onnx graph of GPT-2
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
Raises:
ValueError: Number of inputs not expected.
ValueError: Input name is not expected.
ValueError: Input data type is not expected.
ValueError: Number of outputs not expected.
ValueError: Output name is not expected.
ValueError: Output data type is not expected.
"""
is_float16 = Precision.FLOAT16 == precision
input_count = len(graph.input)
layer_count = input_count - 3
assert layer_count >= 1
expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)]
if len(graph.input) != len(expected_inputs):
@ -300,10 +385,9 @@ def verify_gpt2_subgraph(graph, precision):
if i >= 3:
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
if graph.input[i].type.tensor_type.elem_type != expected_type:
raise ValueError(
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.input[i].type.tensor_type.elem_type}"
)
input_type = graph.input[i].type.tensor_type.elem_type
if input_type != expected_type:
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
print("Verifying GPT-2 graph inputs: name and data type are good.")
expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)]
@ -315,49 +399,197 @@ def verify_gpt2_subgraph(graph, precision):
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
if graph.output[i].type.tensor_type.elem_type != expected_type:
raise ValueError(
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}"
)
output_type = graph.output[i].type.tensor_type.elem_type
if output_type != expected_type:
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}")
print("Verifying GPT-2 graph outputs: name and data type are good.")
# TODO: verify shapes of inputs and outputs.
# TODO(tianleiwu): verify shapes of inputs and outputs.
return
def verify_t5_decoder_subgraph(graph, precision):
# TODO: implement it
pass
def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
"""Verify T5 decoder subgraph
Args:
graph (onnx.GraphProto): onnx graph of T5 decoder
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
Raises:
ValueError: Number of inputs not expected.
ValueError: Input name is not expected.
ValueError: Input data type is not expected.
ValueError: Number of outputs not expected.
ValueError: Output name is not expected.
ValueError: Output data type is not expected.
"""
is_float16 = Precision.FLOAT16 == precision
float_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
input_count = len(graph.input)
layer_count = (input_count - 3) // 4
assert layer_count >= 1
# Expect inputs:
# input_ids: int32 (B, 1)
# encoder_attention_mask: int32 (B, encode_sequence_length)
# encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
# past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
# past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
# ... (for each self attention layer)
# past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
# past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
# ... (for each cross attention layer)
expected_inputs = ["input_ids", "encoder_attention_mask", "encoder_hidden_states"]
for i in range(layer_count):
expected_inputs.append(f"past_key_self_{i}")
expected_inputs.append(f"past_value_self_{i}")
for i in range(layer_count):
expected_inputs.append(f"past_key_cross_{i}")
expected_inputs.append(f"past_value_cross_{i}")
if len(graph.input) != len(expected_inputs):
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
for i, expected_input in enumerate(expected_inputs):
if graph.input[i].name != expected_input:
raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
expected_type = onnx_proto.TensorProto.INT32 if i < 2 else float_type
input_type = graph.input[i].type.tensor_type.elem_type
if input_type != expected_type:
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
# Expect outputs:
# logits: (B, 1, vocab_size)
# present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
# present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
# ... (for each self attention layer)
expected_outputs = ["logits"]
for i in range(layer_count):
expected_outputs.append(f"present_key_self_{i}")
expected_outputs.append(f"present_value_self_{i}")
if len(graph.output) != len(expected_outputs):
raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
for i, expected_output in enumerate(expected_outputs):
if graph.output[i].name != expected_output:
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
output_type = graph.output[i].type.tensor_type.elem_type
if output_type != float_type:
raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}")
def verify_t5_encoder_decoder_init_subgraph(graph, precision):
# TODO: implement it
pass
def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision):
"""Verify T5 decoder subgraph
Args:
graph (onnx.GraphProto): onnx graph of T5 decoder
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
Raises:
ValueError: Number of inputs not expected.
ValueError: Input name is not expected.
ValueError: Input data type is not expected.
ValueError: Number of outputs not expected.
ValueError: Output name is not expected.
ValueError: Output data type is not expected.
"""
is_float16 = Precision.FLOAT16 == precision
layer_count = (len(graph.output) - 2) // 4
assert layer_count >= 1
# Expect 3 inputs:
# encoder_input_ids: int32 (B, encode_sequence_length)
# encoder_attention_mask: int32 (B, encode_sequence_length)
# decoder_input_ids: int32 (B, 1)
expected_inputs = ["encoder_input_ids", "encoder_attention_mask", "decoder_input_ids"]
if len(graph.input) != len(expected_inputs):
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
for i, expected_input in enumerate(expected_inputs):
if graph.input[i].name != expected_input:
raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
expected_type = onnx_proto.TensorProto.INT32
input_type = graph.input[i].type.tensor_type.elem_type
if input_type != expected_type:
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
# Expected outputs:
# logits: (B, 1, vocab_size)
# encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
# present_key_self_0: (B, num_heads, 1, head_size)
# present_value_self_0: (B, num_heads, 1, head_size)
# ... (for each self attention layer)
# present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
# present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
# ... (for each cross attention layer)
expected_outputs = ["logits", "encoder_hidden_states"]
for i in range(layer_count):
expected_outputs.append(f"present_key_self_{i}")
expected_outputs.append(f"present_value_self_{i}")
for i in range(layer_count):
expected_outputs.append(f"present_key_cross_{i}")
expected_outputs.append(f"present_value_cross_{i}")
if len(graph.output) != len(expected_outputs):
raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
for i, expected_output in enumerate(expected_outputs):
if graph.output[i].name != expected_output:
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
output_type = graph.output[i].type.tensor_type.elem_type
if output_type != expected_type:
raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}")
print("T5 encoder graph verified: name and data type of inputs and outputs are good.")
def convert_model(args):
if os.path.exists(args.decoder_onnx):
print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
else:
assert args.model_type == "gpt2", "please have onnx model ready for model type that is not gpt2"
gpt2_to_onnx(args)
def convert_model(args: argparse.Namespace):
"""Convert model according to command line arguments.
# TODO: fix shape inference for T5. Currently symbolic shape inference on T5 is broken.
Args:
args (argparse.Namespace): arguments parsed from command line
"""
is_gpt2: bool = args.model_type == "gpt2"
if is_gpt2:
if os.path.exists(args.decoder_onnx):
print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
else:
print(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...")
gpt2_to_onnx(args)
else: # t5
if args.decoder_onnx and args.encoder_decoder_init_onnx:
print(
f"skip convert_to_onnx since paths specified: {args.decoder_onnx} and {args.encoder_decoder_init_onnx}"
)
else:
print(f"Convert T5 model {args.model_name_or_path} to onnx ...")
t5_to_onnx(args)
# TODO(tianleiwu): fix shape inference for T5. Currently symbolic shape inference on T5 is broken.
enable_shape_inference = args.model_type == "gpt2"
if enable_shape_inference:
print(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.")
shape_inference(args.decoder_onnx)
global config
if args.model_type == "gpt2":
if is_gpt2:
config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
else:
config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
print(config)
if args.verbose:
print(config)
eos_token_id = config.eos_token_id
pad_token_id = config.eos_token_id
pad_token_id = config.eos_token_id if is_gpt2 else config.pad_token_id
vocab_size = config.vocab_size
# if vocab_size is given in parameters use that.
@ -394,7 +626,7 @@ def convert_model(args):
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
outputs.append("scores")
node = helper.make_node(
node = onnx.helper.make_node(
"BeamSearch",
inputs=inputs,
outputs=outputs,
@ -403,12 +635,12 @@ def convert_model(args):
node.domain = "com.microsoft"
node.attribute.extend(
[
helper.make_attribute("eos_token_id", eos_token_id),
helper.make_attribute("pad_token_id", pad_token_id),
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
helper.make_attribute("decoder", model.graph),
onnx.helper.make_attribute("eos_token_id", eos_token_id),
onnx.helper.make_attribute("pad_token_id", pad_token_id),
onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
onnx.helper.make_attribute("decoder", model.graph),
]
)
@ -421,22 +653,25 @@ def convert_model(args):
verify_t5_encoder_decoder_init_subgraph(init_model.graph, args.precision)
node.attribute.extend(
[
helper.make_attribute("encoder_decoder_init", init_model.graph),
onnx.helper.make_attribute("encoder", init_model.graph),
onnx.helper.make_attribute(
"decoder_start_token_id", config.decoder_start_token_id if len(init_model.graph.input) == 3 else -1
),
]
)
from onnx import TensorProto
# graph inputs
input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])
max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
input_ids = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])
max_length = onnx.helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
min_length = onnx.helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
num_beams = onnx.helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
num_return_sequences = onnx.helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
temperature = onnx.helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
graph_inputs = [
input_ids,
@ -451,23 +686,23 @@ def convert_model(args):
]
if args.prefix_vocab_mask:
prefix_vocab_mask = helper.make_tensor_value_info(
prefix_vocab_mask = onnx.helper.make_tensor_value_info(
"prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size]
)
graph_inputs.append(prefix_vocab_mask)
# graph outputs
sequences = helper.make_tensor_value_info(
sequences = onnx.helper.make_tensor_value_info(
"sequences",
TensorProto.INT32,
["batch_size", "num_return_sequences", "max_length"],
)
sequences_scores = helper.make_tensor_value_info(
sequences_scores = onnx.helper.make_tensor_value_info(
"sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"]
)
scores = helper.make_tensor_value_info(
scores = onnx.helper.make_tensor_value_info(
"scores",
TensorProto.FLOAT,
["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
@ -483,7 +718,7 @@ def convert_model(args):
if args.output_token_scores:
graph_outputs.append(scores)
new_graph = helper.make_graph(
new_graph = onnx.helper.make_graph(
[node],
f"{args.model_type}-beam-search",
graph_inputs,
@ -492,18 +727,44 @@ def convert_model(args):
)
# Create the model
new_model = helper.make_model(
new_model = onnx.helper.make_model(
new_graph,
producer_name="onnxruntime.transformers",
opset_imports=model.opset_import,
)
# TODO(tianleiwu): move shared initializers from T5 encoder and decoder subgraphs to parent graph to save memory.
onnx.save(new_model, args.output)
def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id, pad_token_id, bad_words_ids):
def test_torch_performance(
args: argparse.Namespace,
model: Union[GPT2LMHeadModel, T5ForConditionalGeneration],
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
eos_token_id: int,
pad_token_id: int,
bad_words_ids: List[List[int]],
) -> Dict[str, Any]:
"""Test PyTorch performance of text generation.
Args:
args (argparse.Namespace): arguments parsed from command line
model (Union[GPT2LMHeadModel, T5ForConditionalGeneration]): PyTorch model
input_ids (torch.Tensor): input_ids
attention_mask (torch.Tensor): Attention mask
eos_token_id (int): EOS token ID
pad_token_id (int): Padding token ID
bad_words_ids (List[List[int]]): Words shall not be generated.
Raises:
RuntimeError: PyTorch with CUDA is not available for --use_gpu
Returns:
Dict[str, Any]: A dictionary with string with metric name, and value can be integer or string.
"""
if args.use_gpu and not torch.cuda.is_available():
logger.error("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
return None
raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.")
if args.precision == Precision.FLOAT16:
model.half()
@ -543,23 +804,27 @@ def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id,
return get_latency_result(torch_latency, batch_size)
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
if args.model_type != "gpt2":
print(
f"Skipping parity test since the support for model type {args.model_type} is not implemented in OnnxRuntime"
)
return True
def test_gpt_model(args: argparse.Namespace, use_vocab_mask: bool = False, sentences: Optional[List[str]] = None):
"""Test GPT-2 model
Args:
args (argparse.Namespace): arguments parsed from command line
use_vocab_mask (bool, optional): use vocabulary mask. Defaults to False.
sentences (Optional[List[str]], optional): input text. Defaults to None.
Returns:
Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
"""
assert args.model_type == "gpt2"
if args.temperature != 1.0:
# TODO: implement temperature in BeamSearch operator.
# TODO(tianleiwu): implement temperature in BeamSearch operator.
print("Skipping parity test as temperature is not implemented in BeamSearch operator")
return True
return None
if args.prefix_vocab_mask:
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
return True
from transformers import GPT2LMHeadModel, GPT2Tokenizer
return None
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer.padding_side = "left"
@ -589,14 +854,200 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
if use_vocab_mask:
print("bad_words_ids", bad_words_ids)
else:
bad_words_ids = None
bad_words_ids = []
global config
config = model.config
eos_token_id = config.eos_token_id
pad_token_id = config.eos_token_id
vocab_size = config.vocab_size
torch_decoded_sequences = []
beam_outputs = None
if not args.disable_parity:
print("-" * 50)
print("Test PyTorch model and beam search with huggingface transformers...")
beam_outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_length=args.max_length,
min_length=args.min_length,
num_beams=args.num_beams,
early_stopping=args.early_stopping,
no_repeat_ngram_size=args.no_repeat_ngram_size,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
num_return_sequences=args.num_return_sequences,
temperature=args.temperature,
length_penalty=args.length_penalty,
repetition_penalty=args.repetition_penalty,
bad_words_ids=bad_words_ids,
return_dict_in_generate=True,
output_scores=args.output_sequences_scores or args.output_token_scores,
)
print("input_ids", input_ids)
print("huggingface transformers outputs:")
print("sequences", beam_outputs.sequences)
if args.output_sequences_scores:
print("sequences_scores", beam_outputs.sequences_scores)
if args.output_token_scores:
print("scores", beam_outputs.scores)
for i, sequence in enumerate(beam_outputs.sequences):
decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
torch_decoded_sequences.append(decoded_sequence)
print(f"{i}: {decoded_sequence}")
print("-" * 50)
print("Test ONNX model and bream search with onnxruntime...")
ort_session = create_ort_session(args.output, args.use_gpu)
vocab_mask = np.ones((vocab_size), dtype=np.int32)
if use_vocab_mask:
for bad_word_id in bad_words_ids:
vocab_mask[bad_word_id] = 0
inputs = {
"input_ids": input_ids.cpu().numpy().astype(np.int32),
"max_length": np.array([args.max_length], dtype=np.int32),
"min_length": np.array([args.min_length], dtype=np.int32),
"num_beams": np.array([args.num_beams], dtype=np.int32),
"num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
"temperature": np.array([args.temperature], dtype=np.float32),
"length_penalty": np.array([args.length_penalty], dtype=np.float32),
"repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
"vocab_mask": vocab_mask,
}
print("inputs", inputs)
result = ort_session.run(None, inputs)
test_data_dir = Path(args.output).parent.as_posix()
print("test_data_dir", test_data_dir)
from bert_test_data import output_test_data
all_inputs = [inputs]
for i, inputs in enumerate(all_inputs):
dir = os.path.join(test_data_dir, "test_data_set_" + str(i))
output_test_data(dir, inputs)
# Test performance
latency = []
for _ in range(args.total_runs):
start = time.time()
_ = ort_session.run(None, inputs)
latency.append(time.time() - start)
batch_size = input_ids.shape[0]
from benchmark_helper import get_latency_result
output = get_latency_result(latency, batch_size)
print("ORT outputs:")
sequences = result[0]
print("sequences", sequences)
if args.output_sequences_scores:
print("sequences_scores", result[1])
if args.output_token_scores:
print("scores", result[2])
(batch_size, num_sequences, max_length) = sequences.shape
ort_decoded_sequences = []
for i in range(batch_size):
for j in range(num_sequences):
decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
ort_decoded_sequences.append(decoded_sequence)
print(f"batch {i} sequence {j}: {decoded_sequence}")
if beam_outputs:
torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
ort_sequences = torch.LongTensor(sequences)
print("-" * 50)
print("Torch Sequences:")
print(torch_sequences)
print(torch_decoded_sequences)
print("-" * 50)
print("ORT Sequences:")
print(ort_sequences)
print(ort_decoded_sequences)
print("-" * 50)
# Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
is_same = torch_decoded_sequences == ort_decoded_sequences
print("Torch and ORT result is ", "same" if is_same else "different")
output["parity"] = is_same
if args.torch_performance:
torch_latency_output = test_torch_performance(
args,
model,
input_ids,
attention_mask,
eos_token_id,
pad_token_id,
bad_words_ids,
)
print("Torch Latency", torch_latency_output)
print("ORT", output)
return output
def test_t5_model(args: argparse.Namespace, use_vocab_mask: bool = False, sentences: Optional[List[str]] = None):
"""Test T5 model
Args:
args (argparse.Namespace): arguments parsed from command line
use_vocab_mask (bool, optional): use vocabulary mask. Defaults to False.
sentences (Optional[List[str]], optional): input text. Defaults to None.
Returns:
Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
"""
assert args.model_type == "t5"
if args.temperature != 1.0:
# TODO(tianleiwu): implement temperature in BeamSearch operator.
print("Skipping parity test as temperature is not implemented in BeamSearch operator")
return None
if args.prefix_vocab_mask:
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
return None
tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
tokenizer.padding_side = "left"
model = T5ForConditionalGeneration.from_pretrained(
args.model_name_or_path,
cache_dir=args.cache_dir,
)
# Use different length sentences to test batching
if sentences is None:
sentences = [
"translate English to French: The product is released",
"summarize: research continues to show that pets bring real health benefits to their owners."
+ "Having a dog around can lead to lower levels of stress for both adults and kids.",
# "summarize: I enjoy walking in the park. It makes my mind feel calm and refreshed. "
# + "I enjoy looking at the trees, flowers, and wildlife around me, and listening to sound from natural.",
]
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
input_ids = inputs["input_ids"]
attention_mask = inputs["attention_mask"]
bad_words = "walk in park"
bad_words_ids = tokenizer.encode(bad_words)[:-1] # exclude the last token (EOS)
bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
if use_vocab_mask:
print("bad_words_ids", bad_words_ids)
else:
bad_words_ids = []
config = model.config
eos_token_id = config.eos_token_id
pad_token_id = config.pad_token_id
vocab_size = config.vocab_size
print(f"eos_token_id:{eos_token_id}, pad_token_id:{pad_token_id}, vocab_size:{vocab_size}")
torch_decoded_sequences = []
if not args.disable_parity:
print("-" * 50)
@ -619,6 +1070,7 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
return_dict_in_generate=True,
output_scores=args.output_sequences_scores or args.output_token_scores,
)
print("input_ids", input_ids)
print("huggingface transformers outputs:")
print("sequences", beam_outputs.sequences)
@ -724,17 +1176,44 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
return output
def main(argv=None, sentences=None):
def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None):
"""Main entry function
Args:
argv (Optional[List[str]], optional): _description_. Defaults to None.
sentences (Optional[List[str]], optional): input text. Defaults to None.
Raises:
ValueError: --decoder_onnx is not specified for GPT2 model
ValueError: Path does not exist: --encoder_decoder_init_onnx
ValueError: Path does not exist: --decoder_onnx
ValueError: --decoder_onnx and --encoder_decoder_init_onnx are not used together for T5
Returns:
Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
"""
args = parse_arguments(argv)
if args.model_type == "gpt2":
if not args.decoder_onnx:
raise ValueError("--decoder_onnx shall be specified for gpt2 model")
elif args.model_type == "t5":
if args.encoder_decoder_init_onnx and not os.path.exists(args.encoder_decoder_init_onnx):
raise ValueError(f"Path does not exist: --encoder_decoder_init_onnx {args.encoder_decoder_init_onnx}")
if args.decoder_onnx and not os.path.exists(args.decoder_onnx):
raise ValueError(f"Path does not exist: --decoder_onnx {args.decoder_onnx}")
if (args.encoder_decoder_init_onnx and not args.decoder_onnx) or (
args.decoder_onnx and not args.encoder_decoder_init_onnx
):
raise ValueError("--decoder_onnx shall use together with --encoder_decoder_init_onnx")
convert_model(args)
if args.model_type == "t5":
assert args.encoder_decoder_init_onnx, "please export t5 to onnx models before using this tool"
if os.path.exists(args.output):
print(f"skip conversion since path existed: {args.output}")
return test_t5_model(args, use_vocab_mask=True, sentences=sentences)
else:
convert_model(args)
return test_model(args, use_vocab_mask=True, sentences=sentences)
return test_gpt_model(args, use_vocab_mask=True, sentences=sentences)
if __name__ == "__main__":

View file

@ -1,5 +1,4 @@
# -------------------------------------------------------------------------
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.

View file

@ -14,7 +14,7 @@ import torch
from t5_helper import PRETRAINED_T5_MODELS, T5Helper
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger # noqa: E402
logger = logging.getLogger("")
@ -80,7 +80,7 @@ def parse_arguments():
"--use_decoder_start_token",
required=False,
action="store_true",
help="Use config.decoder_start_token_id in decoding. Otherwise, add an extra graph input for decoder_input_ids.",
help="Use config.decoder_start_token_id. Otherwise, add an extra graph input for decoder_input_ids.",
)
parser.set_defaults(use_decoder_start_token=False)
@ -101,6 +101,22 @@ def parse_arguments():
)
parser.set_defaults(disable_auto_mixed_precision=False)
parser.add_argument(
"--separate_encoder_and_decoder_init",
required=False,
action="store_true",
help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.",
)
parser.set_defaults(separate_encoder_and_decoder_init=False)
parser.add_argument(
"--use_int64_inputs",
required=False,
action="store_true",
help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.",
)
parser.set_defaults(use_int64_inputs=False)
args = parser.parse_args()
return args
@ -115,10 +131,11 @@ def export_onnx_models(
optimize_onnx,
precision,
verbose,
use_decoder_start_token: bool = True,
use_decoder_start_token: bool = False,
merge_encoder_and_decoder_init: bool = True,
overwrite: bool = False,
disable_auto_mixed_precision: bool = False,
use_int32_inputs: bool = True,
):
device = torch.device("cuda:0" if use_gpu else "cpu")
@ -126,7 +143,7 @@ def export_onnx_models(
config = models["decoder"].config
if (not use_external_data_format) and (config.num_layers > 24):
logger.info(f"Try use_external_data_format when model size > 2GB")
logger.info("Try use_external_data_format when model size > 2GB")
output_paths = []
for name, model in models.items():
@ -151,6 +168,7 @@ def export_onnx_models(
verbose,
use_external_data_format,
use_decoder_input_ids=not use_decoder_start_token,
use_int32_inputs=use_int32_inputs,
)
else:
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
@ -185,10 +203,12 @@ def export_onnx_models(
use_gpu=use_gpu,
provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"],
)
max_diff = T5Helper.verify_onnx(model, ort_session, device)
with torch.no_grad():
max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs)
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
if max_diff > 1e-4:
logger.warn(f"PyTorch and OnnxRuntime results are NOT close")
logger.warning("PyTorch and OnnxRuntime results are NOT close")
output_paths.append(output_path)
@ -212,24 +232,23 @@ def main():
assert args.use_gpu, "fp16 requires --use_gpu"
if args.optimize_onnx:
logger.warn(f"Graph optimization for T5 is not implemented yet.")
logger.warning("Graph optimization for T5 is not implemented yet.")
with torch.no_grad():
merge_encoder_and_decoder_init = True # Merge encoder and decoder initialization into one model is recommended.
output_paths = export_onnx_models(
args.model_name_or_path,
cache_dir,
output_dir,
args.use_gpu,
args.use_external_data_format,
args.optimize_onnx,
args.precision,
args.verbose,
args.use_decoder_start_token,
merge_encoder_and_decoder_init,
args.overwrite,
args.disable_auto_mixed_precision,
)
output_paths = export_onnx_models(
args.model_name_or_path,
cache_dir,
output_dir,
args.use_gpu,
args.use_external_data_format,
args.optimize_onnx,
args.precision,
args.verbose,
args.use_decoder_start_token,
not args.separate_encoder_and_decoder_init,
args.overwrite,
args.disable_auto_mixed_precision,
not args.use_int64_inputs,
)
logger.info(f"Done! Outputs: {output_paths}")

View file

@ -135,6 +135,7 @@ class T5DecoderInputs:
past_decode_sequence_length: int,
device: torch.device,
float16: bool = False,
use_int32_inputs: bool = False,
): # -> T5DecoderInputs:
"""Create dummy inputs for T5Decoder.
@ -145,6 +146,7 @@ class T5DecoderInputs:
past_decode_sequence_length (int): past sequence length of input_ids for decoder
device (torch.device): device of output tensors
float16 (bool): whether the model uses float32 or float16 in input
use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
Returns:
T5DecoderInputs: dummy inputs for decoder
@ -159,11 +161,17 @@ class T5DecoderInputs:
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=torch.int64,
dtype=(torch.int32 if use_int32_inputs else torch.int64),
device=device,
)
encoder_inputs = T5EncoderInputs.create_dummy(batch_size, encode_sequence_length, vocab_size, device)
encoder_inputs = T5EncoderInputs.create_dummy(
batch_size,
encode_sequence_length,
vocab_size,
device,
use_int32_inputs=use_int32_inputs,
)
float_type = torch.float16 if float16 else torch.float32
encoder_hidden_state = torch.rand(
@ -211,7 +219,7 @@ class T5DecoderInputs:
def to_fp32(self):
encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32)
past = [p.to(dtype=torch.float32) for p in self.past_key_values]
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
return T5DecoderInputs(
self.decoder_input_ids.clone(),
self.encoder_attention_mask.clone(),
@ -228,6 +236,7 @@ class T5DecoderHelper:
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export decoder to ONNX
@ -237,6 +246,7 @@ class T5DecoderHelper:
onnx_model_path (str): onnx path
verbose (bool, optional): print verbose information. Defaults to True.
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
use_int32_inputs (bool, optional): use int32 inputs
"""
assert isinstance(decoder, (T5Decoder, T5DecoderInit))
@ -246,10 +256,9 @@ class T5DecoderHelper:
encode_sequence_length=3,
past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
with torch.no_grad():
outputs = decoder(*input_list)
past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True)
@ -351,7 +360,8 @@ class T5DecoderHelper:
model: Union[T5Decoder, T5DecoderInit],
ort_session: InferenceSession,
device: torch.device,
max_cases=4,
use_int32_inputs: bool,
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)"
@ -373,6 +383,7 @@ class T5DecoderHelper:
past_decode_sequence_length,
device=device,
float16=float16,
use_int32_inputs=use_int32_inputs,
)
# We use fp32 PyTroch model as baseline even when ONNX model is fp16

View file

@ -42,28 +42,30 @@ class T5EncoderInputs:
@staticmethod
def create_dummy(
batch_size: int, sequence_length: int, vocab_size: int, device: torch.device
batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False
): # -> T5EncoderInputs
"""Create dummy inputs for T5 encoder.
Args:
batch_size (int): batch size
sequence_length (int): sequence length
vocab_size (int): vocaburary size
vocab_size (int): vocabulary size
device (torch.device): device of output tensors
Returns:
T5EncoderInputs: dummy inputs for encoder
"""
dtype = torch.int32 if use_int32_inputs else torch.int64
input_ids = torch.randint(
low=0,
high=vocab_size - 1,
size=(batch_size, sequence_length),
dtype=torch.int64,
dtype=dtype,
device=device,
)
attention_mask = torch.ones([batch_size, sequence_length], dtype=torch.int64, device=device)
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
if sequence_length >= 2:
for i in range(batch_size):
padding_position = random.randint(0, sequence_length - 1)
@ -83,6 +85,7 @@ class T5EncoderHelper:
onnx_model_path: str,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export encoder to ONNX
@ -95,12 +98,13 @@ class T5EncoderHelper:
"""
config = encoder.config
encoder_inputs = T5EncoderInputs.create_dummy(
batch_size=2, sequence_length=4, vocab_size=config.vocab_size, device=device
batch_size=2,
sequence_length=4,
vocab_size=config.vocab_size,
device=device,
use_int32_inputs=use_int32_inputs,
)
with torch.no_grad():
outputs = encoder(encoder_inputs.input_ids, encoder_inputs.attention_mask)
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
encoder,
@ -131,13 +135,16 @@ class T5EncoderHelper:
return ort_session.run(None, ort_inputs)
@staticmethod
def verify_onnx(model: T5Encoder, ort_session: InferenceSession, device: torch.device):
def verify_onnx(
model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
inputs = T5EncoderInputs.create_dummy(
batch_size=4,
sequence_length=11,
vocab_size=model.config.vocab_size,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
torch_outputs = model(*input_list)

View file

@ -7,10 +7,12 @@
import logging
import os
import sys
import tempfile
from pathlib import Path
from typing import List
from typing import List, Optional
import numpy
import onnx
import torch
from past_helper import PastKeyValuesHelper
from t5_decoder import T5DecoderInit
@ -34,7 +36,7 @@ class T5EncoderDecoderInit(torch.nn.Module):
decoder: torch.nn.Module,
lm_head: torch.nn.Module,
config: T5Config,
decoder_start_token_id: int = None,
decoder_start_token_id: Optional[int] = None,
):
super().__init__()
self.config = config
@ -67,15 +69,19 @@ class T5EncoderDecoderInitInputs:
encode_sequence_length: int,
use_decoder_input_ids: int,
device: torch.device,
use_int32_inputs: bool = False,
): # -> T5EncoderDecoderInitInputs:
encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
batch_size, encode_sequence_length, config.vocab_size, device
batch_size,
encode_sequence_length,
config.vocab_size,
device,
use_int32_inputs=use_int32_inputs,
)
decoder_input_ids = None
if use_decoder_input_ids:
decoder_input_ids = (
torch.ones((batch_size, 1), dtype=torch.long, device=device) * config.decoder_start_token_id
)
dtype = torch.int32 if use_int32_inputs else torch.int64
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
@ -95,6 +101,7 @@ class T5EncoderDecoderInitHelper:
use_decoder_input_ids: bool = True,
verbose: bool = True,
use_external_data_format: bool = False,
use_int32_inputs: bool = False,
):
"""Export decoder to ONNX
@ -113,9 +120,9 @@ class T5EncoderDecoderInitHelper:
encode_sequence_length=3,
use_decoder_input_ids=use_decoder_input_ids,
device=device,
use_int32_inputs=use_int32_inputs,
)
input_list = inputs.to_list()
outputs = model(*input_list)
present_names = PastKeyValuesHelper.get_past_names(model.config.num_layers, present=True)
@ -135,7 +142,8 @@ class T5EncoderDecoderInitHelper:
input_names = ["encoder_input_ids", "encoder_attention_mask"]
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2'. Use more friendly string here.
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference.
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
sequence_length = "1"
num_heads = str(model.config.num_heads)
hidden_size = str(model.config.d_model)
@ -149,12 +157,18 @@ class T5EncoderDecoderInitHelper:
1: "encode_sequence_length",
2: hidden_size,
},
"logits": {0: "batch_size", 1: sequence_length},
"logits": {
0: "batch_size",
1: sequence_length,
},
}
if use_decoder_input_ids:
input_names.append("decoder_input_ids")
dynamic_axes["decoder_input_ids"] = {0: "batch_size", 1: sequence_length}
dynamic_axes["decoder_input_ids"] = {
0: "batch_size",
1: sequence_length,
}
for name in present_names:
if "cross" in name:
@ -173,20 +187,47 @@ class T5EncoderDecoderInitHelper:
3: head_size,
}
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
model,
args=tuple(input_list),
f=onnx_model_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=12,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
with tempfile.TemporaryDirectory() as tmp_dir_name:
temp_onnx_model_path = os.path.join(tmp_dir_name, "model.onnx")
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
torch_onnx_export(
model,
args=tuple(input_list),
f=temp_onnx_model_path,
export_params=True,
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
opset_version=12,
do_constant_folding=True,
use_external_data_format=use_external_data_format,
verbose=verbose,
)
# Workaround as mentioned earlier: change numeric dim_param to dim_value
model = onnx.load(temp_onnx_model_path)
for tensor in model.graph.output:
for dim_proto in tensor.type.tensor_type.shape.dim:
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
sequence_length,
num_heads,
hidden_size,
head_size,
]:
dim_value = int(dim_proto.dim_param)
dim_proto.Clear()
dim_proto.dim_value = dim_value
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
onnx.save_model(
model,
onnx_model_path,
save_as_external_data=use_external_data_format,
all_tensors_to_one_file=True,
location=onnx_model_path + ".data",
size_threshold=4096,
convert_attribute=False,
)
@staticmethod
def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
@ -208,7 +249,8 @@ class T5EncoderDecoderInitHelper:
model: T5EncoderDecoderInit,
ort_session: InferenceSession,
device: torch.device,
max_cases=4,
use_int32_inputs: bool,
max_cases: int = 4,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
ort_inputs = ort_session.get_inputs()
@ -223,6 +265,7 @@ class T5EncoderDecoderInitHelper:
encode_sequence_length,
use_decoder_input_ids=use_decoder_input_ids,
device=device,
use_int32_inputs=use_int32_inputs,
)
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)

View file

@ -26,7 +26,7 @@ from optimizer import optimize_model
logger = logging.getLogger(__name__)
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3B", "t5-11B"]
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
class T5Helper:
@ -110,9 +110,17 @@ class T5Helper:
verbose: bool = True,
use_external_data_format: bool = False,
use_decoder_input_ids: bool = True,
use_int32_inputs: bool = False,
):
if isinstance(model, T5Encoder):
T5EncoderHelper.export_onnx(model, device, onnx_model_path, verbose, use_external_data_format)
T5EncoderHelper.export_onnx(
model,
device,
onnx_model_path,
verbose,
use_external_data_format,
use_int32_inputs,
)
elif isinstance(model, T5EncoderDecoderInit):
T5EncoderDecoderInitHelper.export_onnx(
model,
@ -121,9 +129,17 @@ class T5Helper:
use_decoder_input_ids,
verbose,
use_external_data_format,
use_int32_inputs,
)
else:
T5DecoderHelper.export_onnx(model, device, onnx_model_path, verbose, use_external_data_format)
T5DecoderHelper.export_onnx(
model,
device,
onnx_model_path,
verbose,
use_external_data_format,
use_int32_inputs,
)
@staticmethod
def auto_mixed_precision(
@ -234,11 +250,13 @@ class T5Helper:
model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit],
ort_session: InferenceSession,
device: torch.device,
use_int32_inputs: bool,
):
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
if isinstance(model, T5Encoder):
return T5EncoderHelper.verify_onnx(model, ort_session, device)
elif isinstance(model, T5EncoderDecoderInit):
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device)
else:
return T5DecoderHelper.verify_onnx(model, ort_session, device)
return T5EncoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
if isinstance(model, T5EncoderDecoderInit):
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)

View file

@ -7,6 +7,7 @@ py3nvml
packaging
transformers >= 4.0
scipy
sentencepiece
# please follow https://pytorch.org/ to install PyTorch for your OS
torch >= 1.8

View file

@ -10,63 +10,189 @@ import os
import unittest
import pytest
import torch
from parity_utilities import find_transformers_source
if find_transformers_source():
from onnxruntime import get_available_providers
if find_transformers_source() and find_transformers_source(["models", "t5"]):
from benchmark_helper import Precision
from convert_beam_search import main as run
from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
else:
from onnxruntime.transformers.benchmark_helper import Precision
from onnxruntime.transformers.convert_beam_search import main as run
from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
class TestBeamSearch(unittest.TestCase):
class TestBeamSearchGpt(unittest.TestCase):
"""Test BeamSearch for GPT-2 model"""
def setUp(self):
# TODO: use a smaller model and enable tests in CI pipeline
self.model_name = "gpt2"
self.gpt2_onnx_path = os.path.join(".", "onnx_models", "gpt2_past_fp32_shape.onnx")
self.beam_search_onnx_path = os.path.join(".", "onnx_models", "gpt2_beam_search.onnx")
self.cpu_params = f"-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0"
self.default_arguments = [
f"-m {self.model_name}",
f"--decoder_onnx {self.gpt2_onnx_path}",
f"--output {self.beam_search_onnx_path}",
"--output_sequences_score",
"--repetition_penalty 2.0",
]
self.sentences = [
"The product is released",
"I enjoy walking in the park",
"Test best way to invest",
]
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
self.remove_onnx_files()
def run_beam_search(self, arguments: str, sentences=None):
return run(arguments.split(), sentences=sentences)
def tearDown(self):
self.remove_onnx_files()
def remove_onnx_files(self):
if os.path.exists(self.gpt2_onnx_path):
os.remove(self.gpt2_onnx_path)
if os.path.exists(self.beam_search_onnx_path):
os.remove(self.beam_search_onnx_path)
def run_beam_search(self, extra_arguments: str, sentences=None):
arguments = " ".join(self.default_arguments + [extra_arguments]).split()
# Test CPU
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
# Test GPU
if self.enable_cuda:
if "--use_gpu" not in arguments:
arguments.append("--use_gpu")
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}")
os.remove(self.beam_search_onnx_path)
@pytest.mark.slow
def test_cpu(self):
result = self.run_beam_search(
self.cpu_params + " --num_return_sequences 2",
sentences=["The product is released"],
)
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
def test_return_sequences(self):
for return_sequences in [1, 2]:
self.run_beam_search(f"--num_return_sequences {return_sequences}")
@pytest.mark.slow
def test_early_stopping(self):
result = self.run_beam_search(self.cpu_params + " --early_stopping")
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
self.run_beam_search("--early_stopping")
@pytest.mark.slow
def test_temperature(self):
result = self.run_beam_search(self.cpu_params + " --temperature 0.5")
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
# @pytest.mark.slow
# def test_temperature(self):
# self.run_beam_search("--temperature 0.5")
@pytest.mark.slow
def test_length_penalty(self):
result = self.run_beam_search(self.cpu_params + " --length_penalty 0.5")
os.remove(self.gpt2_onnx_path)
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
for length_penalty in [0.5, 2.0]:
self.run_beam_search(f"--length_penalty {length_penalty}")
@pytest.mark.slow
def test_no_repeat_ngram(self):
for ngram_size in [1, 2]:
result = self.run_beam_search(self.cpu_params + f" --no_repeat_ngram_size {ngram_size}")
os.remove(self.gpt2_onnx_path)
self.run_beam_search(f"--no_repeat_ngram_size {ngram_size}")
class TestBeamSearchT5(unittest.TestCase):
"""Test BeamSearch for T5 model"""
def setUp(self):
self.model_name = "t5-small"
self.decoder_onnx_path = os.path.join(".", "onnx_models", "t5-small_decoder.onnx")
self.encoder_onnx_path = os.path.join(".", "onnx_models", "t5-small_encoder_decoder_init.onnx")
self.beam_search_onnx_path = os.path.join(".", "onnx_models", "t5_small_beam_search.onnx")
self.default_arguments = [
f"-m {self.model_name}",
"--model_type t5",
f"--decoder_onnx {self.decoder_onnx_path}",
f"--encoder_decoder_init_onnx {self.encoder_onnx_path}",
f"--output {self.beam_search_onnx_path}",
"--output_sequences_score",
"--repetition_penalty 2.0",
]
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
export_t5_onnx_models(
self.model_name,
os.path.join(".", "cache_models"),
os.path.join(".", "onnx_models"),
use_gpu=False,
use_external_data_format=False,
optimize_onnx=False,
precision=Precision.FLOAT32,
verbose=False,
use_decoder_start_token=False,
merge_encoder_and_decoder_init=True,
overwrite=True,
disable_auto_mixed_precision=False,
use_int32_inputs=True,
)
self.sentences = [
"translate English to French: The product is released",
"summarize: research continues to show that pets bring real health benefits to their owners."
+ "Having a dog around can lead to lower levels of stress for both adults and kids.",
]
if os.path.exists(self.beam_search_onnx_path):
os.remove(self.beam_search_onnx_path)
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
def tearDown(self):
self.remove_onnx_files()
def remove_onnx_files(self):
if os.path.exists(self.beam_search_onnx_path):
os.remove(self.beam_search_onnx_path)
if os.path.exists(self.decoder_onnx_path):
os.remove(self.decoder_onnx_path)
if os.path.exists(self.encoder_onnx_path):
os.remove(self.encoder_onnx_path)
def run_beam_search(self, extra_arguments: str, sentences=None):
arguments = " ".join(self.default_arguments + [extra_arguments]).split()
# Test CPU
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
# Test GPU
if self.enable_cuda:
if "--use_gpu" not in arguments:
arguments.append("--use_gpu")
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}")
os.remove(self.beam_search_onnx_path)
@pytest.mark.slow
def test_return_sequences(self):
for return_sequences in [1, 2]:
self.run_beam_search(f"--num_return_sequences {return_sequences}")
@pytest.mark.slow
def test_early_stopping(self):
self.run_beam_search("--early_stopping")
# @pytest.mark.slow
# def test_temperature(self):
# self.run_beam_search("--temperature 0.5")
@pytest.mark.slow
def test_length_penalty(self):
for length_penalty in [0.5, 2.0]:
self.run_beam_search(f"--length_penalty {length_penalty}")
@pytest.mark.slow
def test_no_repeat_ngram(self):
for ngram_size in [1, 2]:
self.run_beam_search(f"--no_repeat_ngram_size {ngram_size}")
if __name__ == "__main__":