diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 18dc84a8d1..a6f3f845d6 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -355,10 +355,12 @@ This version of the operator has been available since version 1 of the 'com.micr
- decoder : graph (required)
- Decoder subgraph to execute in a loop.
+- decoder_start_token_id : int
+- The id of the token that indicates decoding starts.
- early_stopping : int
- early stop or not
-- encoder_decoder_init : graph
-- subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
+- encoder : graph
+- The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
- eos_token_id : int (required)
- The id of the end-of-sequence token
- model_type : int
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
index d82f1a7094..8683317834 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -9,6 +9,7 @@
#pragma warning(disable : 4996)
#endif
+#include
#include
#include
#include "core/common/safeint.h"
@@ -26,11 +27,13 @@
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
#include "gsl/gsl"
-#include "beam_search.h"
-#include "logits_processor.h"
-#include "sequences.h"
-#include "dump_tensor.h"
-#include "beam_search_scorer.h"
+#include "contrib_ops/cpu/transformers/beam_search.h"
+#include "contrib_ops/cpu/transformers/logits_processor.h"
+#include "contrib_ops/cpu/transformers/sequences.h"
+#include "contrib_ops/cpu/transformers/dump_tensor.h"
+#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
+#include "contrib_ops/cpu/transformers/beam_search_impl_gpt.h"
+#include "contrib_ops/cpu/transformers/beam_search_impl_t5.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
@@ -53,590 +56,156 @@ REGISTER_KERNEL_TYPED(float)
namespace transformers {
-template
-gsl::span AllocateBuffer(AllocatorPtr allocator,
- BufferUniquePtr& buffer,
- size_t elements,
- bool fill = false,
- T fill_value = T{}) {
- size_t bytes = SafeInt(sizeof(T)) * elements;
- void* data = allocator->Alloc(bytes);
- BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
- buffer = std::move(temp_buffer);
- T* first = reinterpret_cast(buffer.get());
- auto span = gsl::make_span(first, elements);
-
- if (fill) {
- std::fill_n(first, elements, fill_value);
- }
-
- return span;
-}
-
-template
-struct BeamSearchState : public IBeamSearchState {
- void Init(AllocatorPtr allocator,
- int batch_size,
- int num_beams,
- int vocab_size,
- int sequence_length,
- int max_length,
- bool output_scores) {
- size_t batch_beam_size = SafeInt(batch_size) * num_beams;
-
- size_t next_token_size = SafeInt(batch_beam_size) * vocab_size;
- this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size);
- this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size);
-
- this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size);
-
- this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size);
-
- this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size);
-
- this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size);
-
- if (output_scores) {
- size_t elements = SafeInt(max_length - sequence_length) * batch_size * num_beams * vocab_size;
- this->scores = AllocateBuffer(allocator, scores_buffer_, elements);
- this->remaining_scores = this->scores;
- }
- }
-
- private:
- BufferUniquePtr next_token_logits_buffer_;
- BufferUniquePtr next_token_scores_buffer_;
- BufferUniquePtr next_tokens_buffer_;
- BufferUniquePtr next_indices_buffer_;
- BufferUniquePtr next_positions_buffer_;
- BufferUniquePtr beam_scores_buffer_;
- BufferUniquePtr scores_buffer_;
-};
-
-struct BeamSearchCpuState : public IBeamSearchCpuState {
- Sequences sequences;
-
- void Init(AllocatorPtr allocator, size_t batch_beam_size, int max_length, bool is_cuda) {
- this->sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size);
- this->sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, SafeInt(2) * batch_beam_size * max_length);
-
- if (is_cuda) {
- // buffers used by CUDA operator but not by CPU operator.
- this->topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * batch_beam_size);
- this->topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * batch_beam_size);
- this->topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * batch_beam_size);
- this->final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size);
- }
- }
-
- private:
- BufferUniquePtr final_beam_scores_buffer_;
- BufferUniquePtr sequence_lengths_buffer_;
- BufferUniquePtr topk_scores_buffer_;
- BufferUniquePtr topk_tokens_buffer_;
- BufferUniquePtr topk_indices_buffer_;
- BufferUniquePtr sequences_space_buffer_;
-};
-
-template
-class BeamSearchImpl {
- public:
- BeamSearchImpl(OpKernelContextInternal& context,
- const SessionState& session_state,
- GptSubgraph& gpt_subgraph,
- concurrency::ThreadPool* thread_pool,
- void* cuda_stream,
- IConsoleDumper* cuda_dumper,
- BeamSearchParameters& params,
- const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
- const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
- const BeamSearchDeviceHelper::TopkFunc& topk_func,
- const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func,
- const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func,
- const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func,
- const BeamSearchDeviceHelper::UpdateFeedsFunc& update_feeds_func)
- : context_(context),
- session_state_(session_state),
- gpt_subgraph_(gpt_subgraph),
- thread_pool_(thread_pool),
- implicit_inputs_(context_.GetImplicitInputs()),
- cuda_stream_(cuda_stream),
- cuda_dumper_(cuda_dumper),
- parameters_(¶ms),
- cpu_allocator_(nullptr),
- temp_space_allocator_(nullptr),
- create_inputs_func_(create_inputs_func),
- add_to_feeds_func_(add_to_feeds_func),
- topk_func_(topk_func),
- process_logits_func_(process_logits_func),
- init_beam_state_func_(init_beam_state_func),
- device_copy_func_(device_copy_func),
- update_feeds_func_(update_feeds_func) {
- parameters_->ParseFromInputs(&context);
-
- cpu_allocator_ = session_state.GetExecutionProviders()
- .Get(onnxruntime::kCpuExecutionProvider)
- ->GetAllocator(0, OrtMemTypeDefault);
- }
-
- // Initialize by validating all the inputs, and allocating the output tensors.
- Status Initialize();
-
- // Execute beam search in iterations util stopping criteria is reached.
- // In each iteration, GPT subgraph is called, and next token for each sequence is generated.
- Status Execute(const FeedsFetchesManager& feeds_fetches_manager);
-
- private:
- bool IsCuda() const { return cuda_stream_ != nullptr; }
-
- // Validate inputs.
- Status CheckInputs(const OpKernelContextInternal& context);
-
- // Prepare the inputs for first inference of subgraph
- Status CreateInitialFeeds(gsl::span& sequence_lengths, OrtValue& expanded_input_ids, std::vector& feeds, IAllocatorUniquePtr& buffer);
-
- // Update the input for next iteration.
- Status UpdateFeeds(
- const std::vector& last_outputs,
- std::vector& next_inputs,
- int current_length,
- OrtValue& position_ids,
- gsl::span beam_next_tokens,
- gsl::span beam_indices);
-
- // Process logits and append next tokens to sequences.
- Status GenerateNextToken(const OrtValue& logits,
- gsl::span& beam_next_tokens,
- gsl::span& beam_indices,
- BeamSearchState& beam_state,
- BeamSearchCpuState& cpu_state,
- int counter);
-
- // Calculate scores from logits, then apply filtering and select next token for each beam.
- Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
- BeamSearchState& beam_state,
- BeamSearchCpuState& cpu_state,
- AllocatorPtr& allocator,
- int counter);
-
- const IConsoleDumper* GetConsoleDumper() const { return IsCuda() ? cuda_dumper_ : &(cpu_dumper_); }
-
- OpKernelContextInternal& context_;
-
- const SessionState& session_state_;
-
- GptSubgraph& gpt_subgraph_;
-
- concurrency::ThreadPool* thread_pool_;
-
- const std::vector& implicit_inputs_;
-
- void* cuda_stream_;
-
- IConsoleDumper* cuda_dumper_;
- CpuTensorConsoleDumper cpu_dumper_;
-
- BeamSearchParameters* parameters_;
-
- LogitsProcessorList logits_processors_;
-
- std::unique_ptr beam_scorer_;
-
- AllocatorPtr cpu_allocator_;
- AllocatorPtr temp_space_allocator_;
-
- // Device specific functions
- BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_;
- BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
- BeamSearchDeviceHelper::TopkFunc topk_func_;
- BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_func_;
- BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_func_;
- BeamSearchDeviceHelper::DeviceCopyFunc device_copy_func_;
- BeamSearchDeviceHelper::UpdateFeedsFunc update_feeds_func_;
-};
-
void BeamSearch::Init(const OpKernelInfo& info) {
- // Make sure the decoder attribute was present even though we don't need it here.
+ parameters_.ParseFromAttributes(info);
+
+ // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5).
+ ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
+ parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
+
ONNX_NAMESPACE::GraphProto proto;
+ if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
+ ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK());
+ }
+
+ // Make sure the decoder attribute was present even though we don't need it here.
ORT_ENFORCE(info.GetAttr("decoder", &proto).IsOK());
ORT_IGNORE_RETURN_VALUE(proto);
-
- parameters_.ParseFromAttributes(info);
}
Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
- ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
- // TODO: handle another subgraph with attribute name "encoder_decode_init"
- if (attribute_name == "decoder") {
- const auto& node = Node();
- gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer());
- ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
- feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
- parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
- gpt_subgraph_->num_heads,
- gpt_subgraph_->head_size,
- gpt_subgraph_->num_layers);
+ const auto& node = Node();
+ if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
+ if (attribute_name == "decoder") {
+ ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
+ gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer());
+ ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
+ decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
+ parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
+ gpt_subgraph_->num_heads,
+ gpt_subgraph_->head_size,
+ gpt_subgraph_->num_layers);
+ }
+ } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) {
+ if (attribute_name == "encoder") {
+ ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
+ "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
+ t5_encoder_subgraph_ = std::make_unique(node,
+ attribute_name,
+ subgraph_session_state.GetGraphViewer());
+ ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
+ encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
+
+ if (parameters_.decoder_start_token_id < 0) {
+ ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
+ "Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
+ } else {
+ ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
+ "Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
+ }
+ } else if (attribute_name == "decoder") {
+ ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
+ "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
+ t5_decoder_subgraph_ = std::make_unique(node,
+ attribute_name,
+ subgraph_session_state.GetGraphViewer());
+ ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state));
+ decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager();
+ parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size,
+ t5_decoder_subgraph_->num_heads,
+ t5_decoder_subgraph_->head_size,
+ t5_decoder_subgraph_->num_layers);
+ }
}
+
return Status::OK();
}
Status BeamSearch::Compute(OpKernelContext* ctx) const {
- if (parameters_.model_type != 0) {
- // TODO: support encoder decoder model like T5
- return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Support of 'model_type' != 0 is not implemented");
- }
-
auto* ctx_internal = static_cast(ctx);
- auto* session_state = ctx_internal->SubgraphSessionState("decoder");
- ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
- ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
+
+ auto* decoder_session_state = ctx_internal->SubgraphSessionState("decoder");
+ ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
+ ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
- BeamSearchParameters parameters = parameters_; // make a copy since we will update the parameters based on inputs later
+ // Make a copy of parameters since we will update it based on inputs later
+ BeamSearchParameters parameters = parameters_;
+
+ if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
+ if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32
+ BeamSearchGpt impl{
+ *ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
+ BeamSearchCpuDeviceHelper::CreateGptInputs,
+ add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
+ topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
+ process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits,
+ init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState,
+ device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy,
+ device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy,
+ update_gpt_feeds_func_ ? update_gpt_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateGptFeeds};
+ ORT_RETURN_IF_ERROR(impl.Initialize());
+
+ return impl.Execute(*decoder_feeds_fetches_manager_);
+ } else { // Output float16
+ BeamSearchGpt impl{
+ *ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
+ BeamSearchCpuDeviceHelper::CreateGptInputs,
+ add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
+ topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
+ process_logits_fp16_func_,
+ init_beam_state_fp16_func_,
+ device_copy_func_,
+ device_copy_int32_func_,
+ update_gpt_feeds_fp16_func_};
+ ORT_RETURN_IF_ERROR(impl.Initialize());
+
+ return impl.Execute(*decoder_feeds_fetches_manager_);
+ }
+ }
+
+ auto* encoder_session_state = ctx_internal->SubgraphSessionState("encoder");
+ ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute.");
+ ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
// Subgraph has constraint that the output is either float or float16
- if (!gpt_subgraph_->IsOutputFloat16()) {
- BeamSearchImpl impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
- create_inputs_func_ ? create_inputs_func_ : BeamSearchCpuDeviceHelper::CreateInputs,
- add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
- topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
- process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits,
- init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState,
- device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy,
- update_feeds_func_ ? update_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateFeeds};
+ if (!t5_decoder_subgraph_->IsOutputFloat16()) {
+ BeamSearchT5 impl{
+ *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
+ *t5_decoder_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
+ add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
+ topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
+ process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits,
+ init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState,
+ device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy,
+ device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy,
+ create_encoder_inputs_func_ ? create_encoder_inputs_func_ : BeamSearchCpuDeviceHelper::CreateEncoderInputs,
+ update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds};
ORT_RETURN_IF_ERROR(impl.Initialize());
- return impl.Execute(*feeds_fetches_manager_);
+ return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
} else {
- BeamSearchImpl impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
- create_inputs_func_ ? create_inputs_func_ : BeamSearchCpuDeviceHelper::CreateInputs,
- add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
- topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
- process_logits_fp16_func_,
- init_beam_state_fp16_func_,
- device_copy_func_,
- update_feeds_fp16_func_};
+ BeamSearchT5 impl{
+ *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
+ *t5_decoder_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
+ add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
+ topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
+ process_logits_fp16_func_,
+ init_beam_state_fp16_func_,
+ device_copy_func_,
+ device_copy_int32_func_,
+ create_encoder_inputs_func_,
+ update_decoder_feeds_fp16_func_};
+
ORT_RETURN_IF_ERROR(impl.Initialize());
- return impl.Execute(*feeds_fetches_manager_);
+ return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
}
}
-template
-Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) {
- // Input shapes:
- // input_ids : (batch_size, sequence_length)
- // vocab_mask : (vocab_size) or nullptr
-
- const Tensor* input_ids = context.Input(0);
- const auto& dims = input_ids->Shape().GetDims();
- if (dims.size() != 2) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ",
- dims.size());
- }
-
- const Tensor* vocab_mask = context.Input(8);
- if (vocab_mask != nullptr) { // vocab_mask is optional
- const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
- if (vocab_mask_dims.size() != 1) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ",
- vocab_mask_dims.size());
- }
-
- // There is dependency on vocab_size parameter, which shall be set before calling this function.
- if (static_cast(vocab_mask_dims[0]) != parameters_->vocab_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ",
- vocab_mask_dims[0]);
- }
-
- // store vocab mask in parameters.
- parameters_->vocab_mask = vocab_mask->DataAsSpan();
- }
-
- const Tensor* prefix_vocab_mask = context.Input(9);
- if (prefix_vocab_mask != nullptr) {
- // prefix_vocab_mask is optional
- const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims();
- if (vocab_mask_dims.size() != 2) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ",
- vocab_mask_dims.size());
- }
-
- // prefix_vocab_mask first dimension should be same as the first dimension of input_ids
- if (static_cast(vocab_mask_dims[0]) != static_cast(dims[0])) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size");
- }
-
- // There is dependency on vocab_size parameter, which shall be set before calling this function.
- if (static_cast(vocab_mask_dims[1]) != parameters_->vocab_size) {
- return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ",
- vocab_mask_dims[0]);
- }
-
- // store prefix vocab mask in parameters.
- parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan();
- }
-
- return Status::OK();
-}
-
-template
-Status BeamSearchImpl::Initialize() {
- ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator_));
-
-#define CHECK_SCALAR_INPUT(name, index, required) \
- auto* name##_tensor = context_.Input(index); \
- if (name##_tensor) { \
- if (!name##_tensor->Shape().IsScalar()) { \
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \
- name##_tensor->Shape()); \
- } \
- } else if (required) { \
- return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \
- }
-
- CHECK_SCALAR_INPUT(min_length, 1, false);
-
- CHECK_SCALAR_INPUT(max_length, 2, true);
-
- CHECK_SCALAR_INPUT(num_beams, 3, true);
-
- CHECK_SCALAR_INPUT(num_return_sequences, 4, true);
-
- CHECK_SCALAR_INPUT(temperature, 5, true);
-
- CHECK_SCALAR_INPUT(length_penalty, 6, true);
-
- ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'.");
-
- ORT_RETURN_IF_ERROR(CheckInputs(context_));
-
- // This flag will be updated later when the scores output exists.
- parameters_->output_scores = false;
-
- if (!IsCuda()) {
- // Logits processor is used in CPU only. In CUDA, cuda kernels are used instead.
- // Initialize processsors after CheckInputs so that parameters_->vocab_mask is ready.
- logits_processors_.Init(*parameters_);
- }
-
- return Status::OK();
-}
-
-template
-Status BeamSearchImpl::CreateInitialFeeds(gsl::span& sequence_lengths, OrtValue& expanded_input_ids, std::vector& feeds, IAllocatorUniquePtr& buffer) {
- const OrtValue* input_ids_value = context_.GetInputOrtValue(0);
- const Tensor& input_ids = input_ids_value->Get();
- return gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, sequence_lengths, expanded_input_ids, feeds, create_inputs_func_, add_to_feeds_func_, buffer);
-}
-
-template
-Status BeamSearchImpl::ProcessLogits(
- const OrtValue& logits,
- BeamSearchState& beam_state,
- BeamSearchCpuState& cpu_state,
- AllocatorPtr& allocator,
- int counter) {
- return process_logits_func_(logits, &beam_state, &cpu_state, &(cpu_state.sequences), allocator,
- thread_pool_, &logits_processors_, beam_scorer_.get(),
- parameters_, counter, cuda_stream_, GetConsoleDumper());
-}
-
-template
-Status BeamSearchImpl::GenerateNextToken(
- const OrtValue& logits,
- gsl::span& beam_next_tokens,
- gsl::span& beam_indices,
- BeamSearchState& beam_state,
- BeamSearchCpuState& cpu_state,
- int counter) {
- // Process logits to get next token scores
- ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, cpu_state, temp_space_allocator_, counter));
-
- gsl::span& beam_scores = beam_scorer_->GetNextScores();
- // It is optional to clone beam_scores. Change it to use same buffer also works for CPU:
- // beam_state.beam_scores = beam_scores
- // Here we make a copy to reduce the coupling with little cost (the buffer size is small).
- ORT_RETURN_IF_ERROR(device_copy_func_(beam_state.beam_scores, beam_scores, cuda_stream_, DeviceCopyDirection::hostToDevice));
-
- beam_next_tokens = beam_scorer_->GetNextTokens();
- beam_indices = beam_scorer_->GetNextIndices();
-
-#ifdef DEBUG_BEAM_SEARCH
- cpu_dumper_.Print("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams);
- cpu_dumper_.Print("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams);
- cpu_dumper_.Print("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams);
-#endif
-
- cpu_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens);
-
-#ifdef DEBUG_BEAM_SEARCH
- cpu_state.sequences.PrintSequences(&cpu_dumper_);
-#endif
- return Status::OK();
-}
-
-template
-Status BeamSearchImpl::UpdateFeeds(
- const std::vector& last_outputs,
- std::vector& next_inputs,
- int current_length,
- OrtValue& position_ids,
- gsl::span beam_next_tokens,
- gsl::span beam_indices) {
- return update_feeds_func_(temp_space_allocator_, cuda_stream_, last_outputs, next_inputs, current_length, position_ids,
- beam_next_tokens, beam_indices, parameters_->num_beams, GetConsoleDumper());
-}
-
-template
-Status BeamSearchImpl::Execute(const FeedsFetchesManager& feeds_fetches_manager) {
- auto status = Status::OK();
- int64_t sequences_dims[] = {parameters_->batch_size, parameters_->num_return_sequences, parameters_->max_length};
- TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0]));
- Tensor* output_sequences = context_.Output(0, sequences_shape);
-
- int64_t sequences_scores_dims[] = {parameters_->batch_size, parameters_->num_return_sequences};
- TensorShape sequences_scores_shape(&sequences_scores_dims[0], sizeof(sequences_scores_dims) / sizeof(sequences_scores_dims[0]));
- Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape);
-
- int64_t scores_dims[] = {
- static_cast(parameters_->max_length) - static_cast(parameters_->sequence_length),
- parameters_->batch_size, parameters_->num_beams, parameters_->vocab_size};
- TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0]));
- Tensor* output_scores = context_.Output(2, scores_shape);
-
- // Update the flag to indicate whether scores exists in output
- parameters_->output_scores = (output_scores != nullptr);
-
- std::vector feeds;
- // TODO: allocate fetches. use ping-pong buffers for past state.
- std::vector fetches;
-
- // Initialize resources
- onnxruntime::OrtStlAllocator hypothesis_score_allocator(cpu_allocator_);
- onnxruntime::OrtStlAllocator beam_hyps_allocator(cpu_allocator_);
- beam_scorer_ = std::make_unique(static_cast(parameters_->batch_size),
- static_cast(parameters_->num_beams),
- static_cast(parameters_->max_length),
- parameters_->length_penalty,
- parameters_->early_stopping,
- static_cast(parameters_->num_return_sequences),
- parameters_->pad_token_id,
- parameters_->eos_token_id,
- hypothesis_score_allocator,
- beam_hyps_allocator);
- beam_scorer_->Initialize(cpu_allocator_, parameters_->sequence_length);
-
- BeamSearchCpuState cpu_state;
- cpu_state.Init(cpu_allocator_, static_cast(parameters_->BatchBeamSize()), parameters_->max_length, IsCuda());
-
- // buffer in GPU for input_ids, position_ids and attention_mask
- // size_t buffer_bytes = SafeInt(sizeof(int32_t) + sizeof(int32_t) + sizeof(int32_t)) * parameters_->batch_size * parameters_->num_beams * parameters_->sequence_length;
- // IAllocatorUniquePtr buffer = gpt_subgraph_.GetProvider()->GetScratchBuffer(buffer_bytes);
- IAllocatorUniquePtr buffer;
- OrtValue expanded_input_ids_in_cpu;
- ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
-
- BeamSearchState beam_state;
- beam_state.Init(temp_space_allocator_,
- parameters_->batch_size,
- parameters_->num_beams,
- parameters_->vocab_size,
- parameters_->sequence_length,
- parameters_->max_length,
- parameters_->output_scores);
-
- cpu_state.sequences.Init(cpu_state.sequences_space,
- parameters_->BatchBeamSize(),
- parameters_->sequence_length,
- parameters_->max_length);
-
- gsl::span input_ids = expanded_input_ids_in_cpu.Get().DataAsSpan();
- init_beam_state_func_(&beam_state,
- &cpu_state,
- cpu_state.sequence_lengths,
- parameters_->batch_size,
- parameters_->num_beams,
- input_ids,
- parameters_->sequence_length,
- parameters_->max_length,
- cuda_stream_);
-
-#ifdef DEBUG_BEAM_SEARCH
- const IConsoleDumper* dumper = GetConsoleDumper();
- dumper->Print("input_ids", feeds[0]);
- dumper->Print("position_ids", feeds[1]);
- dumper->Print("attention_mask", feeds[2]);
-#endif
-
- // position ids for all iterations except the first. It uses memory buffer owned by next_positions.
- OrtValue position_ids;
- int64_t dims[] = {parameters_->BatchBeamSize(), 1};
- TensorShape shape(&dims[0], 2);
- Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, beam_state.next_positions.data(), temp_space_allocator_->Info(), position_ids);
-
- int current_length = parameters_->sequence_length;
- int iteration_counter = 0;
- while (current_length < parameters_->max_length) {
- iteration_counter++;
-#ifdef DEBUG_BEAM_SEARCH
- auto cur_len = std::to_string(current_length);
- dumper->Print("***CurrentLength", cur_len, true);
-#endif
-
- status = utils::ExecuteSubgraph(session_state_, feeds_fetches_manager, feeds, fetches, {},
- ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger());
-
- ORT_RETURN_IF_ERROR(status);
-
- const OrtValue& logits = fetches[0];
- gsl::span beam_next_tokens;
- gsl::span beam_indices;
- ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, cpu_state, iteration_counter));
-
- // When all batches are finished, stop earlier to avoid wasting computation.
- if (beam_scorer_->IsDone()) {
- break;
- }
-
- // Increase sequence length after a new token is generated.
- ++current_length;
-
- // Prepare inputs for next round of subgraph call.
- if (current_length < parameters_->max_length) {
- ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
- position_ids,
- beam_next_tokens.as_span(),
- beam_indices.as_span()));
- }
- fetches.clear();
- }
-
- gsl::span final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
- if (IsCuda()) {
- ORT_RETURN_IF_ERROR(device_copy_func_(cpu_state.final_beam_scores, final_beam_scores, nullptr, DeviceCopyDirection::deviceToHost));
- final_beam_scores = gsl::make_span(cpu_state.final_beam_scores.data(), cpu_state.final_beam_scores.size());
- }
-
- beam_scorer_->Finalize(&(cpu_state.sequences),
- final_beam_scores,
- output_sequences,
- output_sequences_scores);
-
- // Output per token scores
- if (output_scores != nullptr) {
- gsl::span target = output_scores->MutableDataAsSpan();
- gsl::span source = gsl::span(beam_state.scores.data(), beam_state.scores.size());
- assert(target.length() == source.length());
- ORT_RETURN_IF_ERROR(device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
- }
-
- return status;
-}
-
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
index 217dc080cc..3a0de82010 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h
@@ -2,12 +2,16 @@
// Licensed under the MIT License.
#pragma once
+
+#include
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/controlflow/utils.h"
-#include "beam_search_parameters.h"
-#include "gpt_subgraph.h"
-#include "beam_search_device_helper.h"
+#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
+#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
+#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
+#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
+#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"
namespace onnxruntime {
class FeedsFetchesManager;
@@ -20,7 +24,11 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
class BeamSearch : public IControlFlowKernel {
public:
BeamSearch(const OpKernelInfo& info)
- : IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) {
+ : IControlFlowKernel(info),
+ encoder_feeds_fetches_manager_(nullptr),
+ decoder_feeds_fetches_manager_(nullptr),
+ cuda_stream_(nullptr),
+ dumper_(nullptr) {
Init(info);
}
@@ -36,54 +44,76 @@ class BeamSearch : public IControlFlowKernel {
void SetComputeStream(void* stream) { cuda_stream_ = stream; }
void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; }
+ // device helpers that is same for both GPT and encoder-decoder models.
void SetDeviceHelpers(
- // const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
- const BeamSearchDeviceHelper::TopkFunc& topk_func) {
- // create_inputs_func_ = create_inputs_func;
+ const BeamSearchDeviceHelper::TopkFunc& topk_func,
+ const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func,
+ const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func,
+ const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func,
+ const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_fp16_func,
+ const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func,
+ const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_fp16_func) {
add_to_feeds_func_ = add_to_feeds_func;
topk_func_ = topk_func;
- }
-
- // Type dependent helpers: float
- void SetDeviceHelpers(
- const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func,
- const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func,
- const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func,
- const BeamSearchDeviceHelper::UpdateFeedsFunc& update_feeds_func) {
- process_logits_func_ = process_logits_func;
- init_beam_state_func_ = init_beam_state_func;
device_copy_func_ = device_copy_func;
- update_feeds_func_ = update_feeds_func;
+ device_copy_int32_func_ = device_copy_int32_func;
+ process_logits_func_ = process_logits_func;
+ process_logits_fp16_func_ = process_logits_fp16_func;
+ init_beam_state_func_ = init_beam_state_func;
+ init_beam_state_fp16_func_ = init_beam_state_fp16_func;
}
- // Type dependent helpers: MLFloat16
- void SetDeviceHelpers(
- const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func,
- const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func,
- const BeamSearchDeviceHelper::UpdateFeedsFunc& update_feeds_func) {
- process_logits_fp16_func_ = process_logits_func;
- init_beam_state_fp16_func_ = init_beam_state_func;
- update_feeds_fp16_func_ = update_feeds_func;
+ void SetDeviceHelpers_Gpt(
+ const BeamSearchDeviceHelper::UpdateGptFeedsFunc& update_gpt_feeds_func,
+ const BeamSearchDeviceHelper::UpdateGptFeedsFunc& update_gpt_feeds_fp16_func) {
+ update_gpt_feeds_func_ = update_gpt_feeds_func;
+ update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
+ }
+
+ // device helpers for encoder-decoder model like T5
+ void SetDeviceHelpers_EncoderDecoder(
+ const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func,
+ const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func) {
+ update_decoder_feeds_func_ = update_decoder_feeds_func;
+ update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func;
}
private:
// Device specific functions
- BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_;
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
BeamSearchDeviceHelper::TopkFunc topk_func_;
- BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_func_;
- BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_func_;
BeamSearchDeviceHelper::DeviceCopyFunc device_copy_func_;
- BeamSearchDeviceHelper::UpdateFeedsFunc update_feeds_func_;
+ BeamSearchDeviceHelper::DeviceCopyFunc device_copy_int32_func_;
+ BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_func_;
BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_fp16_func_;
- BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_fp16_func_;
- BeamSearchDeviceHelper::UpdateFeedsFunc update_feeds_fp16_func_;
+ BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_func_;
+ BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_fp16_func_;
+
+ //------------------------------------------------------------
+ // Device specific functions for GPT
+ //------------------------------------------------------------
+ BeamSearchDeviceHelper::UpdateGptFeedsFunc update_gpt_feeds_func_;
+ BeamSearchDeviceHelper::UpdateGptFeedsFunc update_gpt_feeds_fp16_func_;
+
+ //------------------------------------------------------------
+ // Device specific functions for encoder-decoder model like T5
+ //------------------------------------------------------------
+ BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_;
+
+ BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_func_;
+ BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_fp16_func_;
+
+ //------------------------------------------------------------
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
+ //------------------------------------------------------------
std::unique_ptr gpt_subgraph_;
- FeedsFetchesManager* feeds_fetches_manager_;
+ std::unique_ptr t5_encoder_subgraph_;
+ std::unique_ptr t5_decoder_subgraph_;
+ FeedsFetchesManager* encoder_feeds_fetches_manager_;
+ FeedsFetchesManager* decoder_feeds_fetches_manager_;
void* cuda_stream_;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc
index 2397623a34..55fefe7460 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc
@@ -1,10 +1,17 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include
#include "core/providers/cpu/math/top_k.h"
#include "core/providers/cpu/math/softmax_shared.h"
#include "core/common/safeint.h"
#include "gsl/gsl"
-#include "sequences.h"
-#include "beam_search_scorer.h"
-#include "beam_search_device_helper.h"
+#include "contrib_ops/cpu/transformers/sequences.h"
+#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
+#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"
+#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
+#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
namespace onnxruntime {
namespace contrib {
@@ -25,11 +32,10 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
input->DataType(), " is not supported yet");
}
-OrtValue ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator) {
- // Input shape (batch_size, sequence_length)
+template
+void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) {
+ // Input shape (batch_size, sequence_length). The input is required with data type T.
// Output shape (batch_size * num_beams, sequence_length)
- if (num_beams == 1)
- return input;
const TensorShape& input_shape = input.Get().Shape();
const int64_t& batch_size = input_shape[0];
@@ -38,31 +44,28 @@ OrtValue ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocat
int64_t dims[] = {batch_size * num_beams, sequence_length};
TensorShape expanded_shape(&dims[0], 2);
- OrtValue expanded;
MLDataType element_type = input.Get().DataType();
- ORT_ENFORCE(element_type == DataTypeImpl::GetType(), "input_ids, position_ids and attention_mask is required to be int32 data type");
+ ORT_ENFORCE(element_type == DataTypeImpl::GetType());
Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded);
- const int32_t* input_data = input.Get().Data();
- int32_t* expanded_data = expanded.GetMutable()->MutableData();
- int32_t* target = expanded_data;
+ const T* input_data = input.Get().Data();
+ T* expanded_data = expanded.GetMutable()->MutableData();
+ T* target = expanded_data;
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < num_beams; j++) {
- memcpy(target, input_data + i * sequence_length, sizeof(int32_t) * sequence_length);
+ memcpy(target, input_data + i * sequence_length, sizeof(T) * sequence_length);
target += sequence_length;
}
}
-
- return expanded;
}
-Status CreateInputs(
+Status CreateGptInputs(
const Tensor* original_input_ids,
int num_beams,
int pad_token_id,
gsl::span& sequence_lengths,
- AllocatorPtr alloactor,
+ AllocatorPtr allocator,
OrtValue& expanded_input_ids,
OrtValue& expanded_position_ids,
OrtValue& expanded_attention_mask) {
@@ -74,21 +77,22 @@ Status CreateInputs(
// Allocate position_ids and attention_mask based on shape of input_ids
auto element_type = DataTypeImpl::GetType();
- const OrtMemoryInfo& location = alloactor->Info();
+ const OrtMemoryInfo& location = allocator->Info();
// Use original input_ids. This requires the input_ids for subgraph is also int32.
// Current shape is (batch_size, sequence_length)
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
OrtValue input_ids;
- Tensor::InitOrtValue(element_type, input_ids_shape, const_cast(original_input_ids)->MutableData(), location, input_ids);
+ Tensor::InitOrtValue(element_type, input_ids_shape,
+ const_cast(original_input_ids)->MutableData(), location, input_ids);
OrtValue position_ids;
- Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, position_ids);
+ Tensor::InitOrtValue(element_type, input_ids_shape, allocator, position_ids);
OrtValue attention_mask;
auto mask_type = DataTypeImpl::GetType();
- Tensor::InitOrtValue(mask_type, input_ids_shape, alloactor, attention_mask);
+ Tensor::InitOrtValue(mask_type, input_ids_shape, allocator, attention_mask);
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
// Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens
@@ -115,43 +119,45 @@ Status CreateInputs(
}
}
- // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask
- // TODO: Try expand outputs after first subgraph call instead. That may get better performance, but more complex to implement.
- expanded_input_ids = ExpandInputs(input_ids, num_beams, alloactor);
- expanded_position_ids = ExpandInputs(position_ids, num_beams, alloactor);
- expanded_attention_mask = ExpandInputs(attention_mask, num_beams, alloactor);
+ // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
+ // TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
+ ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids);
+ ExpandInputs(position_ids, num_beams, allocator, expanded_position_ids);
+ ExpandInputs(attention_mask, num_beams, allocator, expanded_attention_mask);
return Status::OK();
}
Status AddToFeeds(const IExecutionProvider* /*execution_provider*/,
- OrtValue& input_ids,
- OrtValue& position_ids,
- OrtValue& attention_mask,
+ std::initializer_list inputs,
std::vector& feeds,
IAllocatorUniquePtr& /*buffer*/) {
- feeds.push_back(input_ids);
- feeds.push_back(position_ids);
- feeds.push_back(attention_mask);
+ for (auto& input : inputs) {
+ if (input.IsAllocated()) {
+ feeds.push_back(input);
+ }
+ }
+
return Status::OK();
}
template
void InitBeamState(transformers::IBeamSearchState* beam_state,
- transformers::IBeamSearchCpuState* cpu_state,
gsl::span& sequence_lengths,
int batch_size,
int num_beams,
- gsl::span input_ids_in_cpu,
- int sequence_length,
- int max_length,
void* /*stream*/) {
memset(beam_state->beam_scores.data(), 0, beam_state->beam_scores.size_bytes());
memset(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes());
memset(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes());
memset(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes());
memset(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes());
- memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes());
+
+ // T5 does not need position, so next_positions is empty for T5.
+ if (!beam_state->next_positions.empty()) {
+ memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes());
+ gsl::copy(sequence_lengths, beam_state->next_positions);
+ }
// Initialize score of first beam of each group with 0 and the rest with -1e9.
// This ensures that the beams in the same group don't produce same tokens every time.
@@ -161,19 +167,6 @@ void InitBeamState(transformers::IBeamSearchState* beam_state,
beam_scores[SafeInt(i) * num_beams + j] = -1e9;
}
}
-
- gsl::copy(sequence_lengths, beam_state->next_positions);
-
- memset(cpu_state->sequences_space.data(), 0, cpu_state->sequences_space.size_bytes());
-
- // Copy input_ids to sequences[0].
- gsl::span sequences_0 = cpu_state->sequences_space;
- int batch_beam_size = batch_size * num_beams;
- for (int i = 0; i < batch_beam_size; i++) {
- for (int j = 0; j < sequence_length; j++) {
- sequences_0[SafeInt(i) * max_length + j] = static_cast(input_ids_in_cpu[SafeInt(i) * sequence_length + j]);
- }
- }
}
template
@@ -216,7 +209,8 @@ Status ProcessLogits(const OrtValue& logits, //
const T* current_logits = logits_data + (input_length - 1) * vocab_size;
for (int i = 0; i < batch_beam_size; i++) {
gsl::span source(current_logits, vocab_size);
- gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, static_cast(vocab_size));
+ gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size,
+ static_cast(vocab_size));
gsl::copy(source, target);
current_logits += input_length * vocab_size;
}
@@ -224,7 +218,9 @@ Status ProcessLogits(const OrtValue& logits, //
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("logits", logits);
- dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
+ if (input_length > 1) {
+ dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
+ }
#endif
// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
@@ -244,12 +240,12 @@ Status ProcessLogits(const OrtValue& logits, //
logits_processors->Process(sequences, next_token_scores, step);
#ifdef DEBUG_BEAM_SEARCH
- dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, num_beams, vocab_size);
+ dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size);
#endif
// Add beam score to next token scores. Corresponding python code is like:
// next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
- // TODO: use thread pool to parrellel
+ // TODO(tianleiwu): use thread pool to parallel
int offset = 0;
int batch_beam_index = 0;
for (int i = 0; i < batch_size; i++) {
@@ -261,7 +257,7 @@ Status ProcessLogits(const OrtValue& logits, //
}
#ifdef DEBUG_BEAM_SEARCH
- dumper->Print("next_token_scores after adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
+ dumper->Print("next_token_scores adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
#endif
if (output_scores) {
@@ -277,7 +273,8 @@ Status ProcessLogits(const OrtValue& logits, //
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
auto element_type = DataTypeImpl::GetType();
OrtValue next_token_scores_value;
- Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value);
+ Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(),
+ next_token_scores_value);
const Tensor& input = next_token_scores_value.Get();
constexpr int axis = 1;
@@ -287,7 +284,8 @@ Status ProcessLogits(const OrtValue& logits, //
std::unique_ptr topk_scores;
std::unique_ptr topk_indices;
- ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, topk_scores, topk_indices));
+ ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
+ topk_scores, topk_indices));
#ifdef DEBUG_BEAM_SEARCH
dumper->Print("topk_scores", *(topk_scores.get()));
@@ -331,32 +329,34 @@ Status DeviceCopy(gsl::span target, gsl::span source, void* /*stream
return Status::OK();
}
+// Copy present state to past state for GPT model
template
-void PickPastState(const std::vector& last_outputs,
- std::vector& next_inputs,
- gsl::span& beam_indices,
- AllocatorPtr allocator,
- void* /*stream*/) {
+void PickGptPastState(const std::vector& last_outputs,
+ std::vector& next_inputs,
+ gsl::span& beam_indices,
+ AllocatorPtr allocator) {
+ int num_present_tensors = static_cast(last_outputs.size()) - transformers::GptSubgraph::kFirstPresentOutputIndex;
+ for (int i = 0; i < num_present_tensors; ++i) {
+ const OrtValue& present = last_outputs[transformers::GptSubgraph::kFirstPresentOutputIndex + i];
- for (size_t i = 1; i < last_outputs.size(); ++i) {
- const OrtValue& present = last_outputs[i]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64)
+ // shape is like (2, batch_beam_size, 12, past_seq_len, 64)
const TensorShape& past_shape = present.Get().Shape();
+ auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
+ auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
// Create a tensor with same shape.
- // TODO: allocate one buffer for all layers
+ // TODO(tianleiwu): allocate one buffer for all layers
OrtValue past;
auto past_type = DataTypeImpl::GetType();
Tensor::InitOrtValue(past_type, past_shape, allocator, past);
- auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
- auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
-
gsl::span past_span = gsl::make_span(past.GetMutable()->MutableData(), past_shape.Size());
gsl::span present_span = gsl::make_span(present.Get().Data(), past_shape.Size());
for (gsl::index j = 0; j < beam_indices.length(); j++) {
int32_t beam_index = beam_indices[j];
gsl::span present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
- gsl::span