diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 18dc84a8d1..a6f3f845d6 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -355,10 +355,12 @@ This version of the operator has been available since version 1 of the 'com.micr
decoder : graph (required)
Decoder subgraph to execute in a loop.
+
decoder_start_token_id : int
+
The id of the token that indicates decoding starts.
early_stopping : int
early stop or not
-
encoder_decoder_init : graph
-
subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
+
encoder : graph
+
The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
eos_token_id : int (required)
The id of the end-of-sequence token
model_type : int
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index d82f1a7094..8683317834 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -9,6 +9,7 @@ #pragma warning(disable : 4996) #endif +#include #include #include #include "core/common/safeint.h" @@ -26,11 +27,13 @@ #include "core/framework/allocator.h" #include "core/framework/ort_value.h" #include "gsl/gsl" -#include "beam_search.h" -#include "logits_processor.h" -#include "sequences.h" -#include "dump_tensor.h" -#include "beam_search_scorer.h" +#include "contrib_ops/cpu/transformers/beam_search.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/sequences.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" +#include "contrib_ops/cpu/transformers/beam_search_scorer.h" +#include "contrib_ops/cpu/transformers/beam_search_impl_gpt.h" +#include "contrib_ops/cpu/transformers/beam_search_impl_t5.h" using namespace ONNX_NAMESPACE; using namespace onnxruntime::common; @@ -53,590 +56,156 @@ REGISTER_KERNEL_TYPED(float) namespace transformers { -template -gsl::span AllocateBuffer(AllocatorPtr allocator, - BufferUniquePtr& buffer, - size_t elements, - bool fill = false, - T fill_value = T{}) { - size_t bytes = SafeInt(sizeof(T)) * elements; - void* data = allocator->Alloc(bytes); - BufferUniquePtr temp_buffer(data, BufferDeleter(allocator)); - buffer = std::move(temp_buffer); - T* first = reinterpret_cast(buffer.get()); - auto span = gsl::make_span(first, elements); - - if (fill) { - std::fill_n(first, elements, fill_value); - } - - return span; -} - -template -struct BeamSearchState : public IBeamSearchState { - void Init(AllocatorPtr allocator, - int batch_size, - int num_beams, - int vocab_size, - int sequence_length, - int max_length, - bool output_scores) { - size_t batch_beam_size = SafeInt(batch_size) * num_beams; - - size_t next_token_size = SafeInt(batch_beam_size) * vocab_size; - this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size); - this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); - - this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size); - - this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size); - - this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size); - - this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size); - - if (output_scores) { - size_t elements = SafeInt(max_length - sequence_length) * batch_size * num_beams * vocab_size; - this->scores = AllocateBuffer(allocator, scores_buffer_, elements); - this->remaining_scores = this->scores; - } - } - - private: - BufferUniquePtr next_token_logits_buffer_; - BufferUniquePtr next_token_scores_buffer_; - BufferUniquePtr next_tokens_buffer_; - BufferUniquePtr next_indices_buffer_; - BufferUniquePtr next_positions_buffer_; - BufferUniquePtr beam_scores_buffer_; - BufferUniquePtr scores_buffer_; -}; - -struct BeamSearchCpuState : public IBeamSearchCpuState { - Sequences sequences; - - void Init(AllocatorPtr allocator, size_t batch_beam_size, int max_length, bool is_cuda) { - this->sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size); - this->sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, SafeInt(2) * batch_beam_size * max_length); - - if (is_cuda) { - // buffers used by CUDA operator but not by CPU operator. - this->topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * batch_beam_size); - this->topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * batch_beam_size); - this->topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * batch_beam_size); - this->final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size); - } - } - - private: - BufferUniquePtr final_beam_scores_buffer_; - BufferUniquePtr sequence_lengths_buffer_; - BufferUniquePtr topk_scores_buffer_; - BufferUniquePtr topk_tokens_buffer_; - BufferUniquePtr topk_indices_buffer_; - BufferUniquePtr sequences_space_buffer_; -}; - -template -class BeamSearchImpl { - public: - BeamSearchImpl(OpKernelContextInternal& context, - const SessionState& session_state, - GptSubgraph& gpt_subgraph, - concurrency::ThreadPool* thread_pool, - void* cuda_stream, - IConsoleDumper* cuda_dumper, - BeamSearchParameters& params, - const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func, - const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, - const BeamSearchDeviceHelper::TopkFunc& topk_func, - const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func, - const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func, - const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func, - const BeamSearchDeviceHelper::UpdateFeedsFunc& update_feeds_func) - : context_(context), - session_state_(session_state), - gpt_subgraph_(gpt_subgraph), - thread_pool_(thread_pool), - implicit_inputs_(context_.GetImplicitInputs()), - cuda_stream_(cuda_stream), - cuda_dumper_(cuda_dumper), - parameters_(¶ms), - cpu_allocator_(nullptr), - temp_space_allocator_(nullptr), - create_inputs_func_(create_inputs_func), - add_to_feeds_func_(add_to_feeds_func), - topk_func_(topk_func), - process_logits_func_(process_logits_func), - init_beam_state_func_(init_beam_state_func), - device_copy_func_(device_copy_func), - update_feeds_func_(update_feeds_func) { - parameters_->ParseFromInputs(&context); - - cpu_allocator_ = session_state.GetExecutionProviders() - .Get(onnxruntime::kCpuExecutionProvider) - ->GetAllocator(0, OrtMemTypeDefault); - } - - // Initialize by validating all the inputs, and allocating the output tensors. - Status Initialize(); - - // Execute beam search in iterations util stopping criteria is reached. - // In each iteration, GPT subgraph is called, and next token for each sequence is generated. - Status Execute(const FeedsFetchesManager& feeds_fetches_manager); - - private: - bool IsCuda() const { return cuda_stream_ != nullptr; } - - // Validate inputs. - Status CheckInputs(const OpKernelContextInternal& context); - - // Prepare the inputs for first inference of subgraph - Status CreateInitialFeeds(gsl::span& sequence_lengths, OrtValue& expanded_input_ids, std::vector& feeds, IAllocatorUniquePtr& buffer); - - // Update the input for next iteration. - Status UpdateFeeds( - const std::vector& last_outputs, - std::vector& next_inputs, - int current_length, - OrtValue& position_ids, - gsl::span beam_next_tokens, - gsl::span beam_indices); - - // Process logits and append next tokens to sequences. - Status GenerateNextToken(const OrtValue& logits, - gsl::span& beam_next_tokens, - gsl::span& beam_indices, - BeamSearchState& beam_state, - BeamSearchCpuState& cpu_state, - int counter); - - // Calculate scores from logits, then apply filtering and select next token for each beam. - Status ProcessLogits(const OrtValue& logits, // logits output of subgraph - BeamSearchState& beam_state, - BeamSearchCpuState& cpu_state, - AllocatorPtr& allocator, - int counter); - - const IConsoleDumper* GetConsoleDumper() const { return IsCuda() ? cuda_dumper_ : &(cpu_dumper_); } - - OpKernelContextInternal& context_; - - const SessionState& session_state_; - - GptSubgraph& gpt_subgraph_; - - concurrency::ThreadPool* thread_pool_; - - const std::vector& implicit_inputs_; - - void* cuda_stream_; - - IConsoleDumper* cuda_dumper_; - CpuTensorConsoleDumper cpu_dumper_; - - BeamSearchParameters* parameters_; - - LogitsProcessorList logits_processors_; - - std::unique_ptr beam_scorer_; - - AllocatorPtr cpu_allocator_; - AllocatorPtr temp_space_allocator_; - - // Device specific functions - BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_; - BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_; - BeamSearchDeviceHelper::TopkFunc topk_func_; - BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_func_; - BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_func_; - BeamSearchDeviceHelper::DeviceCopyFunc device_copy_func_; - BeamSearchDeviceHelper::UpdateFeedsFunc update_feeds_func_; -}; - void BeamSearch::Init(const OpKernelInfo& info) { - // Make sure the decoder attribute was present even though we don't need it here. + parameters_.ParseFromAttributes(info); + + // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5). + ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt || + parameters_.model_type == IBeamSearchParameters::kModelTypeT5); + ONNX_NAMESPACE::GraphProto proto; + if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) { + ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); + } + + // Make sure the decoder attribute was present even though we don't need it here. ORT_ENFORCE(info.GetAttr("decoder", &proto).IsOK()); ORT_IGNORE_RETURN_VALUE(proto); - - parameters_.ParseFromAttributes(info); } Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { - ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); - // TODO: handle another subgraph with attribute name "encoder_decode_init" - if (attribute_name == "decoder") { - const auto& node = Node(); - gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer()); - ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state)); - feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); - parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size, - gpt_subgraph_->num_heads, - gpt_subgraph_->head_size, - gpt_subgraph_->num_layers); + const auto& node = Node(); + if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (attribute_name == "decoder") { + ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + gpt_subgraph_ = std::make_unique(node, attribute_name, subgraph_session_state.GetGraphViewer()); + ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state)); + decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); + parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size, + gpt_subgraph_->num_heads, + gpt_subgraph_->head_size, + gpt_subgraph_->num_layers); + } + } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { + if (attribute_name == "encoder") { + ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, + "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + t5_encoder_subgraph_ = std::make_unique(node, + attribute_name, + subgraph_session_state.GetGraphViewer()); + ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state)); + encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager(); + + if (parameters_.decoder_start_token_id < 0) { + ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2, + "Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty"); + } else { + ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3, + "Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available"); + } + } else if (attribute_name == "decoder") { + ORT_ENFORCE(t5_decoder_subgraph_ == nullptr, + "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + t5_decoder_subgraph_ = std::make_unique(node, + attribute_name, + subgraph_session_state.GetGraphViewer()); + ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state)); + decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager(); + parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size, + t5_decoder_subgraph_->num_heads, + t5_decoder_subgraph_->head_size, + t5_decoder_subgraph_->num_layers); + } } + return Status::OK(); } Status BeamSearch::Compute(OpKernelContext* ctx) const { - if (parameters_.model_type != 0) { - // TODO: support encoder decoder model like T5 - return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Support of 'model_type' != 0 is not implemented"); - } - auto* ctx_internal = static_cast(ctx); - auto* session_state = ctx_internal->SubgraphSessionState("decoder"); - ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'decoder' attribute."); - ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); + + auto* decoder_session_state = ctx_internal->SubgraphSessionState("decoder"); + ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute."); + ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); - BeamSearchParameters parameters = parameters_; // make a copy since we will update the parameters based on inputs later + // Make a copy of parameters since we will update it based on inputs later + BeamSearchParameters parameters = parameters_; + + if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32 + BeamSearchGpt impl{ + *ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters, + BeamSearchCpuDeviceHelper::CreateGptInputs, + add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK, + process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits, + init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState, + device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, + device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, + update_gpt_feeds_func_ ? update_gpt_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateGptFeeds}; + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(*decoder_feeds_fetches_manager_); + } else { // Output float16 + BeamSearchGpt impl{ + *ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters, + BeamSearchCpuDeviceHelper::CreateGptInputs, + add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK, + process_logits_fp16_func_, + init_beam_state_fp16_func_, + device_copy_func_, + device_copy_int32_func_, + update_gpt_feeds_fp16_func_}; + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(*decoder_feeds_fetches_manager_); + } + } + + auto* encoder_session_state = ctx_internal->SubgraphSessionState("encoder"); + ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute."); + ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); // Subgraph has constraint that the output is either float or float16 - if (!gpt_subgraph_->IsOutputFloat16()) { - BeamSearchImpl impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters, - create_inputs_func_ ? create_inputs_func_ : BeamSearchCpuDeviceHelper::CreateInputs, - add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds, - topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK, - process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits, - init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState, - device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, - update_feeds_func_ ? update_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateFeeds}; + if (!t5_decoder_subgraph_->IsOutputFloat16()) { + BeamSearchT5 impl{ + *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, + *t5_decoder_subgraph_, thread_pool, cuda_stream_, dumper_, parameters, + add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK, + process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits, + init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState, + device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, + device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy, + create_encoder_inputs_func_ ? create_encoder_inputs_func_ : BeamSearchCpuDeviceHelper::CreateEncoderInputs, + update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds}; ORT_RETURN_IF_ERROR(impl.Initialize()); - return impl.Execute(*feeds_fetches_manager_); + return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); } else { - BeamSearchImpl impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters, - create_inputs_func_ ? create_inputs_func_ : BeamSearchCpuDeviceHelper::CreateInputs, - add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds, - topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK, - process_logits_fp16_func_, - init_beam_state_fp16_func_, - device_copy_func_, - update_feeds_fp16_func_}; + BeamSearchT5 impl{ + *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, + *t5_decoder_subgraph_, thread_pool, cuda_stream_, dumper_, parameters, + add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK, + process_logits_fp16_func_, + init_beam_state_fp16_func_, + device_copy_func_, + device_copy_int32_func_, + create_encoder_inputs_func_, + update_decoder_feeds_fp16_func_}; + ORT_RETURN_IF_ERROR(impl.Initialize()); - return impl.Execute(*feeds_fetches_manager_); + return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); } } -template -Status BeamSearchImpl::CheckInputs(const OpKernelContextInternal& context) { - // Input shapes: - // input_ids : (batch_size, sequence_length) - // vocab_mask : (vocab_size) or nullptr - - const Tensor* input_ids = context.Input(0); - const auto& dims = input_ids->Shape().GetDims(); - if (dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ", - dims.size()); - } - - const Tensor* vocab_mask = context.Input(8); - if (vocab_mask != nullptr) { // vocab_mask is optional - const auto& vocab_mask_dims = vocab_mask->Shape().GetDims(); - if (vocab_mask_dims.size() != 1) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ", - vocab_mask_dims.size()); - } - - // There is dependency on vocab_size parameter, which shall be set before calling this function. - if (static_cast(vocab_mask_dims[0]) != parameters_->vocab_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ", - vocab_mask_dims[0]); - } - - // store vocab mask in parameters. - parameters_->vocab_mask = vocab_mask->DataAsSpan(); - } - - const Tensor* prefix_vocab_mask = context.Input(9); - if (prefix_vocab_mask != nullptr) { - // prefix_vocab_mask is optional - const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims(); - if (vocab_mask_dims.size() != 2) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ", - vocab_mask_dims.size()); - } - - // prefix_vocab_mask first dimension should be same as the first dimension of input_ids - if (static_cast(vocab_mask_dims[0]) != static_cast(dims[0])) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size"); - } - - // There is dependency on vocab_size parameter, which shall be set before calling this function. - if (static_cast(vocab_mask_dims[1]) != parameters_->vocab_size) { - return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ", - vocab_mask_dims[0]); - } - - // store prefix vocab mask in parameters. - parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan(); - } - - return Status::OK(); -} - -template -Status BeamSearchImpl::Initialize() { - ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator_)); - -#define CHECK_SCALAR_INPUT(name, index, required) \ - auto* name##_tensor = context_.Input(index); \ - if (name##_tensor) { \ - if (!name##_tensor->Shape().IsScalar()) { \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \ - name##_tensor->Shape()); \ - } \ - } else if (required) { \ - return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \ - } - - CHECK_SCALAR_INPUT(min_length, 1, false); - - CHECK_SCALAR_INPUT(max_length, 2, true); - - CHECK_SCALAR_INPUT(num_beams, 3, true); - - CHECK_SCALAR_INPUT(num_return_sequences, 4, true); - - CHECK_SCALAR_INPUT(temperature, 5, true); - - CHECK_SCALAR_INPUT(length_penalty, 6, true); - - ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'."); - - ORT_RETURN_IF_ERROR(CheckInputs(context_)); - - // This flag will be updated later when the scores output exists. - parameters_->output_scores = false; - - if (!IsCuda()) { - // Logits processor is used in CPU only. In CUDA, cuda kernels are used instead. - // Initialize processsors after CheckInputs so that parameters_->vocab_mask is ready. - logits_processors_.Init(*parameters_); - } - - return Status::OK(); -} - -template -Status BeamSearchImpl::CreateInitialFeeds(gsl::span& sequence_lengths, OrtValue& expanded_input_ids, std::vector& feeds, IAllocatorUniquePtr& buffer) { - const OrtValue* input_ids_value = context_.GetInputOrtValue(0); - const Tensor& input_ids = input_ids_value->Get(); - return gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, sequence_lengths, expanded_input_ids, feeds, create_inputs_func_, add_to_feeds_func_, buffer); -} - -template -Status BeamSearchImpl::ProcessLogits( - const OrtValue& logits, - BeamSearchState& beam_state, - BeamSearchCpuState& cpu_state, - AllocatorPtr& allocator, - int counter) { - return process_logits_func_(logits, &beam_state, &cpu_state, &(cpu_state.sequences), allocator, - thread_pool_, &logits_processors_, beam_scorer_.get(), - parameters_, counter, cuda_stream_, GetConsoleDumper()); -} - -template -Status BeamSearchImpl::GenerateNextToken( - const OrtValue& logits, - gsl::span& beam_next_tokens, - gsl::span& beam_indices, - BeamSearchState& beam_state, - BeamSearchCpuState& cpu_state, - int counter) { - // Process logits to get next token scores - ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, cpu_state, temp_space_allocator_, counter)); - - gsl::span& beam_scores = beam_scorer_->GetNextScores(); - // It is optional to clone beam_scores. Change it to use same buffer also works for CPU: - // beam_state.beam_scores = beam_scores - // Here we make a copy to reduce the coupling with little cost (the buffer size is small). - ORT_RETURN_IF_ERROR(device_copy_func_(beam_state.beam_scores, beam_scores, cuda_stream_, DeviceCopyDirection::hostToDevice)); - - beam_next_tokens = beam_scorer_->GetNextTokens(); - beam_indices = beam_scorer_->GetNextIndices(); - -#ifdef DEBUG_BEAM_SEARCH - cpu_dumper_.Print("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams); - cpu_dumper_.Print("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams); - cpu_dumper_.Print("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams); -#endif - - cpu_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens); - -#ifdef DEBUG_BEAM_SEARCH - cpu_state.sequences.PrintSequences(&cpu_dumper_); -#endif - return Status::OK(); -} - -template -Status BeamSearchImpl::UpdateFeeds( - const std::vector& last_outputs, - std::vector& next_inputs, - int current_length, - OrtValue& position_ids, - gsl::span beam_next_tokens, - gsl::span beam_indices) { - return update_feeds_func_(temp_space_allocator_, cuda_stream_, last_outputs, next_inputs, current_length, position_ids, - beam_next_tokens, beam_indices, parameters_->num_beams, GetConsoleDumper()); -} - -template -Status BeamSearchImpl::Execute(const FeedsFetchesManager& feeds_fetches_manager) { - auto status = Status::OK(); - int64_t sequences_dims[] = {parameters_->batch_size, parameters_->num_return_sequences, parameters_->max_length}; - TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0])); - Tensor* output_sequences = context_.Output(0, sequences_shape); - - int64_t sequences_scores_dims[] = {parameters_->batch_size, parameters_->num_return_sequences}; - TensorShape sequences_scores_shape(&sequences_scores_dims[0], sizeof(sequences_scores_dims) / sizeof(sequences_scores_dims[0])); - Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape); - - int64_t scores_dims[] = { - static_cast(parameters_->max_length) - static_cast(parameters_->sequence_length), - parameters_->batch_size, parameters_->num_beams, parameters_->vocab_size}; - TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0])); - Tensor* output_scores = context_.Output(2, scores_shape); - - // Update the flag to indicate whether scores exists in output - parameters_->output_scores = (output_scores != nullptr); - - std::vector feeds; - // TODO: allocate fetches. use ping-pong buffers for past state. - std::vector fetches; - - // Initialize resources - onnxruntime::OrtStlAllocator hypothesis_score_allocator(cpu_allocator_); - onnxruntime::OrtStlAllocator beam_hyps_allocator(cpu_allocator_); - beam_scorer_ = std::make_unique(static_cast(parameters_->batch_size), - static_cast(parameters_->num_beams), - static_cast(parameters_->max_length), - parameters_->length_penalty, - parameters_->early_stopping, - static_cast(parameters_->num_return_sequences), - parameters_->pad_token_id, - parameters_->eos_token_id, - hypothesis_score_allocator, - beam_hyps_allocator); - beam_scorer_->Initialize(cpu_allocator_, parameters_->sequence_length); - - BeamSearchCpuState cpu_state; - cpu_state.Init(cpu_allocator_, static_cast(parameters_->BatchBeamSize()), parameters_->max_length, IsCuda()); - - // buffer in GPU for input_ids, position_ids and attention_mask - // size_t buffer_bytes = SafeInt(sizeof(int32_t) + sizeof(int32_t) + sizeof(int32_t)) * parameters_->batch_size * parameters_->num_beams * parameters_->sequence_length; - // IAllocatorUniquePtr buffer = gpt_subgraph_.GetProvider()->GetScratchBuffer(buffer_bytes); - IAllocatorUniquePtr buffer; - OrtValue expanded_input_ids_in_cpu; - ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer)); - - BeamSearchState beam_state; - beam_state.Init(temp_space_allocator_, - parameters_->batch_size, - parameters_->num_beams, - parameters_->vocab_size, - parameters_->sequence_length, - parameters_->max_length, - parameters_->output_scores); - - cpu_state.sequences.Init(cpu_state.sequences_space, - parameters_->BatchBeamSize(), - parameters_->sequence_length, - parameters_->max_length); - - gsl::span input_ids = expanded_input_ids_in_cpu.Get().DataAsSpan(); - init_beam_state_func_(&beam_state, - &cpu_state, - cpu_state.sequence_lengths, - parameters_->batch_size, - parameters_->num_beams, - input_ids, - parameters_->sequence_length, - parameters_->max_length, - cuda_stream_); - -#ifdef DEBUG_BEAM_SEARCH - const IConsoleDumper* dumper = GetConsoleDumper(); - dumper->Print("input_ids", feeds[0]); - dumper->Print("position_ids", feeds[1]); - dumper->Print("attention_mask", feeds[2]); -#endif - - // position ids for all iterations except the first. It uses memory buffer owned by next_positions. - OrtValue position_ids; - int64_t dims[] = {parameters_->BatchBeamSize(), 1}; - TensorShape shape(&dims[0], 2); - Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, beam_state.next_positions.data(), temp_space_allocator_->Info(), position_ids); - - int current_length = parameters_->sequence_length; - int iteration_counter = 0; - while (current_length < parameters_->max_length) { - iteration_counter++; -#ifdef DEBUG_BEAM_SEARCH - auto cur_len = std::to_string(current_length); - dumper->Print("***CurrentLength", cur_len, true); -#endif - - status = utils::ExecuteSubgraph(session_state_, feeds_fetches_manager, feeds, fetches, {}, - ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger()); - - ORT_RETURN_IF_ERROR(status); - - const OrtValue& logits = fetches[0]; - gsl::span beam_next_tokens; - gsl::span beam_indices; - ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, cpu_state, iteration_counter)); - - // When all batches are finished, stop earlier to avoid wasting computation. - if (beam_scorer_->IsDone()) { - break; - } - - // Increase sequence length after a new token is generated. - ++current_length; - - // Prepare inputs for next round of subgraph call. - if (current_length < parameters_->max_length) { - ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length, - position_ids, - beam_next_tokens.as_span(), - beam_indices.as_span())); - } - fetches.clear(); - } - - gsl::span final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size()); - if (IsCuda()) { - ORT_RETURN_IF_ERROR(device_copy_func_(cpu_state.final_beam_scores, final_beam_scores, nullptr, DeviceCopyDirection::deviceToHost)); - final_beam_scores = gsl::make_span(cpu_state.final_beam_scores.data(), cpu_state.final_beam_scores.size()); - } - - beam_scorer_->Finalize(&(cpu_state.sequences), - final_beam_scores, - output_sequences, - output_sequences_scores); - - // Output per token scores - if (output_scores != nullptr) { - gsl::span target = output_scores->MutableDataAsSpan(); - gsl::span source = gsl::span(beam_state.scores.data(), beam_state.scores.size()); - assert(target.length() == source.length()); - ORT_RETURN_IF_ERROR(device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); - } - - return status; -} - } // namespace transformers } // namespace contrib } // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 217dc080cc..3a0de82010 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -2,12 +2,16 @@ // Licensed under the MIT License. #pragma once + +#include #include "core/common/common.h" #include "core/framework/op_kernel.h" #include "core/providers/cpu/controlflow/utils.h" -#include "beam_search_parameters.h" -#include "gpt_subgraph.h" -#include "beam_search_device_helper.h" +#include "contrib_ops/cpu/transformers/beam_search_parameters.h" +#include "contrib_ops/cpu/transformers/subgraph_gpt.h" +#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" +#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" +#include "contrib_ops/cpu/transformers/beam_search_device_helper.h" namespace onnxruntime { class FeedsFetchesManager; @@ -20,7 +24,11 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel class BeamSearch : public IControlFlowKernel { public: BeamSearch(const OpKernelInfo& info) - : IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) { + : IControlFlowKernel(info), + encoder_feeds_fetches_manager_(nullptr), + decoder_feeds_fetches_manager_(nullptr), + cuda_stream_(nullptr), + dumper_(nullptr) { Init(info); } @@ -36,54 +44,76 @@ class BeamSearch : public IControlFlowKernel { void SetComputeStream(void* stream) { cuda_stream_ = stream; } void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; } + // device helpers that is same for both GPT and encoder-decoder models. void SetDeviceHelpers( - // const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func, const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, - const BeamSearchDeviceHelper::TopkFunc& topk_func) { - // create_inputs_func_ = create_inputs_func; + const BeamSearchDeviceHelper::TopkFunc& topk_func, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func, + const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_fp16_func, + const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func, + const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_fp16_func) { add_to_feeds_func_ = add_to_feeds_func; topk_func_ = topk_func; - } - - // Type dependent helpers: float - void SetDeviceHelpers( - const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func, - const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func, - const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func, - const BeamSearchDeviceHelper::UpdateFeedsFunc& update_feeds_func) { - process_logits_func_ = process_logits_func; - init_beam_state_func_ = init_beam_state_func; device_copy_func_ = device_copy_func; - update_feeds_func_ = update_feeds_func; + device_copy_int32_func_ = device_copy_int32_func; + process_logits_func_ = process_logits_func; + process_logits_fp16_func_ = process_logits_fp16_func; + init_beam_state_func_ = init_beam_state_func; + init_beam_state_fp16_func_ = init_beam_state_fp16_func; } - // Type dependent helpers: MLFloat16 - void SetDeviceHelpers( - const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func, - const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func, - const BeamSearchDeviceHelper::UpdateFeedsFunc& update_feeds_func) { - process_logits_fp16_func_ = process_logits_func; - init_beam_state_fp16_func_ = init_beam_state_func; - update_feeds_fp16_func_ = update_feeds_func; + void SetDeviceHelpers_Gpt( + const BeamSearchDeviceHelper::UpdateGptFeedsFunc& update_gpt_feeds_func, + const BeamSearchDeviceHelper::UpdateGptFeedsFunc& update_gpt_feeds_fp16_func) { + update_gpt_feeds_func_ = update_gpt_feeds_func; + update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; + } + + // device helpers for encoder-decoder model like T5 + void SetDeviceHelpers_EncoderDecoder( + const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func, + const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_fp16_func) { + update_decoder_feeds_func_ = update_decoder_feeds_func; + update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func; } private: // Device specific functions - BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_; BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_; BeamSearchDeviceHelper::TopkFunc topk_func_; - BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_func_; - BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_func_; BeamSearchDeviceHelper::DeviceCopyFunc device_copy_func_; - BeamSearchDeviceHelper::UpdateFeedsFunc update_feeds_func_; + BeamSearchDeviceHelper::DeviceCopyFunc device_copy_int32_func_; + BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_func_; BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_fp16_func_; - BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_fp16_func_; - BeamSearchDeviceHelper::UpdateFeedsFunc update_feeds_fp16_func_; + BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_func_; + BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_fp16_func_; + + //------------------------------------------------------------ + // Device specific functions for GPT + //------------------------------------------------------------ + BeamSearchDeviceHelper::UpdateGptFeedsFunc update_gpt_feeds_func_; + BeamSearchDeviceHelper::UpdateGptFeedsFunc update_gpt_feeds_fp16_func_; + + //------------------------------------------------------------ + // Device specific functions for encoder-decoder model like T5 + //------------------------------------------------------------ + BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_; + + BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_func_; + BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_fp16_func_; + + //------------------------------------------------------------ // Subgraph and FeedsFetchesManager re-used for each subgraph execution. + //------------------------------------------------------------ std::unique_ptr gpt_subgraph_; - FeedsFetchesManager* feeds_fetches_manager_; + std::unique_ptr t5_encoder_subgraph_; + std::unique_ptr t5_decoder_subgraph_; + FeedsFetchesManager* encoder_feeds_fetches_manager_; + FeedsFetchesManager* decoder_feeds_fetches_manager_; void* cuda_stream_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc index 2397623a34..55fefe7460 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.cc @@ -1,10 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include #include "core/providers/cpu/math/top_k.h" #include "core/providers/cpu/math/softmax_shared.h" #include "core/common/safeint.h" #include "gsl/gsl" -#include "sequences.h" -#include "beam_search_scorer.h" -#include "beam_search_device_helper.h" +#include "contrib_ops/cpu/transformers/sequences.h" +#include "contrib_ops/cpu/transformers/beam_search_scorer.h" +#include "contrib_ops/cpu/transformers/beam_search_device_helper.h" +#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" +#include "contrib_ops/cpu/transformers/subgraph_gpt.h" namespace onnxruntime { namespace contrib { @@ -25,11 +32,10 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, input->DataType(), " is not supported yet"); } -OrtValue ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator) { - // Input shape (batch_size, sequence_length) +template +void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) { + // Input shape (batch_size, sequence_length). The input is required with data type T. // Output shape (batch_size * num_beams, sequence_length) - if (num_beams == 1) - return input; const TensorShape& input_shape = input.Get().Shape(); const int64_t& batch_size = input_shape[0]; @@ -38,31 +44,28 @@ OrtValue ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocat int64_t dims[] = {batch_size * num_beams, sequence_length}; TensorShape expanded_shape(&dims[0], 2); - OrtValue expanded; MLDataType element_type = input.Get().DataType(); - ORT_ENFORCE(element_type == DataTypeImpl::GetType(), "input_ids, position_ids and attention_mask is required to be int32 data type"); + ORT_ENFORCE(element_type == DataTypeImpl::GetType()); Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded); - const int32_t* input_data = input.Get().Data(); - int32_t* expanded_data = expanded.GetMutable()->MutableData(); - int32_t* target = expanded_data; + const T* input_data = input.Get().Data(); + T* expanded_data = expanded.GetMutable()->MutableData(); + T* target = expanded_data; for (int i = 0; i < batch_size; i++) { for (int j = 0; j < num_beams; j++) { - memcpy(target, input_data + i * sequence_length, sizeof(int32_t) * sequence_length); + memcpy(target, input_data + i * sequence_length, sizeof(T) * sequence_length); target += sequence_length; } } - - return expanded; } -Status CreateInputs( +Status CreateGptInputs( const Tensor* original_input_ids, int num_beams, int pad_token_id, gsl::span& sequence_lengths, - AllocatorPtr alloactor, + AllocatorPtr allocator, OrtValue& expanded_input_ids, OrtValue& expanded_position_ids, OrtValue& expanded_attention_mask) { @@ -74,21 +77,22 @@ Status CreateInputs( // Allocate position_ids and attention_mask based on shape of input_ids auto element_type = DataTypeImpl::GetType(); - const OrtMemoryInfo& location = alloactor->Info(); + const OrtMemoryInfo& location = allocator->Info(); // Use original input_ids. This requires the input_ids for subgraph is also int32. // Current shape is (batch_size, sequence_length) // Note that we will expand it to (batch_size * num_beams, sequence_length) later. // To avoid cloning input_ids, we use const_cast here since this function does not change its content. OrtValue input_ids; - Tensor::InitOrtValue(element_type, input_ids_shape, const_cast(original_input_ids)->MutableData(), location, input_ids); + Tensor::InitOrtValue(element_type, input_ids_shape, + const_cast(original_input_ids)->MutableData(), location, input_ids); OrtValue position_ids; - Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, position_ids); + Tensor::InitOrtValue(element_type, input_ids_shape, allocator, position_ids); OrtValue attention_mask; auto mask_type = DataTypeImpl::GetType(); - Tensor::InitOrtValue(mask_type, input_ids_shape, alloactor, attention_mask); + Tensor::InitOrtValue(mask_type, input_ids_shape, allocator, attention_mask); // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. // Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens @@ -115,43 +119,45 @@ Status CreateInputs( } } - // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask - // TODO: Try expand outputs after first subgraph call instead. That may get better performance, but more complex to implement. - expanded_input_ids = ExpandInputs(input_ids, num_beams, alloactor); - expanded_position_ids = ExpandInputs(position_ids, num_beams, alloactor); - expanded_attention_mask = ExpandInputs(attention_mask, num_beams, alloactor); + // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) + // TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance. + ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids); + ExpandInputs(position_ids, num_beams, allocator, expanded_position_ids); + ExpandInputs(attention_mask, num_beams, allocator, expanded_attention_mask); return Status::OK(); } Status AddToFeeds(const IExecutionProvider* /*execution_provider*/, - OrtValue& input_ids, - OrtValue& position_ids, - OrtValue& attention_mask, + std::initializer_list inputs, std::vector& feeds, IAllocatorUniquePtr& /*buffer*/) { - feeds.push_back(input_ids); - feeds.push_back(position_ids); - feeds.push_back(attention_mask); + for (auto& input : inputs) { + if (input.IsAllocated()) { + feeds.push_back(input); + } + } + return Status::OK(); } template void InitBeamState(transformers::IBeamSearchState* beam_state, - transformers::IBeamSearchCpuState* cpu_state, gsl::span& sequence_lengths, int batch_size, int num_beams, - gsl::span input_ids_in_cpu, - int sequence_length, - int max_length, void* /*stream*/) { memset(beam_state->beam_scores.data(), 0, beam_state->beam_scores.size_bytes()); memset(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes()); memset(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes()); memset(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes()); memset(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes()); - memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes()); + + // T5 does not need position, so next_positions is empty for T5. + if (!beam_state->next_positions.empty()) { + memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes()); + gsl::copy(sequence_lengths, beam_state->next_positions); + } // Initialize score of first beam of each group with 0 and the rest with -1e9. // This ensures that the beams in the same group don't produce same tokens every time. @@ -161,19 +167,6 @@ void InitBeamState(transformers::IBeamSearchState* beam_state, beam_scores[SafeInt(i) * num_beams + j] = -1e9; } } - - gsl::copy(sequence_lengths, beam_state->next_positions); - - memset(cpu_state->sequences_space.data(), 0, cpu_state->sequences_space.size_bytes()); - - // Copy input_ids to sequences[0]. - gsl::span sequences_0 = cpu_state->sequences_space; - int batch_beam_size = batch_size * num_beams; - for (int i = 0; i < batch_beam_size; i++) { - for (int j = 0; j < sequence_length; j++) { - sequences_0[SafeInt(i) * max_length + j] = static_cast(input_ids_in_cpu[SafeInt(i) * sequence_length + j]); - } - } } template @@ -216,7 +209,8 @@ Status ProcessLogits(const OrtValue& logits, // const T* current_logits = logits_data + (input_length - 1) * vocab_size; for (int i = 0; i < batch_beam_size; i++) { gsl::span source(current_logits, vocab_size); - gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, static_cast(vocab_size)); + gsl::span target = next_token_logits.subspan(SafeInt(i) * vocab_size, + static_cast(vocab_size)); gsl::copy(source, target); current_logits += input_length * vocab_size; } @@ -224,7 +218,9 @@ Status ProcessLogits(const OrtValue& logits, // #ifdef DEBUG_BEAM_SEARCH dumper->Print("logits", logits); - dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); + if (input_length > 1) { + dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size); + } #endif // Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1) @@ -244,12 +240,12 @@ Status ProcessLogits(const OrtValue& logits, // logits_processors->Process(sequences, next_token_scores, step); #ifdef DEBUG_BEAM_SEARCH - dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, num_beams, vocab_size); + dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif // Add beam score to next token scores. Corresponding python code is like: // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) - // TODO: use thread pool to parrellel + // TODO(tianleiwu): use thread pool to parallel int offset = 0; int batch_beam_index = 0; for (int i = 0; i < batch_size; i++) { @@ -261,7 +257,7 @@ Status ProcessLogits(const OrtValue& logits, // } #ifdef DEBUG_BEAM_SEARCH - dumper->Print("next_token_scores after adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size); + dumper->Print("next_token_scores adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif if (output_scores) { @@ -277,7 +273,8 @@ Status ProcessLogits(const OrtValue& logits, // TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); auto element_type = DataTypeImpl::GetType(); OrtValue next_token_scores_value; - Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value); + Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), + next_token_scores_value); const Tensor& input = next_token_scores_value.Get(); constexpr int axis = 1; @@ -287,7 +284,8 @@ Status ProcessLogits(const OrtValue& logits, // std::unique_ptr topk_scores; std::unique_ptr topk_indices; - ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, topk_scores, topk_indices)); + ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, + topk_scores, topk_indices)); #ifdef DEBUG_BEAM_SEARCH dumper->Print("topk_scores", *(topk_scores.get())); @@ -331,32 +329,34 @@ Status DeviceCopy(gsl::span target, gsl::span source, void* /*stream return Status::OK(); } +// Copy present state to past state for GPT model template -void PickPastState(const std::vector& last_outputs, - std::vector& next_inputs, - gsl::span& beam_indices, - AllocatorPtr allocator, - void* /*stream*/) { +void PickGptPastState(const std::vector& last_outputs, + std::vector& next_inputs, + gsl::span& beam_indices, + AllocatorPtr allocator) { + int num_present_tensors = static_cast(last_outputs.size()) - transformers::GptSubgraph::kFirstPresentOutputIndex; + for (int i = 0; i < num_present_tensors; ++i) { + const OrtValue& present = last_outputs[transformers::GptSubgraph::kFirstPresentOutputIndex + i]; - for (size_t i = 1; i < last_outputs.size(); ++i) { - const OrtValue& present = last_outputs[i]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64) + // shape is like (2, batch_beam_size, 12, past_seq_len, 64) const TensorShape& past_shape = present.Get().Shape(); + auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; + auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4]; // Create a tensor with same shape. - // TODO: allocate one buffer for all layers + // TODO(tianleiwu): allocate one buffer for all layers OrtValue past; auto past_type = DataTypeImpl::GetType(); Tensor::InitOrtValue(past_type, past_shape, allocator, past); - auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; - auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4]; - gsl::span past_span = gsl::make_span(past.GetMutable()->MutableData(), past_shape.Size()); gsl::span present_span = gsl::make_span(present.Get().Data(), past_shape.Size()); for (gsl::index j = 0; j < beam_indices.length(); j++) { int32_t beam_index = beam_indices[j]; gsl::span present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); - gsl::span present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam); + gsl::span present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, + block_size_per_beam); gsl::span past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam); gsl::span past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam); @@ -364,12 +364,12 @@ void PickPastState(const std::vector& last_outputs, gsl::copy(present_value, past_value); } - next_inputs[i + 2] = past; + next_inputs[transformers::GptSubgraph::kFirstPastInputIndex + i] = past; } } template -Status UpdateFeeds( +Status UpdateGptFeeds( AllocatorPtr allocator, void* stream, const std::vector& last_outputs, @@ -382,6 +382,7 @@ Status UpdateFeeds( const transformers::IConsoleDumper* dumper) { // last_outputs: logits, present_0, present_1, ... // next_inputs: input_ids, position_id, attention_mask, past_0, past_1 + ORT_UNUSED_PARAMETER(stream); // The following updates inputs for subgraph @@ -391,7 +392,7 @@ Status UpdateFeeds( TensorShape input_ids_shape(&dims[0], 2); auto int32_type = DataTypeImpl::GetType(); OrtValue input_ids; - // TODO: Reuse buffer for input_ids to reduce memory allocation. + // TODO(tianleiwu): Reuse buffer for input_ids to reduce memory allocation. Tensor::InitOrtValue(int32_type, input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); for (int i = 0; i < batch_beam_size; i++) { @@ -433,25 +434,187 @@ Status UpdateFeeds( // Update past state if (num_beams == 1) { // feed present_* output to past_* inputs one by one - for (size_t i = 1; i < last_outputs.size(); ++i) { - next_inputs[i + 2] = last_outputs[i]; + const int k = transformers::GptSubgraph::kFirstPastInputIndex - transformers::GptSubgraph::kFirstPresentOutputIndex; + for (size_t i = transformers::GptSubgraph::kFirstPresentOutputIndex; i < last_outputs.size(); ++i) { + next_inputs[i + k] = last_outputs[i]; } } else { - PickPastState(last_outputs, next_inputs, beam_indices, allocator, stream); + PickGptPastState(last_outputs, next_inputs, beam_indices, allocator); } return Status::OK(); } +// --------------------------------------------------------------- +// The following functions are for encoder-decoder model like T5 +// --------------------------------------------------------------- +Status CreateEncoderInputs( + const Tensor* original_encoder_input_ids, + int num_beams, + int pad_token_id, + int start_token_id, + AllocatorPtr allocator, + OrtValue& expanded_encoder_input_ids, + OrtValue& expanded_encoder_attention_mask, + OrtValue& expanded_decoder_input_ids) { + const TensorShape& input_ids_shape = original_encoder_input_ids->Shape(); + ORT_ENFORCE(input_ids_shape.NumDimensions() == 2); + const int64_t& batch_size = input_ids_shape[0]; + const int64_t& sequence_length = input_ids_shape[1]; + + // Allocate attention_mask based on shape of input_ids + auto element_type = DataTypeImpl::GetType(); + + // Use original encoder_input_ids. This requires the input_ids for subgraph is also int32. + // Current shape is (batch_size, sequence_length) + // Note that we will expand it to (batch_size * num_beams, sequence_length) later. + // To avoid cloning input_ids, we use const_cast here since this function does not change its content. + OrtValue encoder_input_ids; + Tensor::InitOrtValue(element_type, + input_ids_shape, + const_cast(original_encoder_input_ids)->MutableData(), + allocator->Info(), + encoder_input_ids); + + OrtValue encoder_attention_mask; + auto mask_type = DataTypeImpl::GetType(); + Tensor::InitOrtValue(mask_type, input_ids_shape, allocator, encoder_attention_mask); + + // Set attention mask to be 0 for pad tokens, and 1 for all other tokens. + int32_t* mask_data = encoder_attention_mask.GetMutable()->MutableData(); + const int32_t* word_id = original_encoder_input_ids->Data(); + int32_t* mask = mask_data; + for (int i = 0; i < batch_size; i++) { + int32_t abs_position = 0; + for (int j = 0; j < sequence_length; j++, word_id++, mask++) { + // T5Tokenizer might add one EOS pad token at the end. + // That EOS token shall have attention mask 1 even when EOS token is same as pad token. + // Here we only set attention mask to be 0 for left padding only, so as to be parity with huggingface. + if (*word_id == pad_token_id && abs_position == 0) { + *mask = 0; + } else { + *mask = 1; + abs_position++; + } + } + } + + // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) + // for encoder_input_ids and encoder_attention_mask + // TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance. + ExpandInputs(encoder_input_ids, num_beams, allocator, expanded_encoder_input_ids); + ExpandInputs(encoder_attention_mask, num_beams, allocator, expanded_encoder_attention_mask); + + // decoder_input_ids is optional. + if (start_token_id >= 0) { + // Expanded decoder_input_ids has shape (batch_size * num_beams, 1), and filled with start token ID + int64_t dims[] = {batch_size * num_beams, 1}; + TensorShape decoder_input_ids_shape(&dims[0], 2); + Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, expanded_decoder_input_ids); + int32_t* data = expanded_decoder_input_ids.GetMutable()->MutableData(); + for (int i = 0; i < batch_size * num_beams; i++, data++) { + *data = start_token_id; + } + } + + return Status::OK(); +} + +// Copy present state to past state for T5 model +template +void PickT5PastState(const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span& beam_indices, + AllocatorPtr allocator) { + for (int i = 0; i < num_present_tensors; ++i) { + const OrtValue& present = last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i]; + + // shape is like (batch_beam_size, 12, past_seq_len, 64) + const TensorShape& past_shape = present.Get().Shape(); + auto block_size_per_beam = past_shape[1] * past_shape[2] * past_shape[3]; + + // Create a tensor with same shape. + // TODO(tianleiwu): allocate one buffer for all layers + OrtValue past; + Tensor::InitOrtValue(DataTypeImpl::GetType(), past_shape, allocator, past); + + gsl::span past_span = gsl::make_span(past.GetMutable()->MutableData(), past_shape.Size()); + gsl::span present_span = gsl::make_span(present.Get().Data(), past_shape.Size()); + for (gsl::index j = 0; j < beam_indices.length(); j++) { + int32_t beam_index = beam_indices[j]; + gsl::span present_beam = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); + gsl::span past_beam = past_span.subspan(j * block_size_per_beam, block_size_per_beam); + gsl::copy(present_beam, past_beam); + } + + next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] = past; + } +} + +// Update decoder inputs given decoder outputs of last iteration. +template +Status UpdateDecoderFeeds( + AllocatorPtr allocator, + void* stream, + const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams, + const transformers::IConsoleDumper* dumper) { + ORT_UNUSED_PARAMETER(stream); + + // last_outputs: logits, present_key_self_0, present_value_self_0, ... + // next_inputs: input_ids, + // encoder_attention_mask, encoder_hidden_states, + // past_key_self_0, past_value_self_0, ... + // past_key_cross_0, past_value_cross_0, ... + // Only need copy beam next tokens to input_ids, and copy present_*_self_* to past_*_self_*, + + // Update input_ids with next tokens. + int batch_beam_size = static_cast(beam_next_tokens.length()); + int64_t dims[] = {batch_beam_size, 1}; + TensorShape input_ids_shape(&dims[0], 2); + + // TODO(tianleiwu): Reuse buffer for input_ids to reduce memory allocation. + OrtValue input_ids; + Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); + + gsl::copy(beam_next_tokens, input_ids.GetMutable()->MutableDataAsSpan()); + + next_inputs[0] = input_ids; + +#ifdef DEBUG_BEAM_SEARCH + dumper->Print("input_ids", input_ids); +#else + ORT_UNUSED_PARAMETER(dumper); +#endif + + // Update past state + ORT_ENFORCE(last_outputs.size() >= static_cast(1 + num_present_tensors)); + // TODO(tianleiwu): remove num_beams==1 once GreedySearch operator is available. + if (num_beams == 1) { + // feed present_* output to past_* inputs one by one + for (int i = 0; i < num_present_tensors; ++i) { + next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] = + last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i]; + } + } else { + PickT5PastState(last_outputs, next_inputs, num_present_tensors, beam_indices, allocator); + } + return Status::OK(); +} + +//------------------------------------------------ // Explicit template instantiations of functions +//------------------------------------------------ + template void InitBeamState( transformers::IBeamSearchState* beam_state, - transformers::IBeamSearchCpuState* cpu_state, gsl::span& sequence_lengths, int batch_size, int num_beams, - gsl::span input_ids_in_cpu, - int sequence_length, - int max_length, void* stream); template Status ProcessLogits( @@ -472,9 +635,15 @@ template Status DeviceCopy( gsl::span target, gsl::span source, void* stream, - int copyDirectionn); + int copyDirection); -template Status UpdateFeeds( +template Status DeviceCopy( + gsl::span target, + gsl::span source, + void* stream, + int copyDirection); + +template Status UpdateGptFeeds( AllocatorPtr allocator, void* stream, const std::vector& last_outputs, @@ -486,6 +655,19 @@ template Status UpdateFeeds( int num_beams, const transformers::IConsoleDumper* dumper); +template Status UpdateDecoderFeeds( + AllocatorPtr allocator, + void* stream, + const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams, + const transformers::IConsoleDumper* dumper); + +template void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); + } // namespace BeamSearchCpuDeviceHelper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h index 2434ed85df..b58a95d4d0 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_device_helper.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #ifndef SHARED_PROVIDER @@ -6,9 +9,10 @@ #include "core/framework/allocator.h" #endif +#include #include "gsl/gsl" -#include "logits_processor.h" -#include "beam_search_shared.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/beam_search_shared.h" namespace onnxruntime { class IExecutionProvider; @@ -31,40 +35,34 @@ namespace BeamSearchDeviceHelper { using TopkFunc = std::function& output_values, std::unique_ptr& output_indices)>; -// Create subgraph inputs: input_ids, position_ids and attention_mask -using CreateInputsFunc = std::function& sequence_lengths, - AllocatorPtr alloactor, + AllocatorPtr allocator, OrtValue& expanded_input_ids, OrtValue& expanded_position_ids, OrtValue& expanded_attention_mask)>; using AddToFeedsFunc = std::function inputs, std::vector& feeds, IAllocatorUniquePtr& buffer)>; template using InitBeamStateFunc = std::function* beam_state, - transformers::IBeamSearchCpuState* cpu_state, gsl::span& sequence_lengths, int batch_size, int num_beams, - gsl::span input_ids_in_cpu, - int sequence_length, - int max_length, void* stream)>; template @@ -89,8 +87,9 @@ using DeviceCopyFunc = std::function; +// Update subgraph inputs given outputs of last iteration (for GPT-2). template -using UpdateFeedsFunc = std::function& last_outputs, @@ -102,6 +101,29 @@ using UpdateFeedsFunc = std::function; +// Create encoder inputs (for encoder-decoder model like T5). +using CreateEncoderInputsFunc = std::function; + +// Update decoder inputs given decoder outputs of last iteration (for encoder-decoder model like T5). +template +using UpdateDecoderFeedsFunc = std::function& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams, + const transformers::IConsoleDumper* dumper)>; } // namespace BeamSearchDeviceHelper // These are CPU specific device helper implementations @@ -114,33 +136,17 @@ Status TopK( std::unique_ptr& output_values, std::unique_ptr& output_indices); -Status CreateInputs( - const Tensor* original_input_ids, - int num_beams, - int pad_token_id, - gsl::span& sequence_lengths, - AllocatorPtr alloactor, - OrtValue& expanded_input_ids, - OrtValue& expanded_position_ids, - OrtValue& expanded_attention_mask); - Status AddToFeeds( const IExecutionProvider* execution_provider, - OrtValue& input_ids, - OrtValue& position_ids, - OrtValue& attention_mask, + std::initializer_list inputs, std::vector& feeds, IAllocatorUniquePtr& buffer); template void InitBeamState(transformers::IBeamSearchState* beam_state, - transformers::IBeamSearchCpuState* cpu_state, gsl::span& sequence_lengths, int batch_size, int num_beams, - gsl::span input_ids_in_cpu, - int sequence_length, - int max_length, void* stream); template @@ -163,8 +169,22 @@ Status DeviceCopy(gsl::span target, void* stream, int copyDirectionn); +// --------------------------------------------------------------- +// Functions for GPT model only +// --------------------------------------------------------------- + +Status CreateGptInputs( + const Tensor* original_input_ids, + int num_beams, + int pad_token_id, + gsl::span& sequence_lengths, + AllocatorPtr allocator, + OrtValue& expanded_input_ids, + OrtValue& expanded_position_ids, + OrtValue& expanded_attention_mask); + template -Status UpdateFeeds( +Status UpdateGptFeeds( AllocatorPtr allocator, void* stream, const std::vector& last_outputs, @@ -176,6 +196,38 @@ Status UpdateFeeds( int num_beams, const transformers::IConsoleDumper* dumper); +// --------------------------------------------------------------- +// Functions for encoder-decoder model like T5 +// --------------------------------------------------------------- +Status CreateEncoderInputs( + const Tensor* original_encoder_input_ids, + int num_beams, + int pad_token_id, + int start_token_id, + AllocatorPtr allocator, + OrtValue& expanded_encoder_input_ids, + OrtValue& expanded_encoder_attention_mask, + OrtValue& expanded_decoder_input_ids); + +// Update decoder inputs given decoder outputs of last iteration. +template +Status UpdateDecoderFeeds( + AllocatorPtr allocator, + void* stream, + const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams, + const transformers::IConsoleDumper* dumper); + +// --------------------------------------------------------------- +// Utility Functions +// --------------------------------------------------------------- +template +void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded); + } // namespace BeamSearchCpuDeviceHelper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h new file mode 100644 index 0000000000..7c6e4739d9 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -0,0 +1,364 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include + +namespace onnxruntime { +namespace contrib { + +namespace transformers { + +template +gsl::span AllocateBuffer(AllocatorPtr allocator, + BufferUniquePtr& buffer, + size_t elements, + bool fill = false, + T fill_value = T{}) { + size_t bytes = SafeInt(sizeof(T)) * elements; + void* data = allocator->Alloc(bytes); + BufferUniquePtr temp_buffer(data, BufferDeleter(allocator)); + buffer = std::move(temp_buffer); + T* first = reinterpret_cast(buffer.get()); + auto span = gsl::make_span(first, elements); + + if (fill) { + std::fill_n(first, elements, fill_value); + } + + return span; +} + +template +struct BeamSearchState : public IBeamSearchState { + void Init(AllocatorPtr allocator, + int batch_size, + int num_beams, + int vocab_size, + int sequence_length, + int max_length, + bool output_scores, + bool use_position) { + size_t batch_beam_size = SafeInt(batch_size) * num_beams; + + size_t next_token_size = SafeInt(batch_beam_size) * vocab_size; + this->next_token_logits = AllocateBuffer(allocator, next_token_logits_buffer_, next_token_size); + this->next_token_scores = AllocateBuffer(allocator, next_token_scores_buffer_, next_token_size); + + this->next_tokens = AllocateBuffer(allocator, next_tokens_buffer_, SafeInt(2) * batch_beam_size); + + this->next_indices = AllocateBuffer(allocator, next_indices_buffer_, SafeInt(2) * batch_beam_size); + + if (use_position) { + this->next_positions = AllocateBuffer(allocator, next_positions_buffer_, batch_beam_size); + } + + this->beam_scores = AllocateBuffer(allocator, beam_scores_buffer_, batch_beam_size); + + if (output_scores) { + size_t elements = SafeInt(max_length - sequence_length) * batch_size * num_beams * vocab_size; + this->scores = AllocateBuffer(allocator, scores_buffer_, elements); + this->remaining_scores = this->scores; + } + } + + private: + BufferUniquePtr next_token_logits_buffer_; + BufferUniquePtr next_token_scores_buffer_; + BufferUniquePtr next_tokens_buffer_; + BufferUniquePtr next_indices_buffer_; + BufferUniquePtr next_positions_buffer_; + BufferUniquePtr beam_scores_buffer_; + BufferUniquePtr scores_buffer_; +}; + +struct BeamSearchCpuState : public IBeamSearchCpuState { + Sequences sequences; + + void Init(AllocatorPtr allocator, size_t batch_beam_size, int max_length, int sequence_length, bool is_cuda) { + this->sequence_lengths = AllocateBuffer(allocator, sequence_lengths_buffer_, batch_beam_size); + + size_t sequences_bytes = SafeInt(2) * batch_beam_size * max_length; + this->sequences_space = AllocateBuffer(allocator, sequences_space_buffer_, sequences_bytes); + memset(this->sequences_space.data(), 0, this->sequences_space.size_bytes()); + + if (is_cuda) { + // buffers used by CUDA operator but not by CPU operator. + this->topk_scores = AllocateBuffer(allocator, topk_scores_buffer_, 2 * batch_beam_size); + this->topk_tokens = AllocateBuffer(allocator, topk_tokens_buffer_, 2 * batch_beam_size); + this->topk_indices = AllocateBuffer(allocator, topk_indices_buffer_, 2 * batch_beam_size); + this->final_beam_scores = AllocateBuffer(allocator, final_beam_scores_buffer_, batch_beam_size); + } + + this->sequences.Init(this->sequences_space, static_cast(batch_beam_size), sequence_length, max_length); + } + + // Copy input_ids to sequences[0] + void SetSequence(gsl::span input_ids_in_cpu, + size_t batch_beam_size, + int max_length, + int sequence_length) { + gsl::span sequences_0 = sequences_space; + for (size_t i = 0; i < batch_beam_size; i++) { + for (int j = 0; j < sequence_length; j++) { + const size_t index = SafeInt(i) * max_length + j; + sequences_0[index] = input_ids_in_cpu[SafeInt(i) * sequence_length + j]; + } + } + } + + private: + BufferUniquePtr final_beam_scores_buffer_; + BufferUniquePtr sequence_lengths_buffer_; + BufferUniquePtr topk_scores_buffer_; + BufferUniquePtr topk_tokens_buffer_; + BufferUniquePtr topk_indices_buffer_; + BufferUniquePtr sequences_space_buffer_; +}; + +// Base class of beam search implementation that is common for both GPT-2 and T5. +template +class BeamSearchBase { + public: + BeamSearchBase(OpKernelContextInternal& context, + const SessionState& decoder_session_state, + concurrency::ThreadPool* thread_pool, + void* cuda_stream, + IConsoleDumper* cuda_dumper, + BeamSearchParameters& params, + const BeamSearchDeviceHelper::TopkFunc& topk_func, + const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func) + : context_(context), + decoder_session_state_(decoder_session_state), + thread_pool_(thread_pool), + implicit_inputs_(context_.GetImplicitInputs()), + cuda_stream_(cuda_stream), + cuda_dumper_(cuda_dumper), + parameters_(¶ms), + cpu_allocator_(nullptr), + temp_space_allocator_(nullptr), + topk_func_(topk_func), + process_logits_func_(process_logits_func), + device_copy_func_(device_copy_func), + device_copy_int32_func_(device_copy_int32_func) { + parameters_->ParseFromInputs(&context); + + cpu_allocator_ = decoder_session_state.GetExecutionProviders() + .Get(onnxruntime::kCpuExecutionProvider) + ->GetAllocator(0, OrtMemTypeDefault); + } + + // Initialize by validating all the inputs, and allocating the output tensors. + Status Initialize(); + + // Validate inputs. + Status CheckInputs(const OpKernelContextInternal& context); + + protected: + // Process logits and append next tokens to sequences. + Status GenerateNextToken(const OrtValue& logits, + gsl::span& beam_next_tokens, + gsl::span& beam_indices, + BeamSearchState& beam_state, + BeamSearchCpuState& cpu_state, + int counter); + + // Calculate scores from logits, then apply filtering and select next token for each beam. + Status ProcessLogits(const OrtValue& logits, // logits output of subgraph + BeamSearchState& beam_state, + BeamSearchCpuState& cpu_state, + AllocatorPtr& allocator, + int counter); + + bool IsCuda() const { return cuda_stream_ != nullptr; } + + const IConsoleDumper* GetConsoleDumper() const { return IsCuda() ? cuda_dumper_ : &(cpu_dumper_); } + + OpKernelContextInternal& context_; + + const SessionState& decoder_session_state_; + + concurrency::ThreadPool* thread_pool_; + + const std::vector& implicit_inputs_; + + void* cuda_stream_; + + IConsoleDumper* cuda_dumper_; + CpuTensorConsoleDumper cpu_dumper_; + + BeamSearchParameters* parameters_; + + LogitsProcessorList logits_processors_; + + std::unique_ptr beam_scorer_; + + AllocatorPtr cpu_allocator_; + AllocatorPtr temp_space_allocator_; + + // Device specific functions + BeamSearchDeviceHelper::TopkFunc topk_func_; + BeamSearchDeviceHelper::ProcessLogitsFunc process_logits_func_; + BeamSearchDeviceHelper::DeviceCopyFunc device_copy_func_; + BeamSearchDeviceHelper::DeviceCopyFunc device_copy_int32_func_; +}; + +template +Status BeamSearchBase::CheckInputs(const OpKernelContextInternal& context) { + // Input shapes: + // input_ids : (batch_size, sequence_length) + // vocab_mask : (vocab_size) or nullptr + + const Tensor* input_ids = context.Input(0); + const auto& dims = input_ids->Shape().GetDims(); + if (dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'input_ids' is expected to have 2 dimensions, got ", dims.size()); + } + + const Tensor* vocab_mask = context.Input(8); + if (vocab_mask != nullptr) { // vocab_mask is optional + const auto& vocab_mask_dims = vocab_mask->Shape().GetDims(); + if (vocab_mask_dims.size() != 1) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'vocab_mask' is expected to have 1 dimension, got ", vocab_mask_dims.size()); + } + + // There is dependency on vocab_size parameter, which shall be set before calling this function. + if (static_cast(vocab_mask_dims[0]) != parameters_->vocab_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'vocab_mask' shape does not match with vocab_size, got ", vocab_mask_dims[0]); + } + + // store vocab mask in parameters. + parameters_->vocab_mask = vocab_mask->DataAsSpan(); + } + + const Tensor* prefix_vocab_mask = context.Input(9); + if (prefix_vocab_mask != nullptr) { + // prefix_vocab_mask is optional + const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims(); + if (vocab_mask_dims.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'prefix_vocab_mask' is expected to be 2 dimensions, got ", vocab_mask_dims.size()); + } + + // prefix_vocab_mask first dimension should be same as the first dimension of input_ids + if (static_cast(vocab_mask_dims[0]) != static_cast(dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input_ids and prefix_vocab_mask must have the same batch_size"); + } + + // There is dependency on vocab_size parameter, which shall be set before calling this function. + if (static_cast(vocab_mask_dims[1]) != parameters_->vocab_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'prefix_vocab_mask' shape[1] shall be vocab_size, got ", vocab_mask_dims[1]); + } + + // store prefix vocab mask in parameters. + parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan(); + } + + return Status::OK(); +} + +template +Status BeamSearchBase::Initialize() { + ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator_)); + +#define CHECK_SCALAR_INPUT(name, index, required) \ + auto* name##_tensor = context_.Input(index); \ + if (name##_tensor) { \ + if (!name##_tensor->Shape().IsScalar()) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \ + name##_tensor->Shape()); \ + } \ + } else if (required) { \ + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \ + } + + CHECK_SCALAR_INPUT(min_length, 1, false); + + CHECK_SCALAR_INPUT(max_length, 2, true); + + CHECK_SCALAR_INPUT(num_beams, 3, true); + + CHECK_SCALAR_INPUT(num_return_sequences, 4, true); + + CHECK_SCALAR_INPUT(temperature, 5, true); + + CHECK_SCALAR_INPUT(length_penalty, 6, true); + + ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, + "'num_return_sequences' has to be smaller or equal to 'num_beams'."); + + ORT_RETURN_IF_ERROR(CheckInputs(context_)); + + // This flag will be updated later when the scores output exists. + parameters_->output_scores = false; + + if (!IsCuda()) { + // Logits processor is used in CPU only. In CUDA, cuda kernels are used instead. + // Initialize processors after CheckInputs so that parameters_->vocab_mask is ready. + logits_processors_.Init(*parameters_); + } + + return Status::OK(); +} + +template +Status BeamSearchBase::ProcessLogits( + const OrtValue& logits, + BeamSearchState& beam_state, + BeamSearchCpuState& cpu_state, + AllocatorPtr& allocator, + int counter) { + return process_logits_func_(logits, &beam_state, &cpu_state, &(cpu_state.sequences), allocator, + thread_pool_, &logits_processors_, beam_scorer_.get(), + parameters_, counter, cuda_stream_, GetConsoleDumper()); +} + +template +Status BeamSearchBase::GenerateNextToken( + const OrtValue& logits, + gsl::span& beam_next_tokens, + gsl::span& beam_indices, + BeamSearchState& beam_state, + BeamSearchCpuState& cpu_state, + int counter) { + // Process logits to get next token scores + ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, cpu_state, temp_space_allocator_, counter)); + + gsl::span& beam_scores = beam_scorer_->GetNextScores(); + // It is optional to clone beam_scores. Change it to use same buffer also works for CPU: + // beam_state.beam_scores = beam_scores + // Here we make a copy to reduce the coupling with little cost (the buffer size is small). + ORT_RETURN_IF_ERROR(device_copy_func_(beam_state.beam_scores, + beam_scores, + cuda_stream_, + DeviceCopyDirection::hostToDevice)); + + beam_next_tokens = beam_scorer_->GetNextTokens(); + beam_indices = beam_scorer_->GetNextIndices(); + +#ifdef DEBUG_BEAM_SEARCH + cpu_dumper_.Print("beam_scores from scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams); + cpu_dumper_.Print("beam_next_tokens", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams); + cpu_dumper_.Print("beam_indices", beam_indices.data(), parameters_->batch_size, parameters_->num_beams); +#endif + + cpu_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens); + +#ifdef DEBUG_BEAM_SEARCH + cpu_state.sequences.PrintSequences(&cpu_dumper_); +#endif + return Status::OK(); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h new file mode 100644 index 0000000000..bd5d3fcc3a --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/transformers/beam_search_impl_base.h" + +namespace onnxruntime { +namespace contrib { + +namespace transformers { + +// Beam search implementation for GPT-2 model. +template +class BeamSearchGpt : public BeamSearchBase { + public: + BeamSearchGpt(OpKernelContextInternal& context, + const SessionState& decoder_session_state, + GptSubgraph& gpt_subgraph, + concurrency::ThreadPool* thread_pool, + void* cuda_stream, + IConsoleDumper* cuda_dumper, + BeamSearchParameters& params, + const BeamSearchDeviceHelper::CreateGptInputsFunc& create_inputs_func, + const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + const BeamSearchDeviceHelper::TopkFunc& topk_func, + const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func, + const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + const BeamSearchDeviceHelper::UpdateGptFeedsFunc& update_feeds_func) + : BeamSearchBase(context, decoder_session_state, thread_pool, + cuda_stream, cuda_dumper, params, + topk_func, process_logits_func, device_copy_func, device_copy_int32_func), + gpt_subgraph_(gpt_subgraph), + create_inputs_func_(create_inputs_func), + add_to_feeds_func_(add_to_feeds_func), + init_beam_state_func_(init_beam_state_func), + update_feeds_func_(update_feeds_func) { + } + + // Execute beam search in iterations util stopping criteria is reached. + // In each iteration, GPT subgraph is called, and next token for each sequence is generated. + Status Execute(const FeedsFetchesManager& feeds_fetches_manager); + + private: + // Prepare the inputs for first inference of subgraph + Status CreateInitialFeeds(gsl::span& sequence_lengths, + OrtValue& expanded_input_ids, + std::vector& feeds, + IAllocatorUniquePtr& buffer); + + // Update the input for next iteration. + Status UpdateFeeds( + const std::vector& last_outputs, + std::vector& next_inputs, + int current_length, + OrtValue& position_ids, + gsl::span beam_next_tokens, + gsl::span beam_indices); + + GptSubgraph& gpt_subgraph_; + + // Device specific functions + BeamSearchDeviceHelper::CreateGptInputsFunc create_inputs_func_; + BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_; + BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_func_; + BeamSearchDeviceHelper::UpdateGptFeedsFunc update_feeds_func_; +}; + +template +Status BeamSearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths, + OrtValue& expanded_input_ids, + std::vector& feeds, + IAllocatorUniquePtr& buffer) { + const OrtValue* input_ids_value = this->context_.GetInputOrtValue(0); + const Tensor& input_ids = input_ids_value->Get(); + return gpt_subgraph_.CreateInitialFeeds(input_ids, + this->implicit_inputs_, + this->parameters_->num_beams, + this->parameters_->pad_token_id, + sequence_lengths, + expanded_input_ids, + feeds, + this->create_inputs_func_, + this->add_to_feeds_func_, + buffer); +} + +template +Status BeamSearchGpt::UpdateFeeds( + const std::vector& last_outputs, + std::vector& next_inputs, + int current_length, + OrtValue& position_ids, + gsl::span beam_next_tokens, + gsl::span beam_indices) { + return update_feeds_func_(this->temp_space_allocator_, + this->cuda_stream_, + last_outputs, + next_inputs, + current_length, + position_ids, + beam_next_tokens, + beam_indices, + this->parameters_->num_beams, + this->GetConsoleDumper()); +} + +template +Status BeamSearchGpt::Execute(const FeedsFetchesManager& feeds_fetches_manager) { + auto status = Status::OK(); + const BeamSearchParameters* parameters = this->parameters_; + int64_t sequences_dims[] = {parameters->batch_size, parameters->num_return_sequences, parameters->max_length}; + TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0])); + Tensor* output_sequences = this->context_.Output(0, sequences_shape); + + int64_t sequences_scores_dims[] = {parameters->batch_size, parameters->num_return_sequences}; + TensorShape sequences_scores_shape(&sequences_scores_dims[0], 2); + Tensor* output_sequences_scores = this->context_.Output(1, sequences_scores_shape); + + int64_t scores_dims[] = { + static_cast(parameters->max_length) - static_cast(parameters->sequence_length), + parameters->batch_size, parameters->num_beams, parameters->vocab_size}; + TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0])); + Tensor* output_scores = this->context_.Output(2, scores_shape); + + // Update the flag to indicate whether scores exists in output + this->parameters_->output_scores = (output_scores != nullptr); + + std::vector feeds; + // TODO(tianleiwu): allocate fetches. use ping-pong buffers for past state. + std::vector fetches; + + // Initialize resources + onnxruntime::OrtStlAllocator hypothesis_score_allocator(this->cpu_allocator_); + onnxruntime::OrtStlAllocator beam_hyps_allocator(this->cpu_allocator_); + this->beam_scorer_ = std::make_unique(static_cast(parameters->batch_size), + static_cast(parameters->num_beams), + static_cast(parameters->max_length), + parameters->length_penalty, + parameters->early_stopping, + static_cast(parameters->num_return_sequences), + parameters->pad_token_id, + parameters->eos_token_id, + hypothesis_score_allocator, + beam_hyps_allocator); + this->beam_scorer_->Initialize(this->cpu_allocator_, parameters->sequence_length); + + BeamSearchCpuState cpu_state; + cpu_state.Init(this->cpu_allocator_, + static_cast(parameters->BatchBeamSize()), + parameters->max_length, + parameters->sequence_length, + this->IsCuda()); + + // buffer in GPU for input_ids, position_ids and attention_mask + IAllocatorUniquePtr buffer; + OrtValue expanded_input_ids_in_cpu; + ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer)); + + BeamSearchState beam_state; + constexpr bool use_position = true; + beam_state.Init(this->temp_space_allocator_, + parameters->batch_size, + parameters->num_beams, + parameters->vocab_size, + parameters->sequence_length, + parameters->max_length, + parameters->output_scores, + use_position); + + init_beam_state_func_(&beam_state, + cpu_state.sequence_lengths, + parameters->batch_size, + parameters->num_beams, + this->cuda_stream_); + + gsl::span input_ids = expanded_input_ids_in_cpu.Get().DataAsSpan(); + cpu_state.SetSequence(input_ids, + static_cast(parameters->BatchBeamSize()), + parameters->max_length, + parameters->sequence_length); + +#ifdef DEBUG_BEAM_SEARCH + const IConsoleDumper* dumper = this->GetConsoleDumper(); + dumper->Print("input_ids", feeds[0]); + dumper->Print("position_ids", feeds[1]); + dumper->Print("attention_mask", feeds[2]); +#endif + + // Position ids for all iterations except the first. It uses memory buffer owned by next_positions. + OrtValue position_ids; + int64_t dims[] = {parameters->BatchBeamSize(), 1}; + TensorShape shape(&dims[0], 2); + Tensor::InitOrtValue(DataTypeImpl::GetType(), + shape, + beam_state.next_positions.data(), + this->temp_space_allocator_->Info(), + position_ids); + + int current_length = parameters->sequence_length; + int iteration_counter = 0; + while (current_length < parameters->max_length) { + iteration_counter++; +#ifdef DEBUG_BEAM_SEARCH + auto cur_len = std::to_string(current_length); + dumper->Print("***CurrentLength", cur_len, true); +#endif + + status = utils::ExecuteSubgraph(this->decoder_session_state_, + feeds_fetches_manager, + feeds, + fetches, + {}, + ExecutionMode::ORT_SEQUENTIAL, + this->context_.GetTerminateFlag(), + this->context_.Logger()); + + ORT_RETURN_IF_ERROR(status); + + const OrtValue& logits = fetches[0]; + gsl::span beam_next_tokens; + gsl::span beam_indices; + ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits, + beam_next_tokens, + beam_indices, + beam_state, + cpu_state, + iteration_counter)); + + // When all batches are finished, stop earlier to avoid wasting computation. + if (this->beam_scorer_->IsDone()) { + break; + } + + // Increase sequence length after a new token is generated. + ++current_length; + + // Prepare inputs for next round of subgraph call. + if (current_length < parameters->max_length) { + ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length, + position_ids, + beam_next_tokens.as_span(), + beam_indices.as_span())); + } + fetches.clear(); + } + + gsl::span final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size()); + if (this->IsCuda()) { + ORT_RETURN_IF_ERROR(this->device_copy_func_(cpu_state.final_beam_scores, + final_beam_scores, + nullptr, + DeviceCopyDirection::deviceToHost)); + final_beam_scores = gsl::make_span(cpu_state.final_beam_scores.data(), + cpu_state.final_beam_scores.size()); + } + + this->beam_scorer_->Finalize(&(cpu_state.sequences), + final_beam_scores, + output_sequences, + output_sequences_scores); + + // Output per token scores + if (output_scores != nullptr) { + gsl::span target = output_scores->MutableDataAsSpan(); + gsl::span source = gsl::span(beam_state.scores.data(), beam_state.scores.size()); + assert(target.length() == source.length()); + ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); + } + + return status; +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h new file mode 100644 index 0000000000..491defe314 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -0,0 +1,307 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/transformers/beam_search_shared.h" // for DEBUG_BEAM_SEARCH +#include "contrib_ops/cpu/transformers/beam_search_impl_base.h" +#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" +#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" + +namespace onnxruntime { +namespace contrib { + +namespace transformers { + +// Beam search implementation for T5 model. +template +class BeamSearchT5 : public BeamSearchBase { + public: + BeamSearchT5(OpKernelContextInternal& context, + const SessionState& encoder_session_state, + const SessionState& decoder_session_state, + T5EncoderSubgraph& encoder_subgraph, + T5DecoderSubgraph& decoder_subgraph, + concurrency::ThreadPool* thread_pool, + void* cuda_stream, + IConsoleDumper* cuda_dumper, + BeamSearchParameters& params, + const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + const BeamSearchDeviceHelper::TopkFunc& topk_func, + const BeamSearchDeviceHelper::ProcessLogitsFunc& process_logits_func, + const BeamSearchDeviceHelper::InitBeamStateFunc& init_beam_state_func, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_func, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func, + const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc& update_decoder_feeds_func) + : BeamSearchBase(context, decoder_session_state, thread_pool, + cuda_stream, cuda_dumper, params, + topk_func, process_logits_func, device_copy_func, device_copy_int32_func), + encoder_session_state_(encoder_session_state), + encoder_subgraph_(encoder_subgraph), + decoder_subgraph_(decoder_subgraph), + add_to_feeds_func_(add_to_feeds_func), + init_beam_state_func_(init_beam_state_func), + create_encoder_inputs_func_(create_encoder_inputs_func), + update_decoder_feeds_func_(update_decoder_feeds_func) { + } + + // Execute beam search in iterations util stopping criteria is reached. + Status Execute(const FeedsFetchesManager& encoder_feeds_fetches_manager, + const FeedsFetchesManager& decoder_feeds_fetches_manager); + + private: + const SessionState& encoder_session_state_; + + T5EncoderSubgraph& encoder_subgraph_; + T5DecoderSubgraph& decoder_subgraph_; + + // Device specific functions + BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_; + BeamSearchDeviceHelper::InitBeamStateFunc init_beam_state_func_; + + BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_; + BeamSearchDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_func_; +}; + +template +Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches_manager, + const FeedsFetchesManager& decoder_feeds_fetches_manager) { + auto status = Status::OK(); + + const BeamSearchParameters* parameters = this->parameters_; + ORT_ENFORCE(parameters->sequence_length == 1); + + // Allocate output tensors. + int64_t sequences_dims[] = {parameters->batch_size, parameters->num_return_sequences, parameters->max_length}; + TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0])); + Tensor* output_sequences = this->context_.Output(0, sequences_shape); + + int64_t sequences_scores_dims[] = {parameters->batch_size, parameters->num_return_sequences}; + constexpr int64_t dims = sizeof(sequences_scores_dims) / sizeof(sequences_scores_dims[0]); + TensorShape sequences_scores_shape(&sequences_scores_dims[0], dims); + Tensor* output_sequences_scores = this->context_.Output(1, sequences_scores_shape); + + int64_t scores_dims[] = { + static_cast(parameters->max_length) - static_cast(parameters->sequence_length), + parameters->batch_size, parameters->num_beams, parameters->vocab_size}; + TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0])); + Tensor* output_scores = this->context_.Output(2, scores_shape); + + // Update the flag to indicate whether scores exists in output + this->parameters_->output_scores = (output_scores != nullptr); + + // ------------------------------------ + // Call encoder subgraph. + // ------------------------------------ + std::vector encoder_feeds; + std::vector encoder_fetches; + + const OrtValue* encoder_input_ids_value = this->context_.GetInputOrtValue(0); + const Tensor& encoder_input_ids = encoder_input_ids_value->Get(); + + BeamSearchCpuState cpu_state; + cpu_state.Init(this->cpu_allocator_, + static_cast(parameters->BatchBeamSize()), + parameters->max_length, + parameters->sequence_length, + this->IsCuda()); + + IAllocatorUniquePtr buffer; + OrtValue expanded_decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state + ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds( + encoder_input_ids, + this->implicit_inputs_, + parameters->num_beams, + parameters->pad_token_id, + parameters->decoder_start_token_id, + encoder_feeds, + this->create_encoder_inputs_func_, + this->add_to_feeds_func_, + buffer, + expanded_decoder_input_ids)); + + ORT_RETURN_IF_ERROR(utils::ExecuteSubgraph(this->encoder_session_state_, + encoder_feeds_fetches_manager, + encoder_feeds, + encoder_fetches, + {}, + ExecutionMode::ORT_SEQUENTIAL, + this->context_.GetTerminateFlag(), + this->context_.Logger())); + +#ifdef DEBUG_BEAM_SEARCH + const IConsoleDumper* dumper = this->GetConsoleDumper(); + for (size_t i = 0; i < encoder_feeds.size(); i++) { + dumper->Print("encoder_feeds", static_cast(i), true); + dumper->Print("", encoder_feeds[i]); + } + + for (int i = 0; i <= T5EncoderSubgraph::kFirstPresentOutputIndex; i++) { + dumper->Print("encoder_fetches", i, true); + dumper->Print("", encoder_fetches[i]); + } +#endif + + // ------------------------------------ + // Initialize resources + // ------------------------------------ + + // Copy expanded_decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam. + cpu_state.SetSequence(expanded_decoder_input_ids.Get().DataAsSpan(), + static_cast(parameters->BatchBeamSize()), + parameters->max_length, + parameters->sequence_length); + + onnxruntime::OrtStlAllocator hypothesis_score_allocator(this->cpu_allocator_); + onnxruntime::OrtStlAllocator beam_hyps_allocator(this->cpu_allocator_); + this->beam_scorer_ = std::make_unique(static_cast(parameters->batch_size), + static_cast(parameters->num_beams), + static_cast(parameters->max_length), + parameters->length_penalty, + parameters->early_stopping, + static_cast(parameters->num_return_sequences), + parameters->pad_token_id, + parameters->eos_token_id, + hypothesis_score_allocator, + beam_hyps_allocator); + this->beam_scorer_->Initialize(this->cpu_allocator_, parameters->sequence_length); + + BeamSearchState beam_state; + constexpr bool use_position = false; + beam_state.Init(this->temp_space_allocator_, + parameters->batch_size, + parameters->num_beams, + parameters->vocab_size, + parameters->sequence_length, + parameters->max_length, + parameters->output_scores, + use_position); + + init_beam_state_func_(&beam_state, + cpu_state.sequence_lengths, + parameters->batch_size, + parameters->num_beams, + this->cuda_stream_); + + // ------------------------------------------------------------------------------ + // Generate next token from logits output from encoder, and initialize decoder inputs. + // ------------------------------------------------------------------------------ + gsl::span beam_next_tokens; + gsl::span beam_indices; + + int iteration_counter = 0; + std::vector decoder_feeds; + int current_length = parameters->sequence_length; + if (current_length + 1 < parameters->max_length) { + ++iteration_counter; + ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0], + beam_next_tokens, + beam_indices, + beam_state, + cpu_state, + iteration_counter)); + ++current_length; // Increase sequence length after a new token is generated. + ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(beam_next_tokens.as_span(), + this->implicit_inputs_, + encoder_feeds, + encoder_fetches, + decoder_feeds, + this->device_copy_int32_func_, + this->cuda_stream_)); + } + + // TODO(tianleiwu): allocate fetches. use ping-pong buffers for past state. + std::vector decoder_fetches; + while (current_length < parameters->max_length) { + iteration_counter++; +#ifdef DEBUG_BEAM_SEARCH + auto cur_len = std::to_string(current_length); + dumper->Print("***CurrentLength", cur_len, true); + + for (int i = 0; i <= T5DecoderSubgraph::kFirstPastInputIndex; i++) { + dumper->Print("decoder_feeds", i, true); + dumper->Print("", decoder_feeds[i]); + } +#endif + + status = utils::ExecuteSubgraph(this->decoder_session_state_, + decoder_feeds_fetches_manager, + decoder_feeds, + decoder_fetches, + {}, + ExecutionMode::ORT_SEQUENTIAL, + this->context_.GetTerminateFlag(), + this->context_.Logger()); + + ORT_RETURN_IF_ERROR(status); + +#ifdef DEBUG_BEAM_SEARCH + for (int i = 0; i <= T5DecoderSubgraph::kFirstPresentOutputIndex; i++) { + dumper->Print("decoder_fetches", i, true); + dumper->Print("", decoder_fetches[i]); + } +#endif + + const OrtValue& logits = decoder_fetches[0]; + ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits, + beam_next_tokens, + beam_indices, + beam_state, + cpu_state, + iteration_counter)); + + // When all batches are finished, stop earlier to avoid wasting computation. + if (this->beam_scorer_->IsDone()) { + break; + } + + // Increase sequence length after a new token is generated. + ++current_length; + + // Prepare inputs for next round of subgraph call. + if (current_length < parameters->max_length) { + const int num_present_outputs = 2 * parameters->num_layers; // number of outputs with name like present_* + ORT_RETURN_IF_ERROR(this->update_decoder_feeds_func_( + this->temp_space_allocator_, + this->cuda_stream_, + decoder_fetches, + decoder_feeds, + num_present_outputs, + beam_next_tokens.as_span(), + beam_indices.as_span(), + parameters->num_beams, + this->GetConsoleDumper())); + } + decoder_fetches.clear(); + } + + gsl::span final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size()); + if (this->IsCuda()) { + ORT_RETURN_IF_ERROR(this->device_copy_func_(cpu_state.final_beam_scores, + final_beam_scores, + nullptr, + DeviceCopyDirection::deviceToHost)); + final_beam_scores = gsl::make_span(cpu_state.final_beam_scores.data(), + cpu_state.final_beam_scores.size()); + } + + this->beam_scorer_->Finalize(&(cpu_state.sequences), + final_beam_scores, + output_sequences, + output_sequences_scores); + + // Output per token scores + if (output_scores != nullptr) { + gsl::span target = output_scores->MutableDataAsSpan(); + gsl::span source = gsl::span(beam_state.scores.data(), beam_state.scores.size()); + assert(target.length() == source.length()); + ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice)); + } + + return status; +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 4fc9f2f383..d5375e6be5 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "beam_search_parameters.h" + +#include "contrib_ops/cpu/transformers/beam_search_parameters.h" namespace onnxruntime { namespace contrib { @@ -17,10 +18,11 @@ Status BeamSearchParameters::Validate() const { } void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { - model_type = static_cast(info.GetAttrOrDefault("model_type", 0)); + model_type = static_cast(info.GetAttrOrDefault("model_type", IBeamSearchParameters::kModelTypeGpt)); early_stopping = info.GetAttrOrDefault("early_stopping", 0) == 1; eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); + decoder_start_token_id = static_cast(info.GetAttrOrDefault("decoder_start_token_id", -1)); no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0)); } @@ -30,25 +32,32 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { const auto& dims = input_ids->Shape().GetDims(); ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size()); batch_size = static_cast(dims[0]); - sequence_length = static_cast(dims[1]); + + // For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1 + sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast(dims[1]) : 1; auto* max_length_tensor = context->Input(1); max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : kMaxSequenceLength; - ORT_ENFORCE(max_length > sequence_length, "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")"); - ORT_ENFORCE(max_length <= kMaxSequenceLength, "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength); + ORT_ENFORCE(max_length > sequence_length, + "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")"); + ORT_ENFORCE(max_length <= kMaxSequenceLength, + "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength); auto* min_length_tensor = context->Input(2); min_length = min_length_tensor ? static_cast(*min_length_tensor->Data()) : 0; auto* num_beams_tensor = context->Input(3); num_beams = num_beams_tensor ? static_cast(*num_beams_tensor->Data()) : 1; - // TODO: limit num_beams > 1 when we can have another operator for greedy search. - ORT_ENFORCE(num_beams >= 1 && num_beams <= kMaxNumBeams, "num_beams shall be a positive integer no more than ", kMaxNumBeams, ", got ", num_beams); + // TODO(tianleiwu): limit num_beams > 1 when we can have another operator for greedy search. + ORT_ENFORCE(num_beams >= 1 && num_beams <= kMaxNumBeams, + "num_beams shall be a positive integer no more than ", kMaxNumBeams, ", got ", num_beams); auto* num_return_sequences_tensor = context->Input(4); - num_return_sequences = num_return_sequences_tensor ? static_cast(*num_return_sequences_tensor->Data()) : 1; - ORT_ENFORCE(num_return_sequences >= 1, "num_return_sequences shall be a positive integer, got ", num_return_sequences); - ORT_ENFORCE(num_beams >= num_return_sequences, "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")"); + num_return_sequences = num_return_sequences_tensor ? *num_return_sequences_tensor->Data() : 1; + ORT_ENFORCE(num_return_sequences >= 1, + "num_return_sequences shall be a positive integer, got ", num_return_sequences); + ORT_ENFORCE(num_beams >= num_return_sequences, + "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")"); auto* temperature_tensor = context->Input(5); temperature = temperature_tensor ? static_cast(*temperature_tensor->Data()) : 1; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index caaaf53751..5932bf54c9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -4,7 +4,7 @@ #pragma once #include "core/common/common.h" #include "core/framework/op_kernel.h" -#include "beam_search_shared.h" +#include "contrib_ops/cpu/transformers/beam_search_shared.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc index 77a900c506..81ebb74830 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.cc @@ -10,7 +10,7 @@ #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" #include "core/providers/cpu/rnn/rnn_helpers.h" -#include "beam_search_scorer.h" +#include "contrib_ops/cpu/transformers/beam_search_scorer.h" namespace onnxruntime { namespace contrib { @@ -69,7 +69,7 @@ void BeamHypotheses::Output( } // Since pop get the worst sequence, so output it in the reverse order. - // The frist (worst) beam shall be put at the last position among top_k sequences. + // The first (worst) beam shall be put at the last position among top_k sequences. int index = top_k - 1; while (!beams_.empty()) { auto item = beams_.top(); @@ -124,7 +124,7 @@ void BeamSearchScorer::Initialize(AllocatorPtr& allocator, int sequence_length) ORT_ENFORCE(next_beam_scores_.empty()); // Make sure this is called only once. size_t batch_beam_size = batch_size_ * num_beams_; - constexpr bool no_fill = false; // do not fill values after allocation + constexpr bool no_fill = false; // Do not fill values after allocation done_ = Allocate(allocator, batch_size_, done_ptr_, no_fill); std::fill_n(done_.data(), done_.size(), false); @@ -134,8 +134,8 @@ void BeamSearchScorer::Initialize(AllocatorPtr& allocator, int sequence_length) next_beam_indices_ = Allocate(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill); // Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length. - size_t buffer_per_beam = (SafeInt(max_length_) * (max_length_ + 1) - SafeInt(sequence_length - 1) * sequence_length) / 2; - hypothesis_buffer_length_ = batch_beam_size * buffer_per_beam; + size_t per_beam = (SafeInt(max_length_) * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2; + hypothesis_buffer_length_ = batch_beam_size * per_beam; hypothesis_buffer_ = Allocate(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill); } @@ -155,7 +155,8 @@ void BeamSearchScorer::Process(ISequences* sequences, for (size_t batch = 0; batch < batch_size_; batch++) { BeamHypotheses& beam_hyp = beam_hyps_[batch]; if (done_[batch]) { - ORT_ENFORCE(beam_hyp.Size() >= gsl::narrow_cast(num_beams_), "Batch can only be done if all beams have been generated"); + ORT_ENFORCE(beam_hyp.Size() >= gsl::narrow_cast(num_beams_), + "Batch can only be done if all beams have been generated"); // Pad the batch. for (size_t j = 0; j < num_beams_; j++) { @@ -203,7 +204,7 @@ void BeamSearchScorer::Process(ISequences* sequences, } ORT_ENFORCE(beam_idx == num_beams_); - ORT_ENFORCE(hypothesis_buffer_offset_ <= batch_size_ * num_beams_ * max_length_); + ORT_ENFORCE(hypothesis_buffer_offset_ <= hypothesis_buffer_length_); // Check if we are done so that we can save a pad step if all(done) if (!done_[batch]) { @@ -258,7 +259,8 @@ void BeamSearchScorer::Finalize(ISequences* sequences, BeamHypotheses& beam_hyp = beam_hyps_[batch_index]; const size_t num_return_sequences = num_beam_hyps_to_keep_; - auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, num_return_sequences * max_length_); + auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, + num_return_sequences * max_length_); if (output_sequence_scores != nullptr) { batch_sequence_score = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences); @@ -274,4 +276,4 @@ void BeamSearchScorer::Finalize(ISequences* sequences, } // namespace transformers } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h index 0293e51d3a..70f944e533 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_scorer.h @@ -12,8 +12,8 @@ #include "core/framework/utils.h" #include "core/providers/cpu/tensor/utils.h" #include "core/providers/cpu/containers.h" -#include "sequences.h" -#include "beam_search_shared.h" +#include "contrib_ops/cpu/transformers/sequences.h" +#include "contrib_ops/cpu/transformers/beam_search_shared.h" namespace onnxruntime { namespace contrib { @@ -52,7 +52,7 @@ class BeamHypotheses { // Output results. Note that it will clear all beams. void Output(int top_k, // number of sequences to return int max_length, // max sequence length - gsl::span& sequences, // buffer filled with pad token ID, with shape (num_return_sequences, max_length) + gsl::span& sequences, // buffer with pad token, shape (num_return_sequences, max_length) gsl::span& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences) private: @@ -60,7 +60,9 @@ class BeamHypotheses { float length_penalty_; bool early_stopping_; float worst_score_; - std::priority_queue, HypothesisScoreCompare> beams_; // min-heap for top k + + // Min-heap for top k + std::priority_queue, HypothesisScoreCompare> beams_; }; class BeamSearchScorer : public IBeamScorer { @@ -103,7 +105,7 @@ class BeamSearchScorer : public IBeamScorer { int eos_token_id_; IAllocatorUniquePtr done_ptr_; // Allocated buffer for done_ - gsl::span done_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size). + gsl::span done_; // Flags indicates whether each batch is finished or not. Shape is (batch_size). IAllocatorUniquePtr next_beam_scores_ptr_; gsl::span next_beam_scores_; @@ -117,11 +119,11 @@ class BeamSearchScorer : public IBeamScorer { IAllocatorUniquePtr hypothesis_buffer_ptr_; // Allocated buffer to hold all hypotheses gsl::span hypothesis_buffer_; // Span of the allocated buffer size_t hypothesis_buffer_length_; // Total number of elements - size_t hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer. + size_t hypothesis_buffer_offset_; // Offset of available buffer, or length of used buffer. onnxruntime::FastAllocVector beam_hyps_; }; } // namespace transformers } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h index 3cf4a48f55..2643ca3c1b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_shared.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "gsl/gsl" @@ -23,10 +26,10 @@ struct IBeamSearchState { gsl::span next_token_scores; // shape (batch_size, num_beams * vocab_size) gsl::span next_tokens; // shape (batch_size, 2 * num_beams) gsl::span next_indices; // shape (batch_size, 2 * num_beams) - gsl::span next_positions; // shape (batch_size, num_beams). Next position value for position_ids. + gsl::span next_positions; // shape (batch_size, num_beams), empty for T5. Next position for position_ids. gsl::span beam_scores; // shape (batch_size, num_beams) gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) - gsl::span remaining_scores; // portion of scores that is avaiable for appending next token scores. + gsl::span remaining_scores; // portion of scores that is available for appending next token scores. }; struct IBeamSearchCpuState { @@ -72,10 +75,14 @@ class IBeamScorer { }; struct IBeamSearchParameters { + static constexpr int kModelTypeGpt = 0; + static constexpr int kModelTypeT5 = 1; + // Parameters from node attributes - int model_type; + int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5 int eos_token_id; int pad_token_id; + int decoder_start_token_id; int no_repeat_ngram_size; bool early_stopping; @@ -88,7 +95,7 @@ struct IBeamSearchParameters { float length_penalty; float repetition_penalty; int batch_size; // deduce from first dimension of input_ids - int sequence_length; // deduce from second dimension of input_ids + int sequence_length; // deduce from second dimension of input_ids of GPT-2 or decoder_input_ids of T5 gsl::span vocab_mask; gsl::span prefix_vocab_mask; @@ -128,4 +135,4 @@ class IConsoleDumper { } // namespace transformers } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc index 7efc4726e1..9cb634512f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.cc @@ -70,12 +70,17 @@ void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) { void DumpCpuTensor(const char* name, const Tensor& tensor) { const auto& shape = tensor.Shape(); + if (nullptr != name) { + std::cout << std::string(name) << std::endl; + } + std::cout << "Shape:" << shape << std::endl; + size_t num_dims = shape.NumDimensions(); if (num_dims >= 3) { int dim0 = static_cast(shape.SizeToDimension(num_dims - 2)); int dim1 = static_cast(shape[num_dims - 2]); int dim2 = static_cast(shape[num_dims - 1]); - DumpCpuTensor(name, tensor, dim0, dim1, dim2); + DumpCpuTensor(nullptr, tensor, dim0, dim1, dim2); return; } @@ -85,7 +90,7 @@ void DumpCpuTensor(const char* name, const Tensor& tensor) { num_rows = static_cast(shape[0]); } size_t row_size = num_items / num_rows; - DumpCpuTensor(name, tensor, static_cast(num_rows), static_cast(row_size)); + DumpCpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { @@ -209,4 +214,4 @@ void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const } // namespace transformers } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h index d3a936da34..5e5f39954c 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/dump_tensor.h @@ -4,7 +4,7 @@ #pragma once #include #include "core/framework/tensorprotoutils.h" -#include "beam_search_shared.h" +#include "contrib_ops/cpu/transformers/beam_search_shared.h" namespace onnxruntime { namespace contrib { @@ -30,4 +30,4 @@ class CpuTensorConsoleDumper : public IConsoleDumper { } // namespace transformers } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc b/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc deleted file mode 100644 index 52e35d006b..0000000000 --- a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.cc +++ /dev/null @@ -1,248 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/framework/framework_common.h" -#include "core/framework/session_state.h" -#include "core/framework/tensorprotoutils.h" -#include "core/framework/utils.h" -#include "core/providers/cpu/tensor/utils.h" -#include "gsl/gsl" -#include "gpt_subgraph.h" -#include "dump_tensor.h" - -using namespace ONNX_NAMESPACE; -using namespace onnxruntime::common; - -namespace onnxruntime { -namespace contrib { -namespace transformers { - -GptSubgraph::GptSubgraph( - const onnxruntime::Node& node_in, - const std::string& attribute_name, - const GraphViewer& subgraph_in) - : node(node_in), attribute(attribute_name), subgraph(subgraph_in), allocator_(nullptr), is_output_float16_(false) { - num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); - - auto& subgraph_inputs = subgraph.GetInputs(); - auto& subgraph_outputs = subgraph.GetOutputs(); - - // inputs: input_ids, position_ids, attention_mask, past_0, past_1, ... - // outputs: logits, present_0, present_1, ... - num_subgraph_inputs = static_cast(subgraph_inputs.size()); - num_subgraph_outputs = static_cast(subgraph_outputs.size()); - - // CheckSubgraph will verify inputs and outputs later. - subgraph_input_names.reserve(num_subgraph_inputs); - for (int i = 0; i < num_subgraph_inputs; ++i) { - subgraph_input_names.push_back(subgraph_inputs[i]->Name()); - } - - subgraph_output_names.reserve(num_subgraph_outputs); - for (int i = 0; i < num_subgraph_outputs; ++i) { - subgraph_output_names.push_back(subgraph_outputs[i]->Name()); - } -} - -Status GptSubgraph::Validate(const std::vector& subgraph_inputs, - const std::vector& subgraph_outputs) { - ORT_RETURN_IF(num_subgraph_outputs <= 1, - "Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in inputs and outputs)."); - - ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2, - "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2"); - - ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ", - subgraph_inputs[0]->Name()); - ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ", - subgraph_inputs[1]->Name()); - ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ", - subgraph_inputs[2]->Name()); - ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ", - subgraph_inputs[3]->Name()); - - // Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads. - const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); - ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ", - past_shape->dim_size()); - - ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2, - "subgraph past state dimension 0 shall have length of 2"); - - ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0, - "subgraph past state dimension 2 shall have a positive value for number of heads"); - - ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0, - "subgraph past state dimension 4 shall have a positive value for hidden size per head"); - - // check subgraph outputs - ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ", - subgraph_outputs[0]->Name()); - - ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ", - subgraph_outputs[1]->Name()); - - // Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size. - const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); - ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ", - logits_shape->dim_size()); - - ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, - "subgraph past state dimension 2 shall have a positive value for vocabulary size"); - - // Save parameters related to the subgraph. - num_heads = static_cast(past_shape->dim(2).dim_value()); - head_size = static_cast(past_shape->dim(4).dim_value()); - vocab_size = static_cast(logits_shape->dim(2).dim_value()); - num_layers = static_cast(subgraph_outputs.size()) - 1; - - ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, - "subgraph input 0 (input_ids) shall have int32 type"); - ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, - "subgraph input 1 (position_ids) shall have int32 type"); - ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32, - "subgraph input 2 (attention_mask) shall have int32 type"); - - auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type(); - ORT_RETURN_IF(output_type != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && output_type != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16, - "subgraph output 0 (logits) shall be float or float16 data type"); - - ORT_RETURN_IF(subgraph_inputs[3]->TypeAsProto()->tensor_type().elem_type() != output_type, - "subgraph input 3 (past_0) shall shall have same data type of logits output"); - ORT_RETURN_IF(subgraph_outputs[1]->TypeAsProto()->tensor_type().elem_type() != output_type, - "subgraph output 1 (present_0) shall shall have same data type of logits output"); - - is_output_float16_ = (output_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16); - - return Status::OK(); -} - -Status GptSubgraph::Setup(const SessionState& session_state, - const SessionState& subgraph_session_state) { - session_state_ = &session_state; - subgraph_session_state_ = &subgraph_session_state; - - std::vector feed_names; - feed_names.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); - - // Currently, input_ids is in CPU even for CUDA operator, so we have to use logits location as default. - const OrtMemoryInfo& default_location = utils::FindMemoryInfoForValue(subgraph_session_state, "logits"); - - // position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. - // as we skip them when we call FindDevicesForValues, and default them to be in the same device as input_ids - feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); - - for (auto& entry : node.ImplicitInputDefs()) { - feed_names.push_back(entry->Name()); - } - - std::vector feed_locations; - feed_locations.resize(feed_names.size()); - - for (size_t i = 0, end = feed_names.size(); i < end; ++i) { - if (i >= subgraph_input_names.size()) { // implicit inputs - const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]); - feed_locations[i] = location.device; - } else { - feed_locations[i] = default_location.device; - } - } - - std::unique_ptr ffm; - ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names, - subgraph_session_state.GetOrtValueNameIdxMap(), ffm)); - ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm)); - - // setup the locations where we want the subgraph output to end up on - std::vector fetch_locations; - fetch_locations.reserve(num_subgraph_outputs); - - // past state need to be where we can feed them in to the next iteration, so set the fetch location to match the feed location. - for (int i = 0; i < num_subgraph_outputs; ++i) { - fetch_locations.push_back(&default_location); - } - - utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations); - - feeds_fetches_manager_ = std::move(ffm); - - // Check subgraph only need once so put in Setup function. - auto& inputs = subgraph.GetInputs(); - auto& outputs = subgraph.GetOutputs(); - ORT_RETURN_IF_ERROR(Validate(inputs, outputs)); - - return Status::OK(); -} - -const IExecutionProvider* GptSubgraph::GetProvider() const { - const ExecutionProviders& providers = session_state_->GetExecutionProviders(); - const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider); - const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider); - const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider; - return provider; -} - -Status GptSubgraph::CreateInitialFeeds( - const Tensor& input_ids, - const std::vector& implicit_inputs, - int num_beams, - int pad_token_id, - gsl::span& sequence_lengths, - OrtValue& expanded_input_ids, - std::vector& feeds, - const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func, - const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, - IAllocatorUniquePtr& buffer) { - ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); - - const IExecutionProvider* provider = GetProvider(); - - const TensorShape& input_ids_shape = input_ids.Shape(); - ORT_ENFORCE(input_ids_shape.NumDimensions() == 2); - const int64_t& batch_size = input_ids_shape[0]; - - // Subgraph inputs: - // input_ids: shape (B, S) wher B is batch size, and S is sequence length - // position_ids: shape (B, S) - // attention_mask: shape (B, P+S), where past_sequence_length (P) is 0 - // After expansion, their shapes will become (B, M*S), where M is num_beams. - - // Allocate subgraph inputs to be same device as input_ids - AllocatorPtr cpu_alloactor = session_state_->GetAllocator(input_ids.Location()); - - // Store allocator, which will be used in remaining feeds - auto default_allocator = provider->GetAllocator(0, OrtMemTypeDefault); - allocator_ = default_allocator; - - // Initialize empty past state - auto past_type = IsOutputFloat16() ? DataTypeImpl::GetType() : DataTypeImpl::GetType(); - int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size}; - TensorShape past_shape(&past_state_dims[0], 5); - OrtValue empty_past; - Tensor::InitOrtValue(past_type, past_shape, default_allocator, empty_past); - - // The ordering is the same as used in Setup - feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); - - OrtValue expanded_position_ids; - OrtValue expanded_attention_mask; - ORT_RETURN_IF_ERROR(create_inputs_func(&input_ids, num_beams, pad_token_id, sequence_lengths, cpu_alloactor, expanded_input_ids, expanded_position_ids, expanded_attention_mask)); - - ORT_RETURN_IF_ERROR(add_to_feeds_func(provider, expanded_input_ids, expanded_position_ids, expanded_attention_mask, feeds, buffer)); - - // The remaing inputs are past state. - for (int i = 3; i < num_subgraph_inputs; ++i) { - feeds.push_back(empty_past); - } - - // pass in implicit inputs - for (const auto* entry : implicit_inputs) { - feeds.push_back(*entry); - } - - return Status::OK(); -} - -} // namespace transformers -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 4c2af1bb92..69e37e643f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -1,7 +1,11 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include #include -#include "logits_processor.h" -#include "dump_tensor.h" #include "core/common/safeint.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" namespace onnxruntime { namespace contrib { @@ -102,7 +106,8 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, std::unordered_set blocked_word_ids; for (int j = 0; j <= static_cast(sequence.length()) - ngram_size_; j++) { // Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length) - // TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching. + // TODO(tianleiwu): build N-Gram index (hash table with prefix of length NGram - 1 as key, + // and list of last word of NGram as value) for fast matching. if (ngram_size_ == 1 || prefix == sequence.subspan(j, prefix_length)) { blocked_word_ids.insert(sequence[static_cast(j) + prefix_length]); } @@ -119,7 +124,8 @@ void NoRepeatNGramLogitsProcessor::Process(const ISequences* sequences, } template -VocabMaskLogitsProcessor::VocabMaskLogitsProcessor(const gsl::span& vocab_mask) : vocab_mask_(vocab_mask) { +VocabMaskLogitsProcessor::VocabMaskLogitsProcessor(const gsl::span& vocab_mask) + : vocab_mask_(vocab_mask) { } template @@ -145,8 +151,10 @@ void VocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, } template -PrefixVocabMaskLogitsProcessor::PrefixVocabMaskLogitsProcessor(const gsl::span& prefix_vocab_mask, int batch_size) - : prefix_vocab_mask_(prefix_vocab_mask), batch_size_(batch_size) { +PrefixVocabMaskLogitsProcessor::PrefixVocabMaskLogitsProcessor(const gsl::span& prefix_vocab_mask, + int batch_size) + : prefix_vocab_mask_(prefix_vocab_mask), + batch_size_(batch_size) { } template @@ -159,7 +167,7 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, assert(num_beams * batch_size_ == next_token_scores.batch_beam_size); // Process prefix vocabulary mask and set tokens with mask value 0 to -inf. - // prefix_vocab_mask shape (batch_szie, vocab_size). + // prefix_vocab_mask shape (batch_size, vocab_size). T* p = next_token_scores.scores.data(); for (int i = 0; i < batch_size_; i++) { size_t prefix_vocab_mask_offset = SafeInt(i) * next_token_scores.vocab_size; @@ -181,7 +189,8 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { processor_list_.clear(); if (parameters.repetition_penalty != 1.0f) { // 1.0 means no penalty - repetition_penalty_processor_ = std::make_unique>(parameters.repetition_penalty); + repetition_penalty_processor_ = std::make_unique>( + parameters.repetition_penalty); processor_list_.push_back(repetition_penalty_processor_.get()); } @@ -196,12 +205,14 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { } if (!parameters.prefix_vocab_mask.empty()) { - prefix_vocab_mask_processor_ = std::make_unique>(parameters.prefix_vocab_mask, parameters.batch_size); + prefix_vocab_mask_processor_ = std::make_unique>(parameters.prefix_vocab_mask, + parameters.batch_size); processor_list_.push_back(prefix_vocab_mask_processor_.get()); } if (parameters.min_length > 0) { - min_length_processor_ = std::make_unique>(parameters.min_length, parameters.eos_token_id); + min_length_processor_ = std::make_unique>(parameters.min_length, + parameters.eos_token_id); processor_list_.push_back(min_length_processor_.get()); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 30a3ac627f..8473b039be 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -1,9 +1,12 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once -#include "sequences.h" -#include "beam_search_parameters.h" #include "core/common/inlined_containers.h" -#include "beam_search_shared.h" +#include "contrib_ops/cpu/transformers/sequences.h" +#include "contrib_ops/cpu/transformers/beam_search_parameters.h" +#include "contrib_ops/cpu/transformers/beam_search_shared.h" namespace onnxruntime { namespace contrib { diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc index 382672f3e5..9ae94dec68 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.cc @@ -1,5 +1,8 @@ -#include "sequences.h" +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "core/common/safeint.h" +#include "contrib_ops/cpu/transformers/sequences.h" namespace onnxruntime { namespace contrib { @@ -20,8 +23,10 @@ void Sequences::Init(gsl::span buffer, int batch_beam_size, int sequenc } gsl::span Sequences::GetSequence(int beam_index) const { - gsl::span buffer(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size()); - gsl::span sequence = buffer.subspan(SafeInt(beam_index) * max_length_, static_cast(current_length_)); + gsl::span buffer(sequences[current_sequences_buffer].data(), + sequences[current_sequences_buffer].size()); + gsl::span sequence = buffer.subspan(SafeInt(beam_index) * max_length_, + static_cast(current_length_)); return sequence; } @@ -42,13 +47,16 @@ void Sequences::PrintSequences(const IConsoleDumper* dumper) const { void Sequences::AppendNextTokenToSequences( gsl::span& beam_indices, gsl::span& beam_next_tokens) { - gsl::span input(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size()); + gsl::span input(sequences[current_sequences_buffer].data(), + sequences[current_sequences_buffer].size()); gsl::span output = sequences[1 - current_sequences_buffer]; for (int i = 0; i < batch_beam_size_; i++) { int beam_index = beam_indices[i]; - gsl::span source = input.subspan(SafeInt(beam_index) * max_length_, static_cast(current_length_)); - gsl::span target = output.subspan(SafeInt(i) * max_length_, static_cast(current_length_)); + gsl::span source = input.subspan(SafeInt(beam_index) * max_length_, + static_cast(current_length_)); + gsl::span target = output.subspan(SafeInt(i) * max_length_, + static_cast(current_length_)); gsl::copy(source, target); } @@ -65,4 +73,4 @@ void Sequences::AppendNextTokenToSequences( } // namespace transformers } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sequences.h b/onnxruntime/contrib_ops/cpu/transformers/sequences.h index 5d61722e96..d353f8d354 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/sequences.h +++ b/onnxruntime/contrib_ops/cpu/transformers/sequences.h @@ -1,7 +1,10 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "gsl/gsl" -#include "beam_search_shared.h" +#include "contrib_ops/cpu/transformers/beam_search_shared.h" namespace onnxruntime { namespace contrib { @@ -47,4 +50,4 @@ class Sequences : public ISequences { } // namespace transformers } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc new file mode 100644 index 0000000000..f2e5dc23dd --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc @@ -0,0 +1,163 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/framework/framework_common.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +#include "gsl/gsl" +#include "contrib_ops/cpu/transformers/subgraph_base.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +Subgraph::Subgraph( + const onnxruntime::Node& node_in, + const std::string& attribute_name, + const GraphViewer& subgraph_in) + : node(node_in), + attribute(attribute_name), + subgraph(subgraph_in), + num_heads(0), + head_size(0), + vocab_size(0), + num_layers(0), + allocator_(nullptr), + is_output_float16_(false) { + num_implicit_inputs = static_cast(node.ImplicitInputDefs().size()); + + auto& subgraph_inputs = subgraph.GetInputs(); + auto& subgraph_outputs = subgraph.GetOutputs(); + + // inputs: input_ids, position_ids, attention_mask, past_0, past_1, ... + // outputs: logits, present_0, present_1, ... + num_subgraph_inputs = static_cast(subgraph_inputs.size()); + num_subgraph_outputs = static_cast(subgraph_outputs.size()); + + // CheckSubgraph will verify inputs and outputs later. + subgraph_input_names.reserve(num_subgraph_inputs); + for (int i = 0; i < num_subgraph_inputs; ++i) { + subgraph_input_names.push_back(subgraph_inputs[i]->Name()); + } + + subgraph_output_names.reserve(num_subgraph_outputs); + for (int i = 0; i < num_subgraph_outputs; ++i) { + subgraph_output_names.push_back(subgraph_outputs[i]->Name()); + } +} + +Status Subgraph::Setup(const SessionState& session_state, + const SessionState& subgraph_session_state) { + session_state_ = &session_state; + subgraph_session_state_ = &subgraph_session_state; + + std::vector feed_names; + feed_names.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); + + // Use the first output (logits) to find device location. + const OrtMemoryInfo& default_location = utils::FindMemoryInfoForValue(subgraph_session_state, + subgraph_output_names[0]); + + // The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter. + feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end()); + + for (auto& entry : node.ImplicitInputDefs()) { + feed_names.push_back(entry->Name()); + } + + std::vector feed_locations; + feed_locations.resize(feed_names.size()); + + for (size_t i = 0, end = feed_names.size(); i < end; ++i) { + if (i >= subgraph_input_names.size()) { // Implicit inputs + const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]); + feed_locations[i] = location.device; + } else { + feed_locations[i] = default_location.device; + } + } + + std::unique_ptr ffm; + ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names, + subgraph_session_state.GetOrtValueNameIdxMap(), ffm)); + ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm)); + + // Setup the locations where we want the subgraph output to end up on + std::vector fetch_locations; + fetch_locations.reserve(num_subgraph_outputs); + + // Past state need to be where we can feed them in to the next iteration, so set the location to match the feed. + for (int i = 0; i < num_subgraph_outputs; ++i) { + fetch_locations.push_back(&default_location); + } + + utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations); + + feeds_fetches_manager_ = std::move(ffm); + + // Check subgraph only need once so put in Setup function. + auto& inputs = subgraph.GetInputs(); + auto& outputs = subgraph.GetOutputs(); + ORT_RETURN_IF_ERROR(Validate(inputs, outputs)); + + return Status::OK(); +} + +const IExecutionProvider* Subgraph::GetProvider() const { + const ExecutionProviders& providers = session_state_->GetExecutionProviders(); + const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider); + const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider); + const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider; + return provider; +} + +Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shape, + const ONNX_NAMESPACE::TensorShapeProto* logits_shape, + bool merged_past) { + if (merged_past) { + // Merged past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads) + ORT_RETURN_IF(past_shape->dim_size() != 5, + "subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size()); + ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2, + "subgraph past state dimension 0 shall have length of 2"); + + ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for number of heads"); + + ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0, + "subgraph past state dimension 4 shall have a positive value for hidden size per head"); + this->num_heads = static_cast(past_shape->dim(2).dim_value()); + this->head_size = static_cast(past_shape->dim(4).dim_value()); + } else { + // Past state shape is like (batch_size, num_heads, past_seq_len, hidden_size/num_heads). + ORT_RETURN_IF(past_shape->dim_size() != 4, + "subgraph output present_key_self_0 is expected to have 4 dimension, got ", past_shape->dim_size()); + + ORT_RETURN_IF(!past_shape->dim(1).has_dim_value() || past_shape->dim(1).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for number of heads"); + + ORT_RETURN_IF(!past_shape->dim(3).has_dim_value() || past_shape->dim(3).dim_value() <= 0, + "subgraph past state dimension 4 shall have a positive value for hidden size per head"); + this->num_heads = static_cast(past_shape->dim(1).dim_value()); + this->head_size = static_cast(past_shape->dim(3).dim_value()); + } + + // Logits shape is like (batch_size, seq_len, vocabulary_size) + ORT_RETURN_IF(logits_shape->dim_size() != 3, + "subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size()); + + ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for vocabulary size"); + + this->vocab_size = static_cast(logits_shape->dim(2).dim_value()); + + return Status::OK(); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h similarity index 57% rename from onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h rename to onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h index 8600bf765c..69ddf84a6b 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/gpt_subgraph.h +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_base.h @@ -2,6 +2,9 @@ // Licensed under the MIT License. #pragma once + +#include +#include #include "gsl/gsl" #include "core/framework/allocator.h" #include "core/framework/feeds_fetches_manager.h" @@ -15,21 +18,22 @@ namespace onnxruntime { namespace contrib { namespace transformers { -// A class for GPT-2 subgraph inputs and outputs preparation. -struct GptSubgraph { - GptSubgraph( +class Subgraph { + public: + Subgraph( const onnxruntime::Node& node_in, const std::string& attribute_name, const GraphViewer& subgraph_in); + virtual ~Subgraph() {} - const onnxruntime::Node& node; // node that contains the subgraph - const std::string& attribute; // attribute of th node that contains the subgraph. Not used yet. - const GraphViewer& subgraph; // the subgraph + const onnxruntime::Node& node; // Node that contains the subgraph + const std::string& attribute; // Attribute of th node that contains the subgraph. Not used yet. + const GraphViewer& subgraph; // The subgraph int num_implicit_inputs; - int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience. - int num_subgraph_outputs; // same as subgraph_output_names.size() + int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience. + int num_subgraph_outputs; // Same as subgraph_output_names.size() std::vector subgraph_input_names; std::vector subgraph_output_names; @@ -40,32 +44,23 @@ struct GptSubgraph { int vocab_size; int num_layers; - // Setup exectuion + // Setup execution Status Setup(const SessionState& session_state, const SessionState& subgraph_session_state); - // Create inputs for first inference of subgraph. - Status CreateInitialFeeds( - const Tensor& input_ids, - const std::vector& implicit_inputs, - int num_beams, - int pad_token_id, - gsl::span& sequence_lengths, - OrtValue& expanded_input_ids, - std::vector& feeds, - const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func, - const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, - IAllocatorUniquePtr& buffer); - FeedsFetchesManager* GetFeedsFetchesManager() const { return feeds_fetches_manager_.get(); } const IExecutionProvider* GetProvider() const; bool IsOutputFloat16() const { return is_output_float16_; } + virtual Status Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) = 0; + protected: - Status Validate(const std::vector& subgraph_inputs, - const std::vector& subgraph_outputs); + Status GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shape, + const ONNX_NAMESPACE::TensorShapeProto* logits_shape, + bool merged_past); AllocatorPtr allocator_; const SessionState* session_state_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc new file mode 100644 index 0000000000..c69e78e223 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc @@ -0,0 +1,167 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/framework_common.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +#include "gsl/gsl" +#include "contrib_ops/cpu/transformers/subgraph_gpt.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +Status GptSubgraph::CreateInitialFeeds( + const Tensor& input_ids, + const std::vector& implicit_inputs, + int num_beams, + int pad_token_id, + gsl::span& sequence_lengths, + OrtValue& expanded_input_ids, + std::vector& feeds, + const BeamSearchDeviceHelper::CreateGptInputsFunc& create_gpt_inputs_func, + const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + IAllocatorUniquePtr& buffer) { + ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); + + const IExecutionProvider* provider = GetProvider(); + + const TensorShape& input_ids_shape = input_ids.Shape(); + ORT_ENFORCE(input_ids_shape.NumDimensions() == 2); + const int64_t& batch_size = input_ids_shape[0]; + + // Subgraph inputs: + // input_ids: shape (B, S) where B is batch size, and S is sequence length + // position_ids: shape (B, S) + // attention_mask: shape (B, P+S), where past_sequence_length (P) is 0 + // After expansion, their shapes will become (B, M*S), where M is num_beams. + + // Allocate subgraph inputs to be same device as input_ids + AllocatorPtr cpu_allocator = session_state_->GetAllocator(input_ids.Location()); + + // Store allocator, which will be used in remaining feeds + auto default_allocator = provider->GetAllocator(0, OrtMemTypeDefault); + allocator_ = default_allocator; + + // Initialize empty past state + auto past_type = IsOutputFloat16() ? DataTypeImpl::GetType() : DataTypeImpl::GetType(); + int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size}; + TensorShape past_shape(&past_state_dims[0], 5); + OrtValue empty_past; + Tensor::InitOrtValue(past_type, past_shape, default_allocator, empty_past); + + // The ordering is the same as used in Setup + feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); + + OrtValue expanded_position_ids; + OrtValue expanded_attention_mask; + ORT_RETURN_IF_ERROR(create_gpt_inputs_func(&input_ids, + num_beams, + pad_token_id, + sequence_lengths, + cpu_allocator, + expanded_input_ids, + expanded_position_ids, + expanded_attention_mask)); + + ORT_RETURN_IF_ERROR(add_to_feeds_func(provider, + {expanded_input_ids, expanded_position_ids, expanded_attention_mask}, + feeds, + buffer)); + + // The remaining inputs are past state. + for (int i = kFirstPastInputIndex; i < num_subgraph_inputs; ++i) { + feeds.push_back(empty_past); + } + + // Pass in implicit inputs + for (const auto* entry : implicit_inputs) { + feeds.push_back(*entry); + } + + return Status::OK(); +} + +Status GptSubgraph::Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) { + ORT_RETURN_IF(num_subgraph_outputs <= kFirstPresentOutputIndex, + "Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in outputs)."); + + ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2, + "Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2"); + + ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", + "subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); + ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", + "subgraph input 1 shall be named as position_ids, got: ", subgraph_inputs[1]->Name()); + ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", + "subgraph input 2 shall be named as attention_mask, got: ", subgraph_inputs[2]->Name()); + ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", + "subgraph input 3 shall be named as past_0, got: ", subgraph_inputs[3]->Name()); + + // Past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads). + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape(); + ORT_RETURN_IF(past_shape->dim_size() != 5, + "subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size()); + + ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2, + "subgraph past state dimension 0 shall have length of 2"); + + ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for number of heads"); + + ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0, + "subgraph past state dimension 4 shall have a positive value for hidden size per head"); + + // check subgraph outputs + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", + "subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name()); + + ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", + "subgraph input 1 shall be named as present_0, got: ", subgraph_outputs[1]->Name()); + + // Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size. + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); + ORT_RETURN_IF(logits_shape->dim_size() != 3, + "subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size()); + + ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0, + "subgraph past state dimension 2 shall have a positive value for vocabulary size"); + + // Save parameters related to the subgraph. + num_heads = static_cast(past_shape->dim(2).dim_value()); + head_size = static_cast(past_shape->dim(4).dim_value()); + vocab_size = static_cast(logits_shape->dim(2).dim_value()); + num_layers = static_cast(subgraph_outputs.size()) - 1; + + constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; + constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT; + constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16; + + ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "subgraph input 0 (input_ids) shall have int32 type"); + ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "subgraph input 1 (position_ids) shall have int32 type"); + ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "subgraph input 2 (attention_mask) shall have int32 type"); + + auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type(); + ORT_RETURN_IF(output_type != float32_type && output_type != float16_type, + "subgraph output 0 (logits) shall be float or float16 data type"); + + ORT_RETURN_IF(subgraph_inputs[kFirstPastInputIndex]->TypeAsProto()->tensor_type().elem_type() != output_type, + "subgraph input 3 (past_0) shall shall have same data type of logits output"); + ORT_RETURN_IF(subgraph_outputs[kFirstPresentOutputIndex]->TypeAsProto()->tensor_type().elem_type() != output_type, + "subgraph output 1 (present_0) shall shall have same data type of logits output"); + + is_output_float16_ = (output_type == float16_type); + + return Status::OK(); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h new file mode 100644 index 0000000000..728c446dd4 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h @@ -0,0 +1,42 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/transformers/subgraph_base.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +// A class for GPT-2 subgraph inputs and outputs preparation. +class GptSubgraph : public Subgraph { + public: + GptSubgraph( + const onnxruntime::Node& node_in, + const std::string& attribute_name, + const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {} + + // Create inputs for first inference of subgraph. + Status CreateInitialFeeds( + const Tensor& input_ids, + const std::vector& implicit_inputs, + int num_beams, + int pad_token_id, + gsl::span& sequence_lengths, + OrtValue& expanded_input_ids, + std::vector& feeds, + const BeamSearchDeviceHelper::CreateGptInputsFunc& create_gpt_inputs_func, + const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + IAllocatorUniquePtr& buffer); + + Status Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) override; + + constexpr static int kFirstPastInputIndex = 3; + constexpr static int kFirstPresentOutputIndex = 1; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc new file mode 100644 index 0000000000..8a55cb3f2f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc @@ -0,0 +1,154 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/framework_common.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +#include "gsl/gsl" +#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" +#include "contrib_ops/cpu/transformers/beam_search_device_helper.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +/* T5 Decoder Subgraph. + + Inputs: + input_ids: int32 (B, 1) + encoder_attention_mask: int32 (B, encode_sequence_length) + encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + + past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) + past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) + ... (for each self attention layer) + + past_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + past_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + ... (for each cross attention layer) + + Outputs: + logits: (B, 1, vocab_size) + + present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) + present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) + ... (for each self attention layer) + + Note: + B = batch_size * num_beams + Data type of input or output is float or float16 if not specified. +*/ + +Status T5DecoderSubgraph::Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) { + ORT_RETURN_IF(num_subgraph_inputs < 7 || (num_subgraph_inputs - kFirstPastInputIndex) % 4 != 0, + "number of outputs expected to be 3 + 4 * layers, got:", num_subgraph_inputs); + ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - kFirstPresentOutputIndex) % 2 != 0, + "number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs); + + ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", + "decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name()); + ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask", + "decoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name()); + ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states", + "decoder subgraph input 2 shall be named as encoder_hidden_states, got: ", subgraph_inputs[2]->Name()); + + // check subgraph outputs + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", + "decoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name()); + + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[kFirstPresentOutputIndex]->Shape(); + + // Save parameters related to the subgraph. + ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); + num_layers = (static_cast(subgraph_outputs.size()) - kFirstPresentOutputIndex) / 2; + + constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; + constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT; + constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16; + + ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "decoder subgraph input 0 (input_ids) shall have int32 type"); + ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "decoder subgraph input 1 (encoder_attention_mask) shall have int32 type"); + + auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type(); + ORT_RETURN_IF(float_type != float32_type && float_type != float16_type, + "decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type"); + + for (int i = kFirstPastInputIndex; i < num_subgraph_inputs; i++) { + ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, + "decoder subgraph past inputs shall have same data type as that of encoder_hidden_states"); + } + + for (int i = 0; i < num_subgraph_outputs; i++) { + ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type, + "decoder subgraph output shall have same data type as that of encoder_hidden_states"); + } + + is_output_float16_ = (subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() == float16_type); + + return Status::OK(); +} + +// Create inputs for decoder from the following data sources: +// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens) +// encoder fetches: logits, +// encoder_hidden_states, +// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ... +// decoder_feeds: input_ids, +// encoder_attention_mask, +// encoder_hidden_states, +// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ... +Status T5DecoderSubgraph::CreateInitialFeeds( + gsl::span beam_next_tokens, + const std::vector& implicit_inputs, + const std::vector& encoder_feeds, + const std::vector& encoder_fetches, + std::vector& decoder_feeds, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + void* stream) { + ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); + + // Allocate subgraph inputs from same device as inputs of encoder subgraph. + AllocatorPtr allocator = session_state_->GetAllocator(encoder_feeds[0].Get().Location()); + + // Copy beam next tokens in CPU to input_ids in provider device (CPU for CPU EP, or GPU for CUDA EP). + int batch_beam_size = static_cast(beam_next_tokens.length()); + int64_t dims[] = {batch_beam_size, 1}; + TensorShape input_ids_shape(&dims[0], 2); + OrtValue input_ids; + Tensor::InitOrtValue(DataTypeImpl::GetType(), input_ids_shape, allocator, input_ids); + ORT_RETURN_IF_ERROR(device_copy_int32_func( + input_ids.GetMutable()->MutableDataAsSpan(), + beam_next_tokens, + stream, + DeviceCopyDirection::hostToDevice)); + + // The ordering is the same as used in Setup. + decoder_feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); + decoder_feeds.push_back(input_ids); + + // The encoder_attention_mask is copied from the second input of encoder. + decoder_feeds.push_back(encoder_feeds[1]); + + // The encoder_hidden_states and past states are copied from the second output of encoder. + for (size_t j = 1; j < encoder_fetches.size(); j++) { + decoder_feeds.push_back(encoder_fetches[j]); + } + + // Pass through implicit inputs. + for (const auto* entry : implicit_inputs) { + decoder_feeds.push_back(*entry); + } + + return Status::OK(); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h new file mode 100644 index 0000000000..3e59181a65 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.h @@ -0,0 +1,39 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/transformers/subgraph_base.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +// A class for T5 decoder subgraph inputs and outputs preparation. +class T5DecoderSubgraph : public Subgraph { + public: + T5DecoderSubgraph( + const onnxruntime::Node& node_in, + const std::string& attribute_name, + const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {} + + // Create inputs for first inference of decoder subgraph. + Status CreateInitialFeeds( + gsl::span beam_next_tokens, + const std::vector& implicit_inputs, + const std::vector& encoder_feeds, + const std::vector& encoder_fetches, + std::vector& decoder_feeds, + const BeamSearchDeviceHelper::DeviceCopyFunc& device_copy_int32_func, + void* stream); + + Status Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) override; + + constexpr static int kFirstPastInputIndex = 3; + constexpr static int kFirstPresentOutputIndex = 1; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc new file mode 100644 index 0000000000..9b8292ed6f --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc @@ -0,0 +1,145 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/framework/framework_common.h" +#include "core/framework/session_state.h" +#include "core/framework/tensorprotoutils.h" +#include "core/framework/utils.h" +#include "core/providers/cpu/tensor/utils.h" +#include "gsl/gsl" +#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +/* T5 Encoder Subgraph (It also contains decoder initialization where decoder_input_ids are filled with start token ID). + + Inputs: + encoder_input_ids: int32 (B, encode_sequence_length) + encoder_attention_mask: int32 (B, encode_sequence_length) + decoder_input_ids: int32 (B, 1) + + Outputs: + logits: (B, 1, vocab_size) + encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + + present_key_self_0: (B, num_heads, 1, head_size) + present_value_self_0: (B, num_heads, 1, head_size) + ... (for each self attention layer) + + present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + ... (for each cross attention layer) + + Note: + Here, B = batch_size * num_beams since we expand the inputs. + Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams. + Data type of input or output is float or float16 if not specified. +*/ + +Status T5EncoderSubgraph::Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) { + ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs); + + ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs); + ORT_RETURN_IF((static_cast(subgraph_outputs.size()) - kFirstPresentOutputIndex) % 4 != 0, + "number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs); + + ORT_RETURN_IF(subgraph_inputs[0]->Name() != "encoder_input_ids", + "encoder subgraph input 0 shall be named as encoder_input_ids, got: ", subgraph_inputs[0]->Name()); + ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask", + "encoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name()); + ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids", + "encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name()); + + ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", + "encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name()); + ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states", + "encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name()); + ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0", + "encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name()); + ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0", + "encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name()); + + const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape(); + const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape(); + + // Save parameters related to the subgraph. + ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false)); + num_layers = (static_cast(subgraph_outputs.size()) - kFirstPresentOutputIndex) / 4; + + constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32; + constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT; + constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16; + + ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "encoder subgraph input 0 (encoder_input_ids) shall have int32 type"); + ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "encoder subgraph input 1 (encoder_attention_mask) shall have int32 type"); + ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type, + "encoder subgraph input 2 (decoder_input_ids) shall have int32 type"); + + auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type(); + ORT_RETURN_IF(output_type != float32_type && output_type != float16_type, + "encoder subgraph output 0 (logits) shall be float or float16 data type"); + + for (int i = 1; i < num_subgraph_outputs; i++) { + ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != output_type, + "encoder subgraph outputs 1, 2, ... shall have same data type"); + } + + is_output_float16_ = (output_type == float16_type); + + return Status::OK(); +} + +// Create inputs for first inference of subgraph. +Status T5EncoderSubgraph::CreateInitialFeeds( + const Tensor& encoder_input_ids, + const std::vector& implicit_inputs, + int num_beams, + int pad_token_id, + int start_token_id, + std::vector& feeds, + const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func, + const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + IAllocatorUniquePtr& buffer, + OrtValue& expanded_decoder_input_ids) { + ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds"); + + // The ordering is the same as used in Setup. + feeds.reserve(static_cast(num_subgraph_inputs) + static_cast(num_implicit_inputs)); + + // Allocate subgraph inputs to be same device as encoder_input_ids. + AllocatorPtr cpu_allocator = session_state_->GetAllocator(encoder_input_ids.Location()); + + // TODO(tianleiwu): expand the outputs instead of inputs to save computation. + OrtValue expanded_encoder_input_ids; + OrtValue expanded_encoder_attention_mask; + ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&encoder_input_ids, + num_beams, + pad_token_id, + start_token_id, + cpu_allocator, + expanded_encoder_input_ids, + expanded_encoder_attention_mask, + expanded_decoder_input_ids)); + + const IExecutionProvider* provider = GetProvider(); + ORT_RETURN_IF_ERROR(add_to_feeds_func( + provider, + {expanded_encoder_input_ids, expanded_encoder_attention_mask, expanded_decoder_input_ids}, + feeds, + buffer)); + + for (const auto* entry : implicit_inputs) { + feeds.push_back(*entry); + } + + return Status::OK(); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h new file mode 100644 index 0000000000..30ff3840a7 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.h @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/transformers/subgraph_base.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +// A class for T5 encoder subgraph inputs and outputs preparation. +class T5EncoderSubgraph : public Subgraph { + public: + T5EncoderSubgraph( + const onnxruntime::Node& node_in, + const std::string& attribute_name, + const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {} + + // Create inputs for first inference of subgraph. + Status CreateInitialFeeds( + const Tensor& encoder_input_ids, + const std::vector& implicit_inputs, + int num_beams, + int pad_token_id, + int start_token_id, + std::vector& feeds, + const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func, + const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + IAllocatorUniquePtr& buffer, + OrtValue& expanded_decoder_input_ids); + + Status Validate(const std::vector& subgraph_inputs, + const std::vector& subgraph_outputs) override; + + constexpr static int kFirstPresentOutputIndex = 2; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc index 03a4f77755..b8afb13919 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search.cc @@ -4,8 +4,8 @@ #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cuda_execution_provider.h" #include "contrib_ops/cuda/transformers/beam_search.h" -#include "beam_search_device_helper.h" -#include "dump_cuda_tensor.h" +#include "contrib_ops/cuda/transformers/beam_search_device_helper.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" namespace onnxruntime { namespace contrib { @@ -38,16 +38,19 @@ BeamSearch::BeamSearch(const OpKernelInfo& info) SetComputeStream(static_cast(info.GetExecutionProvider()->GetComputeStream())); SetDeviceHelpers(BeamSearchCudaDeviceHelper::AddToFeeds, - BeamSearchCudaDeviceHelper::TopK); - - SetDeviceHelpers(BeamSearchCudaDeviceHelper::ProcessLogits, - BeamSearchCudaDeviceHelper::InitBeamState, + BeamSearchCudaDeviceHelper::TopK, BeamSearchCudaDeviceHelper::DeviceCopy, - BeamSearchCudaDeviceHelper::UpdateFeeds); + BeamSearchCudaDeviceHelper::DeviceCopy, + BeamSearchCudaDeviceHelper::ProcessLogits, + BeamSearchCudaDeviceHelper::ProcessLogits, + BeamSearchCudaDeviceHelper::InitBeamState, + BeamSearchCudaDeviceHelper::InitBeamState); - SetDeviceHelpers(BeamSearchCudaDeviceHelper::ProcessLogits, - BeamSearchCudaDeviceHelper::InitBeamState, - BeamSearchCudaDeviceHelper::UpdateFeeds); + SetDeviceHelpers_Gpt(BeamSearchCudaDeviceHelper::UpdateGptFeeds, + BeamSearchCudaDeviceHelper::UpdateGptFeeds); + + SetDeviceHelpers_EncoderDecoder(BeamSearchCudaDeviceHelper::UpdateDecoderFeeds, + BeamSearchCudaDeviceHelper::UpdateDecoderFeeds); SetConsoleDumper(&g_cuda_dumper); } @@ -71,4 +74,4 @@ Status BeamSearch::Compute(OpKernelContext* context) const { } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc index b8d871f320..3134d3a020 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.cc @@ -1,16 +1,17 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #include "core/providers/shared_library/provider_api.h" #include "core/providers/cuda/math/topk_impl.h" #include "core/providers/cuda/math/softmax.h" #include "core/providers/cuda/shared_inc/accumulation_type.h" #include "core/framework/ort_value.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" -#include "beam_search_impl.h" #include -#include "dump_cuda_tensor.h" - -#ifdef DEBUG_BEAM_SEARCH -using namespace onnxruntime::contrib::cuda::transformers; -#endif +#include "contrib_ops/cuda/transformers/beam_search_impl.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" +#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" +#include "contrib_ops/cpu/transformers/subgraph_gpt.h" namespace onnxruntime { namespace concurrency { @@ -52,7 +53,7 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, output_indices = Tensor::Create(DataTypeImpl::GetType(), output_shape, allocator); if (input->IsDataType()) { - return TopKImpl(nullptr, // We limit number of beams in BeamSearchParameters, so that K <= 256 and kernel is not needed + return TopKImpl(nullptr, // We limit number of beams in BeamSearchParameters, so K <= 256 and use NULL here reinterpret_cast(stream), input->Data(), static_cast(output_values->MutableDataRaw()), @@ -87,35 +88,53 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, } Status AddToFeeds(const IExecutionProvider* execution_provider, - OrtValue& input_ids, - OrtValue& position_ids, - OrtValue& attention_mask, + std::initializer_list inputs, std::vector& feeds, IAllocatorUniquePtr& buffer) { // Copy tensors to GPU, then add to feeds const CUDAExecutionProvider* provider = reinterpret_cast(execution_provider); - const TensorShape& shape = input_ids.Get().Shape(); - ORT_ENFORCE(shape.NumDimensions() == 2); - const int64_t elements = shape[0] * shape[1]; + size_t total_bytes = 0; + for (auto& input : inputs) { + if (input.IsAllocated()) { + total_bytes += input.Get().Shape().Size() * input.Type()->Size(); + } + } + + ORT_ENFORCE(total_bytes > 0); AllocatorPtr pinned_allocator = provider->GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU); cudaStream_t stream = static_cast(provider->GetComputeStream()); - - size_t bytes = (sizeof(int32_t) + sizeof(int32_t) + sizeof(int32_t)) * elements; - auto pinned_buffer = IAllocator::MakeUniquePtr(pinned_allocator, bytes); + auto pinned_buffer = IAllocator::MakeUniquePtr(pinned_allocator, total_bytes); char* pinned_data = static_cast(pinned_buffer.get()); // Copy tensors to one pinned memory buffer (so that we only need copy to GPU once) - memcpy(pinned_data, input_ids.Get().Data(), sizeof(int32_t) * elements); - memcpy(pinned_data + sizeof(int32_t) * elements, position_ids.Get().Data(), sizeof(int32_t) * elements); - memcpy(pinned_data + 2 * sizeof(int32_t) * elements, attention_mask.Get().Data(), sizeof(int32_t) * elements); + char* destination = pinned_data; + for (auto& input : inputs) { + if (input.IsAllocated()) { + const Tensor& tensor = input.Get(); + const size_t bytes = input.Type()->Size() * tensor.Shape().Size(); + MLDataType dataType = tensor.DataType(); + if (dataType == DataTypeImpl::GetType()) { + memcpy(destination, input.Get().Data(), bytes); + } else if (dataType == DataTypeImpl::GetType()) { + memcpy(destination, input.Get().Data(), bytes); + } else { + return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, + "AddToFeeds: An implementation for the input type ", + dataType, " is not supported yet"); + } + + // Do not need alignment because GPT has int32 inputs (past is empty) and T5 encoder has int64 inputs. + destination += bytes; + } + } if (!buffer) { - buffer = provider->GetScratchBuffer(bytes); + buffer = provider->GetScratchBuffer(total_bytes); } char* gpu_data = buffer.get(); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, bytes, cudaMemcpyHostToDevice, stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream)); // Create an event to make sure the async copy is finished before reading the data. onnxruntime::contrib::cuda::AutoDestoryCudaEvent new_event; @@ -124,33 +143,32 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, stream)); CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone)); - // TODO: allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call. - OrtValue device_input_ids; - OrtValue device_position_ids; - OrtValue device_attention_mask; + // TODO(tianleiwu): allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call. const OrtMemoryInfo& location = provider->GetAllocator(0, OrtMemTypeDefault)->Info(); - Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, gpu_data, location, device_input_ids); - Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, gpu_data + sizeof(int32_t) * elements, location, device_position_ids); - Tensor::InitOrtValue(DataTypeImpl::GetType(), shape, gpu_data + 2 * sizeof(int32_t) * elements, location, device_attention_mask); + for (auto& input : inputs) { + if (input.IsAllocated()) { + const Tensor& tensor = input.Get(); + const TensorShape& shape = tensor.Shape(); + const size_t bytes = input.Type()->Size() * shape.Size(); + MLDataType dataType = tensor.DataType(); - feeds.push_back(device_input_ids); - feeds.push_back(device_position_ids); - feeds.push_back(device_attention_mask); + OrtValue device_input; + Tensor::InitOrtValue(dataType, shape, gpu_data, location, device_input); + gpu_data += bytes; + feeds.push_back(device_input); + } + } return Status::OK(); } template void InitBeamState(transformers::IBeamSearchState* beam_state, - transformers::IBeamSearchCpuState* cpu_state, gsl::span& sequence_lengths, int batch_size, int num_beams, - gsl::span input_ids_in_cpu, - int sequence_length, - int max_length, void* stream) { - // TODO: we can use another stream to avoid blocking subgraph execution. + // TODO(tianleiwu): we can use another stream to avoid blocking subgraph execution. cudaStream_t cuda_stream = reinterpret_cast(stream); cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream); cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream); @@ -162,18 +180,9 @@ void InitBeamState(transformers::IBeamSearchState* beam_state, // copy sequence lengths to GPU // since next_positions is only needed to update feeds after subgraph execution, so it is fine to use Async here. - // cudaMemsetAsync(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes(), cuda_stream); - cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), cudaMemcpyHostToDevice, cuda_stream); - - memset(cpu_state->sequences_space.data(), 0, cpu_state->sequences_space.size_bytes()); - - // Copy input_ids to sequences[0] - gsl::span sequences_0 = cpu_state->sequences_space; - int batch_beam_size = batch_size * num_beams; - for (int i = 0; i < batch_beam_size; i++) { - for (int j = 0; j < sequence_length; j++) { - sequences_0[SafeInt(i) * max_length + j] = static_cast(input_ids_in_cpu[SafeInt(i) * sequence_length + j]); - } + if (!beam_state->next_positions.empty()) { // next_positions is empty for T5 + cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), + cudaMemcpyHostToDevice, cuda_stream); } } @@ -220,12 +229,13 @@ Status ProcessLogits(const OrtValue& logits, // // When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1. gsl::span& next_token_logits = beam_state->next_token_logits; if (input_length > 1) { - // TODO: use one kernel to replace a loop of memory copy. + // TODO(tianleiwu): use one kernel to replace a loop of memory copy. const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size; for (int i = 0; i < batch_beam_size; i++) { gsl::span source(reinterpret_cast(current_logits), vocab_size); gsl::span target = next_token_logits.subspan(i * vocab_size, vocab_size); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, + cudaMemcpyDeviceToDevice, cuda_stream)); current_logits += input_length * vocab_size; } } @@ -253,12 +263,14 @@ Status ProcessLogits(const OrtValue& logits, // // Copy sequences to device only when repetition penalty or no repeat ngram is used in kernel BufferUniquePtr sequences_buffer; int current_sequence_length = sequences->GetSequenceLength(); - if (parameters->repetition_penalty != 1.0f || (parameters->no_repeat_ngram_size > 0 && current_sequence_length >= parameters->no_repeat_ngram_size)) { + bool run_ngram = parameters->no_repeat_ngram_size > 0 && current_sequence_length >= parameters->no_repeat_ngram_size; + if (parameters->repetition_penalty != 1.0f || run_ngram) { size_t bytes = SafeInt(sizeof(int32_t)) * batch_beam_size * parameters->max_length; void* data = allocator->Alloc(bytes); BufferUniquePtr temp_buffer(data, BufferDeleter(allocator)); sequences_buffer = std::move(temp_buffer); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes, cudaMemcpyHostToDevice, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes, + cudaMemcpyHostToDevice, cuda_stream)); } cuda::LaunchLogitsProcessKernel( @@ -277,20 +289,25 @@ Status ProcessLogits(const OrtValue& logits, // cuda_stream); #ifdef DEBUG_BEAM_SEARCH - dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, num_beams, vocab_size); + dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif // Add beam score to next token scores. Corresponding python code is like: // next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores) - cuda::LaunchAddProbsKernel(next_token_scores.data(), beam_state->beam_scores.data(), batch_size, num_beams, vocab_size, cuda_stream); + cuda::LaunchAddProbsKernel(next_token_scores.data(), beam_state->beam_scores.data(), + batch_size, num_beams, vocab_size, cuda_stream); #ifdef DEBUG_BEAM_SEARCH - dumper->Print("next_token_scores after adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size); + dumper->Print("next_token_scores adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size); #endif if (output_scores) { // Append next token scores to the scores output. - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state->remaining_scores.data(), next_token_scores.data(), next_token_scores.size_bytes(), cudaMemcpyDeviceToDevice, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state->remaining_scores.data(), + next_token_scores.data(), + next_token_scores.size_bytes(), + cudaMemcpyDeviceToDevice, + cuda_stream)); beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size()); } @@ -301,7 +318,8 @@ Status ProcessLogits(const OrtValue& logits, // TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); auto element_type = DataTypeImpl::GetType(); OrtValue next_token_scores_value; - Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value); + Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), + next_token_scores_value); const Tensor& input = next_token_scores_value.Get(); constexpr int axis = 1; @@ -311,7 +329,8 @@ Status ProcessLogits(const OrtValue& logits, // std::unique_ptr topk_scores; std::unique_ptr topk_indices; - ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, topk_scores, topk_indices)); + ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, + topk_scores, topk_indices)); #ifdef DEBUG_BEAM_SEARCH dumper->Print("topk_scores", *(topk_scores.get())); @@ -322,7 +341,8 @@ Status ProcessLogits(const OrtValue& logits, // // next_indices = (next_tokens / vocab_size).long() // next_tokens = next_tokens % vocab_size const int64_t* next_token_indices = topk_indices->Data(); - cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(), batch_size, top_k, vocab_size, cuda_stream); + cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(), + batch_size, top_k, vocab_size, cuda_stream); const float* data = topk_scores->Data(); @@ -332,12 +352,26 @@ Status ProcessLogits(const OrtValue& logits, // dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k); #endif - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), data, topk_scores->Shape().Size() * sizeof(float), cudaMemcpyDeviceToHost, cuda_stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(), beam_state->next_tokens.data(), beam_state->next_tokens.size_bytes(), cudaMemcpyDeviceToHost, cuda_stream)); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_indices.data(), beam_state->next_indices.data(), beam_state->next_indices.size_bytes(), cudaMemcpyDeviceToHost, cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), + data, + topk_scores->Shape().Size() * sizeof(float), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(), + beam_state->next_tokens.data(), + beam_state->next_tokens.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_indices.data(), + beam_state->next_indices.data(), + beam_state->next_indices.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); - gsl::span next_scores = gsl::make_span(cpu_state->topk_scores.data(), static_cast::index_type>(topk_scores->Shape().Size())); + gsl::span next_scores = gsl::make_span( + cpu_state->topk_scores.data(), + static_cast::index_type>(topk_scores->Shape().Size())); gsl::span next_tokens(cpu_state->topk_tokens.data(), beam_state->next_tokens.size()); gsl::span next_indices(cpu_state->topk_indices.data(), beam_state->next_indices.size()); @@ -354,55 +388,98 @@ template Status DeviceCopy(gsl::span target, gsl::span source, void* stream, int copyDirection) { assert(copyDirection >= 0 && copyDirection <= 3); if (stream == nullptr) { - CUDA_RETURN_IF_ERROR(cudaMemcpy(target.data(), source.data(), source.size_bytes(), static_cast(copyDirection))); + CUDA_RETURN_IF_ERROR(cudaMemcpy(target.data(), source.data(), source.size_bytes(), + static_cast(copyDirection))); } else { cudaStream_t cuda_stream = reinterpret_cast(stream); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), static_cast(copyDirection), cuda_stream)); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), + static_cast(copyDirection), cuda_stream)); CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); } return Status::OK(); } template -Status PickPastState(const std::vector& last_outputs, - std::vector& next_inputs, - gsl::span& beam_indices, - AllocatorPtr allocator, - void* stream) { - for (size_t i = 1; i < last_outputs.size(); ++i) { - const OrtValue& present = last_outputs[i]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64) +Status PickGptPastState(const std::vector& last_outputs, + std::vector& next_inputs, + gsl::span& beam_indices, + AllocatorPtr allocator, + void* stream) { + int num_present_tensors = static_cast(last_outputs.size()) - transformers::GptSubgraph::kFirstPresentOutputIndex; + for (int i = 0; i < num_present_tensors; ++i) { + const OrtValue& present = last_outputs[transformers::GptSubgraph::kFirstPresentOutputIndex + i]; + + // shape is like (2, batch_beam_size, 12, past_seq_len, 64) const TensorShape& past_shape = present.Get().Shape(); + auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; + auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4]; // Create a tensor with same shape. - // TODO: allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data. + // TODO(tianleiwu): allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data. OrtValue past; auto past_type = DataTypeImpl::GetType(); Tensor::InitOrtValue(past_type, past_shape, allocator, past); - auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4]; - auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4]; - gsl::span past_span = gsl::make_span(past.GetMutable()->MutableData(), past_shape.Size()); gsl::span present_span = gsl::make_span(present.Get().Data(), past_shape.Size()); for (gsl::index j = 0; j < beam_indices.length(); j++) { int32_t beam_index = beam_indices[j]; gsl::span present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); - gsl::span present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam); + gsl::span present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, + block_size_per_beam); gsl::span past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam); gsl::span past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_key.data(), present_key.data(), present_key.size_bytes(), cudaMemcpyDeviceToDevice, reinterpret_cast(stream))); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_value.data(), present_value.data(), present_value.size_bytes(), cudaMemcpyDeviceToDevice, reinterpret_cast(stream))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_key.data(), present_key.data(), present_key.size_bytes(), + cudaMemcpyDeviceToDevice, reinterpret_cast(stream))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_value.data(), present_value.data(), present_value.size_bytes(), + cudaMemcpyDeviceToDevice, reinterpret_cast(stream))); } - next_inputs[i + 2] = past; + next_inputs[transformers::GptSubgraph::kFirstPastInputIndex + i] = past; + } + + return Status::OK(); +} + +// Copy present state to past state for T5 model +template +Status PickT5PastState(const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span& beam_indices, + AllocatorPtr allocator, + void* stream) { + for (int i = 0; i < num_present_tensors; ++i) { + const OrtValue& present = last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i]; + + // shape is like (batch_beam_size, 12, past_seq_len, 64) + const TensorShape& past_shape = present.Get().Shape(); + auto block_size_per_beam = past_shape[1] * past_shape[2] * past_shape[3]; + + // Create a tensor with same shape. + // TODO(tianleiwu): allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data. + OrtValue past; + Tensor::InitOrtValue(DataTypeImpl::GetType(), past_shape, allocator, past); + + gsl::span past_span = gsl::make_span(past.GetMutable()->MutableData(), past_shape.Size()); + gsl::span present_span = gsl::make_span(present.Get().Data(), past_shape.Size()); + for (gsl::index j = 0; j < beam_indices.length(); j++) { + int32_t beam_index = beam_indices[j]; + gsl::span present_beam = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam); + gsl::span past_beam = past_span.subspan(j * block_size_per_beam, block_size_per_beam); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_beam.data(), present_beam.data(), present_beam.size_bytes(), + cudaMemcpyDeviceToDevice, reinterpret_cast(stream))); + } + + next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] = past; } return Status::OK(); } template -Status UpdateFeeds( +Status UpdateGptFeeds( AllocatorPtr allocator, void* stream, const std::vector& last_outputs, @@ -421,7 +498,8 @@ Status UpdateFeeds( OrtValue input_ids; Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids); int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); - CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, reinterpret_cast(stream))); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), + cudaMemcpyHostToDevice, reinterpret_cast(stream))); next_inputs[0] = input_ids; // Update position IDs @@ -439,7 +517,8 @@ Status UpdateFeeds( int32_t* mask_data = attention_mask.GetMutable()->MutableData(); // Launch kernel to update position_ids and attention_mask for next iteration - cuda::LaunchUpdateKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length, reinterpret_cast(stream)); + cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length, + reinterpret_cast(stream)); next_inputs[2] = attention_mask; @@ -453,12 +532,13 @@ Status UpdateFeeds( // Update past state if (num_beams == 1) { + const int k = transformers::GptSubgraph::kFirstPastInputIndex - transformers::GptSubgraph::kFirstPresentOutputIndex; // feed present_* output to past_* inputs one by one - for (size_t i = 1; i < last_outputs.size(); ++i) { - next_inputs[i + 2] = last_outputs[i]; + for (size_t i = transformers::GptSubgraph::kFirstPresentOutputIndex; i < last_outputs.size(); ++i) { + next_inputs[i + k] = last_outputs[i]; } } else { - ORT_RETURN_IF_ERROR(PickPastState(last_outputs, next_inputs, beam_indices, allocator, stream)); + ORT_RETURN_IF_ERROR(PickGptPastState(last_outputs, next_inputs, beam_indices, allocator, stream)); } // Make sure data is ready before next subgraph execution. @@ -466,15 +546,63 @@ Status UpdateFeeds( return Status::OK(); } +// Update decoder inputs given decoder outputs of last iteration. +template +Status UpdateDecoderFeeds( + AllocatorPtr allocator, + void* stream, + const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams, + const transformers::IConsoleDumper* dumper) { + // last_outputs: logits, present_key_self_0, present_value_self_0, ... + // next_inputs: input_ids, + // encoder_attention_mask, encoder_hidden_states, + // past_key_self_0, past_value_self_0, ... + // past_key_cross_0, past_value_cross_0, ... + // Only need copy beam next tokens to input_ids, and copy present_*_self_* to past_*_self_*, + + // Update input_ids with next tokens. + int batch_beam_size = static_cast(beam_next_tokens.length()); + int64_t dims[] = {batch_beam_size, 1}; + TensorShape input_ids_shape(&dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue input_ids; + Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids); + int32_t* input_ids_data = input_ids.GetMutable()->MutableData(); + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), + cudaMemcpyHostToDevice, reinterpret_cast(stream))); + next_inputs[0] = input_ids; + +#ifdef DEBUG_BEAM_SEARCH + dumper->Print("input_ids", input_ids); +#else + ORT_UNUSED_PARAMETER(dumper); +#endif + + // Update past state + ORT_ENFORCE(last_outputs.size() >= static_cast(1 + num_present_tensors)); + // TODO(tianleiwu): remove num_beams==1 once GreedySearch operator is available. + if (num_beams == 1) { + // feed present_* output to past_* inputs one by one + for (int i = 0; i < num_present_tensors; ++i) { + next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] = + last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i]; + return Status::OK(); + } + } + + return PickT5PastState(last_outputs, next_inputs, num_present_tensors, beam_indices, allocator, stream); +} + // Explicit template instantiations of functions template void InitBeamState(transformers::IBeamSearchState* beam_state, - transformers::IBeamSearchCpuState* cpu_state, gsl::span& sequence_lengths, int batch_size, int num_beams, - gsl::span input_ids_in_cpu, - int sequence_length, - int max_length, void* stream); template Status ProcessLogits(const OrtValue& logits, @@ -494,9 +622,15 @@ template Status DeviceCopy( gsl::span target, gsl::span source, void* stream, - int copyDirectionn); + int copyDirection); -template Status UpdateFeeds( +template Status DeviceCopy( + gsl::span target, + gsl::span source, + void* stream, + int copyDirection); + +template Status UpdateGptFeeds( AllocatorPtr allocator, void* stream, const std::vector& last_outputs, @@ -510,13 +644,9 @@ template Status UpdateFeeds( // Float16 template void InitBeamState(transformers::IBeamSearchState* beam_state, - transformers::IBeamSearchCpuState* cpu_state, gsl::span& sequence_lengths, int batch_size, int num_beams, - gsl::span input_ids_in_cpu, - int sequence_length, - int max_length, void* stream); template Status ProcessLogits(const OrtValue& logits, @@ -532,7 +662,7 @@ template Status ProcessLogits(const OrtValue& logits, void* stream, const transformers::IConsoleDumper* dumper); -template Status UpdateFeeds( +template Status UpdateGptFeeds( AllocatorPtr allocator, void* stream, const std::vector& last_outputs, @@ -544,6 +674,28 @@ template Status UpdateFeeds( int num_beams, const transformers::IConsoleDumper* dumper); +template Status UpdateDecoderFeeds( + AllocatorPtr allocator, + void* stream, + const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams, + const transformers::IConsoleDumper* dumper); + +template Status UpdateDecoderFeeds( + AllocatorPtr allocator, + void* stream, + const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams, + const transformers::IConsoleDumper* dumper); + } // namespace BeamSearchCudaDeviceHelper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h index 24eb1d4ebd..dde2b9e3f5 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_device_helper.h @@ -1,3 +1,6 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + #pragma once #include "core/providers/cuda/shared_inc/fpgeneric.h" @@ -26,21 +29,15 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, std::unique_ptr& output_indices); Status AddToFeeds(const IExecutionProvider* execution_provider, - OrtValue& input_ids, - OrtValue& position_ids, - OrtValue& attention_mask, + std::initializer_list inputs, std::vector& feeds, IAllocatorUniquePtr& buffer); template void InitBeamState(transformers::IBeamSearchState* beam_state, - transformers::IBeamSearchCpuState* cpu_state, gsl::span& sequence_lengths, int batch_size, int num_beams, - gsl::span input_ids_in_cpu, - int sequence_length, - int max_length, void* stream); template @@ -64,7 +61,7 @@ Status DeviceCopy(gsl::span target, int copyDirection); template -Status UpdateFeeds( +Status UpdateGptFeeds( AllocatorPtr allocator, void* stream, const std::vector& last_outputs, @@ -76,6 +73,23 @@ Status UpdateFeeds( int num_beams, const transformers::IConsoleDumper* dumper); +// --------------------------------------------------------------- +// Functions for encoder-decoder model like T5 +// --------------------------------------------------------------- + +// Update decoder inputs given decoder outputs of last iteration. +template +Status UpdateDecoderFeeds( + AllocatorPtr allocator, + void* stream, + const std::vector& last_outputs, + std::vector& next_inputs, + int num_present_tensors, + gsl::span beam_next_tokens, + gsl::span beam_indices, + int num_beams, + const transformers::IConsoleDumper* dumper); + } // namespace BeamSearchCudaDeviceHelper } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu index 4e171c0c46..4f93b1dded 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu @@ -1,10 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. -#include "beam_search_impl.h" #include "core/providers/cuda/cuda_common.h" #include "core/providers/cuda/cu_inc/common.cuh" #include "cub/util_type.cuh" +#include "contrib_ops/cuda/transformers/beam_search_impl.h" namespace onnxruntime { namespace contrib { @@ -52,7 +52,11 @@ void LaunchNextTokenKernel(const int64_t* next_token_indices, int total_elements = batch_size * top_k; constexpr int blockSize = 256; const int gridSize = (total_elements + blockSize - 1) / blockSize; - NextTokenKernel<<>>(next_token_indices, next_indices, next_tokens, vocab_size, total_elements); + NextTokenKernel<<>>(next_token_indices, + next_indices, + next_tokens, + vocab_size, + total_elements); } template @@ -153,8 +157,19 @@ void LaunchLogitsProcessKernel( int total_elements = batch_size * num_beams * vocab_size; constexpr int blockSize = 256; const int gridSize = (total_elements + blockSize - 1) / blockSize; - LogitsProcessKernel<<>>(next_token_scores, vocab_mask, prefix_vocab_mask, num_beams, vocab_size, total_elements, demote_token_id, - sequences, max_sequence_length, current_sequence_length, repetition_penalty, no_repeat_ngram_size); + LogitsProcessKernel<<>>( + next_token_scores, + vocab_mask, + prefix_vocab_mask, + num_beams, + vocab_size, + total_elements, + demote_token_id, + sequences, + max_sequence_length, + current_sequence_length, + repetition_penalty, + no_repeat_ngram_size); } // Instantiation @@ -221,11 +236,11 @@ template void LaunchAddProbsKernel( cudaStream_t stream); template -__global__ void UpdateInputsKernel(const T* old_mask_data, - T* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length) { +__global__ void UpdateGptInputsKernel(const T* old_mask_data, + T* mask_data, + int32_t* next_positions, + int batch_beam_size, + int current_length) { int index = blockIdx.x * blockDim.x + threadIdx.x; if (index < batch_beam_size * current_length) { // Update attention mask. @@ -240,19 +255,20 @@ __global__ void UpdateInputsKernel(const T* old_mask_data, } } -void LaunchUpdateKernel(const int32_t* old_mask_data, - int32_t* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length, - cudaStream_t stream) { +void LaunchUpdateGptKernel(const int32_t* old_mask_data, + int32_t* mask_data, + int32_t* next_positions, + int batch_beam_size, + int current_length, + cudaStream_t stream) { assert(current_length > 0); int total_elements = batch_beam_size * current_length; constexpr int blockSize = 256; const int gridSize = (total_elements + blockSize - 1) / blockSize; - UpdateInputsKernel<<>>(old_mask_data, mask_data, next_positions, batch_beam_size, current_length); + UpdateGptInputsKernel<<>>( + old_mask_data, mask_data, next_positions, batch_beam_size, current_length); } } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h index de3d1898b8..b1685326a1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h +++ b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h @@ -2,8 +2,10 @@ // Licensed under the MIT License. #pragma once + #include #include + namespace onnxruntime { namespace contrib { namespace cuda { @@ -46,12 +48,12 @@ void LaunchNextTokenKernel(const int64_t* next_token_indices, int vocab_size, cudaStream_t stream); -void LaunchUpdateKernel(const int32_t* old_mask_data, - int32_t* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length, - cudaStream_t stream); +void LaunchUpdateGptKernel(const int32_t* old_mask_data, + int32_t* mask_data, + int32_t* next_positions, + int batch_beam_size, + int current_length, + cudaStream_t stream); } // namespace cuda } // namespace contrib diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc index ab5d66443d..2788b798ea 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.cc @@ -1,9 +1,10 @@ // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. + #include #include "core/providers/cuda/cuda_common.h" -#include "dump_cuda_tensor.h" #include "core/framework/print_tensor_utils.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" namespace onnxruntime { namespace contrib { @@ -38,11 +39,13 @@ class PinnedHostBuffer { }; template -void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1) { +void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool is_gpu_tensor) { + // Occasionally, user will need dump CPU tensor in CUDA EP. + // In that case, we copy tensor data as well. It is not needed, but it keeps code simple. int num_items = dim0 * dim1; auto data = std::make_shared>(num_items); cudaDeviceSynchronize(); - cudaMemcpy(*data, tensor, num_items * sizeof(T), cudaMemcpyDeviceToHost); + cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost); if (nullptr != name) { std::cout << std::string(name) << std::endl; @@ -56,11 +59,11 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1) { } template -void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) { +void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, bool is_gpu_tensor) { int num_items = dim0 * dim1 * dim2; auto data = std::make_shared>(num_items); cudaDeviceSynchronize(); - cudaMemcpy(*data, tensor, num_items * sizeof(T), cudaMemcpyDeviceToHost); + cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost); if (nullptr != name) { std::cout << std::string(name) << std::endl; @@ -75,14 +78,15 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) { MLDataType dataType = tensor.DataType(); + bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU); if (dataType == DataTypeImpl::GetType()) { - DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2); + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { - DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2); + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { - DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2); + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { - DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2); + DumpGpuTensor(name, tensor.Data(), dim0, dim1, dim2, is_gpu_tensor); } else { assert(0); } @@ -90,14 +94,15 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, i void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) { MLDataType dataType = tensor.DataType(); + bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU); if (dataType == DataTypeImpl::GetType()) { - DumpGpuTensor(name, tensor.Data(), dim0, dim1); + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { - DumpGpuTensor(name, tensor.Data(), dim0, dim1); + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { - DumpGpuTensor(name, tensor.Data(), dim0, dim1); + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); } else if (dataType == DataTypeImpl::GetType()) { - DumpGpuTensor(name, tensor.Data(), dim0, dim1); + DumpGpuTensor(name, tensor.Data(), dim0, dim1, is_gpu_tensor); } else { assert(0); } @@ -106,12 +111,18 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) { void DumpGpuTensor(const char* name, const Tensor& tensor) { const auto& shape = tensor.Shape(); + if (nullptr != name) { + std::cout << std::string(name) << std::endl; + } + std::cout << "Shape:" << shape << std::endl; + std::cout << tensor.Location().ToString() << std::endl; + size_t num_dims = shape.NumDimensions(); if (num_dims >= 3) { int dim0 = static_cast(shape.SizeToDimension(num_dims - 2)); int dim1 = static_cast(shape[num_dims - 2]); int dim2 = static_cast(shape[num_dims - 1]); - DumpGpuTensor(name, tensor, dim0, dim1, dim2); + DumpGpuTensor(nullptr, tensor, dim0, dim1, dim2); return; } @@ -121,47 +132,47 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) { num_rows = static_cast(shape[0]); } size_t row_size = num_items / num_rows; - DumpGpuTensor(name, tensor, static_cast(num_rows), static_cast(row_size)); + DumpGpuTensor(nullptr, tensor, static_cast(num_rows), static_cast(row_size)); } void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1); + DumpGpuTensor(name, tensor, dim0, dim1, true); } void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1); + DumpGpuTensor(name, tensor, dim0, dim1, true); } void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1); + DumpGpuTensor(name, tensor, dim0, dim1, true); } void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1); + DumpGpuTensor(name, tensor, dim0, dim1, true); } void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const { if (is_enabled_) - DumpGpuTensor(name, tensor, dim0, dim1, dim2); + DumpGpuTensor(name, tensor, dim0, dim1, dim2, true); } void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const { @@ -235,4 +246,4 @@ void CudaTensorConsoleDumper::Print(const char*, const std::string&, bool) const } // namespace transformers } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h index 0b225570ca..2780c2cc5c 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h +++ b/onnxruntime/contrib_ops/cuda/transformers/dump_cuda_tensor.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once + #include #include "core/framework/tensorprotoutils.h" #include "core/framework/ort_value.h" @@ -33,4 +34,4 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons } // namespace transformers } // namespace cuda } // namespace contrib -} // namespace onnxruntime \ No newline at end of file +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 0a107a02c6..06fc4f2273 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -957,10 +957,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1, .SetDoc("Beam Search for text generation. Supports GPT-2 decoder.") .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) .Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast(0)) .Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast(0)) - .Attr("encoder_decoder_init", "subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") diff --git a/onnxruntime/python/tools/transformers/__init__.py b/onnxruntime/python/tools/transformers/__init__.py index 89ed09bcca..4200447eef 100644 --- a/onnxruntime/python/tools/transformers/__init__.py +++ b/onnxruntime/python/tools/transformers/__init__.py @@ -12,3 +12,5 @@ import convert_to_onnx # added for backward compatible import gpt2_helper + +sys.path.append(os.path.join(os.path.dirname(__file__), "models", "t5")) diff --git a/onnxruntime/python/tools/transformers/bert_test_data.py b/onnxruntime/python/tools/transformers/bert_test_data.py index dc818ae87c..12c2145fe3 100644 --- a/onnxruntime/python/tools/transformers/bert_test_data.py +++ b/onnxruntime/python/tools/transformers/bert_test_data.py @@ -9,9 +9,8 @@ import argparse import os import random -import sys from pathlib import Path -from typing import Dict, List, Tuple, Union +from typing import Dict, Optional, Tuple import numpy as np from onnx import ModelProto, TensorProto, numpy_helper @@ -27,7 +26,7 @@ def fake_input_ids_data( input_ids (TensorProto): graph input of the input_ids input tensor batch_size (int): batch size sequence_length (int): sequence length - dictionary_size (int): vacaburary size of dictionary + dictionary_size (int): vocabulary size of dictionary Returns: np.ndarray: the input tensor created @@ -115,28 +114,28 @@ def fake_input_mask_data( return data -def output_test_data(dir: str, inputs: np.ndarray): +def output_test_data(directory: str, inputs: Dict[str, np.ndarray]): """Output input tensors of test data to a directory Args: - dir (str): path of a directory - inputs (numpy.ndarray): numpy array + directory (str): path of a directory + inputs (Dict[str, np.ndarray]): map from input name to value """ - if not os.path.exists(dir): + if not os.path.exists(directory): try: - os.mkdir(dir) + os.mkdir(directory) except OSError: - print("Creation of the directory %s failed" % dir) + print("Creation of the directory %s failed" % directory) else: - print("Successfully created the directory %s " % dir) + print("Successfully created the directory %s " % directory) else: - print("Warning: directory %s existed. Files will be overwritten." % dir) + print("Warning: directory %s existed. Files will be overwritten." % directory) index = 0 for name, data in inputs.items(): tensor = numpy_helper.from_array(data, name) - with open(os.path.join(dir, "input_{}.pb".format(index)), "wb") as f: - f.write(tensor.SerializeToString()) + with open(os.path.join(directory, "input_{}.pb".format(index)), "wb") as file: + file.write(tensor.SerializeToString()) index += 1 @@ -158,7 +157,7 @@ def fake_test_data( batch_size (int): batch size sequence_length (int): sequence length test_cases (int): number of test cases - dictionary_size (int): vocaburary size of dictionary for input_ids + dictionary_size (int): vocabulary size of dictionary for input_ids verbose (bool): print more information or not random_seed (int): random seed input_ids (TensorProto): graph input of input IDs @@ -167,7 +166,8 @@ def fake_test_data( random_mask_length (bool): whether mask random number of words at the end Returns: - List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictonary with input name as key and a tensor as value + List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary + with input name as key and a tensor as value """ assert input_ids is not None @@ -202,7 +202,7 @@ def generate_test_data( input_mask: TensorProto, random_mask_length: bool, ): - """Create given number of minput data for testing + """Create given number of input data for testing Args: batch_size (int): batch size @@ -216,7 +216,8 @@ def generate_test_data( random_mask_length (bool): whether mask random number of words at the end Returns: - List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictonary with input name as key and a tensor as value + List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary + with input name as key and a tensor as value """ dictionary_size = 10000 all_inputs = fake_test_data( @@ -251,12 +252,13 @@ def get_graph_input_from_embed_node(onnx_model, embed_node, input_index): def find_bert_inputs( onnx_model: OnnxModel, - input_ids_name: str = None, - segment_ids_name: str = None, - input_mask_name: str = None, -) -> Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: + input_ids_name: Optional[str] = None, + segment_ids_name: Optional[str] = None, + input_mask_name: Optional[str] = None, +) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """Find graph inputs for BERT model. - First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming. + First, we will deduce inputs from EmbedLayerNormalization node. + If not found, we will guess the meaning of graph inputs based on naming. Args: onnx_model (OnnxModel): onnx model object @@ -266,10 +268,12 @@ def find_bert_inputs( Raises: ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name - ValueError: Exptected graph input number does not match with specifeid input_ids_name, segment_ids_name and input_mask_name + ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name + and input_mask_name Returns: - Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: input tensors of input_ids, segment_ids and input_mask + Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids, + segment_ids and input_mask """ graph_inputs = onnx_model.get_graph_inputs_excluding_initializers() @@ -340,12 +344,13 @@ def find_bert_inputs( def get_bert_inputs( onnx_file: str, - input_ids_name: str = None, - segment_ids_name: str = None, - input_mask_name: str = None, -): + input_ids_name: Optional[str] = None, + segment_ids_name: Optional[str] = None, + input_mask_name: Optional[str] = None, +) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: """Find graph inputs for BERT model. - First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming. + First, we will deduce inputs from EmbedLayerNormalization node. + If not found, we will guess the meaning of graph inputs based on naming. Args: onnx_file (str): onnx model path @@ -354,11 +359,12 @@ def get_bert_inputs( input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None. Returns: - Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: input tensors of input_ids, segment_ids and input_mask + Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids, + segment_ids and input_mask """ model = ModelProto() - with open(onnx_file, "rb") as f: - model.ParseFromString(f.read()) + with open(onnx_file, "rb") as file: + model.ParseFromString(file.read()) onnx_model = OnnxModel(model) return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name) @@ -447,9 +453,9 @@ def create_and_save_test_data( test_cases: int, seed: int, verbose: bool, - input_ids_name: str, - segment_ids_name: str, - input_mask_name: str, + input_ids_name: Optional[str], + segment_ids_name: Optional[str], + input_mask_name: Optional[str], only_input_tensors: bool, ): """Create test data for a model, and save test data to a directory. @@ -482,24 +488,24 @@ def create_and_save_test_data( ) for i, inputs in enumerate(all_inputs): - dir = os.path.join(output_dir, "test_data_set_" + str(i)) - output_test_data(dir, inputs) + directory = os.path.join(output_dir, "test_data_set_" + str(i)) + output_test_data(directory, inputs) if only_input_tensors: return import onnxruntime - sess = onnxruntime.InferenceSession(model) - output_names = [output.name for output in sess.get_outputs()] + session = onnxruntime.InferenceSession(model) + output_names = [output.name for output in session.get_outputs()] for i, inputs in enumerate(all_inputs): - dir = os.path.join(output_dir, "test_data_set_" + str(i)) - result = sess.run(output_names, inputs) + directory = os.path.join(output_dir, "test_data_set_" + str(i)) + result = session.run(output_names, inputs) for i, output_name in enumerate(output_names): - tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_names[i]) - with open(os.path.join(dir, "output_{}.pb".format(i)), "wb") as f: - f.write(tensor_result.SerializeToString()) + tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_name) + with open(os.path.join(directory, "output_{}.pb".format(i)), "wb") as file: + file.write(tensor_result.SerializeToString()) def main(): diff --git a/onnxruntime/python/tools/transformers/convert_beam_search.py b/onnxruntime/python/tools/transformers/convert_beam_search.py index ddfc43e3c8..70283d2852 100644 --- a/onnxruntime/python/tools/transformers/convert_beam_search.py +++ b/onnxruntime/python/tools/transformers/convert_beam_search.py @@ -6,11 +6,20 @@ This converts GPT2 or T5 model to onnx with beam search operator. Example 1: convert gpt2 model with beam search: - python convert_beam_search.py -m gpt2 --decoder_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores - -Example 2: convert T5 model with beam search: - python ./models/t5/convert_to_onnx.py -m t5-small -s - python convert_beam_search.py -m t5-small --model_type t5 --decoder_onnx ./onnx_models/t5-small_decoder.onnx --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder_decoder_init.onnx --output ./onnx_models/t5_small_beam_search.onnx + python convert_beam_search.py -m gpt2 --decoder_onnx ./onnx_models/gpt2_past_fp32.onnx \ + --output ./onnx_models/gpt2_beam_search.onnx --output_sequences_scores + +Example 2: convert T5 model with beam search in two steps: + cd ./models/t5 + python convert_to_onnx.py -m t5-small + cd ../.. + python convert_beam_search.py -m t5-small --model_type t5 \ + --decoder_onnx ./models/t5/onnx_models/t5-small_decoder.onnx \ + --encoder_decoder_init_onnx ./models/t5/onnx_models/t5-small_encoder_decoder_init.onnx \ + --output ./models/t5/onnx_models/t5_small_beam_search.onnx + +Example 3: convert T5 model with beam search. All in one step: + python convert_beam_search.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx """ import argparse @@ -19,27 +28,36 @@ import os import sys import time from pathlib import Path -from typing import List, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import onnx import torch from benchmark_helper import Precision -from onnx import helper from onnx import onnx_pb as onnx_proto -from packaging import version -from transformers import GPT2Config, T5Config +from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, T5Config, T5ForConditionalGeneration, T5Tokenizer + +from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2")) -from convert_to_onnx import main as convert_gpt2_to_onnx -from gpt2_helper import PRETRAINED_GPT2_MODELS +from gpt2_helper import PRETRAINED_GPT2_MODELS # noqa: E402 +from models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx # noqa: E402 -config: Union[GPT2Config, T5Config] = None +sys.path.append(os.path.join(os.path.dirname(__file__), "models", "t5")) +from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models # noqa: E402 logger = logging.getLogger("") -def parse_arguments(argv=None): +def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: + """Parse arguments + + Args: + argv (Optional[List[str]], optional): _description_. Defaults to None. + + Returns: + argparse.Namespace: Parsed arguments. + """ parser = argparse.ArgumentParser() parser.add_argument( @@ -69,9 +87,10 @@ def parse_arguments(argv=None): parser.add_argument( "--decoder_onnx", - required=True, + required=False, type=str, - help="Output directory for decoder onnx model, or model path ends with .onnx", + default="", + help="Path of onnx model for decoder. Required for gpt2 model type.", ) parser.add_argument( @@ -79,14 +98,14 @@ def parse_arguments(argv=None): required=False, type=str, default="", - help="path of ONNX model for encoder and decoder initialization. Required for t5 model type.", + help="Path of ONNX model for encoder and decoder initialization. For t5 model type.", ) parser.add_argument( "--output", required=False, type=str, - help="Output directory for beam search model, or model path ends with .onnx", + help="Output path for onnx model with beam search.", ) parser.add_argument( @@ -113,6 +132,14 @@ def parse_arguments(argv=None): ) parser.set_defaults(disable_parity=False) + parser.add_argument( + "--verbose", + required=False, + action="store_true", + help="Print more information", + ) + parser.set_defaults(verbose=False) + parser.add_argument( "--torch_performance", required=False, @@ -185,7 +212,7 @@ def parse_arguments(argv=None): type=float, required=False, default=1, - help="Positive. >1 to penalize and <1 to encorage short sentence.", + help="Positive. >1 to penalize and <1 to encourage short sentence.", ) beam_search_group.add_argument( @@ -193,7 +220,7 @@ def parse_arguments(argv=None): type=float, required=False, default=1, - help="Positive. >1 to penalize and <1 to encorage.", + help="Positive. >1 to penalize and <1 to encourage.", ) beam_search_group.add_argument( @@ -217,10 +244,14 @@ def parse_arguments(argv=None): return args -def gpt2_to_onnx(args): +def gpt2_to_onnx(args: argparse.Namespace): + """Convert GPT-2 model to onnx + + Args: + args (argparse.Namespace): arguments parsed from command line + """ model_name = args.model_name_or_path - print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.decoder_onnx} ...") arguments = [ "--model_name_or_path", model_name, @@ -233,7 +264,7 @@ def gpt2_to_onnx(args): "1", "--test_cases", "10", - "--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, postion_ids and attention_mask + "--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, position_ids and attention_mask ] if args.use_gpu: arguments.append("--use_gpu") @@ -242,39 +273,78 @@ def gpt2_to_onnx(args): if args.precision == Precision.FLOAT16: assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu" - # TODO: Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision') + # TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision') # Need change cuda kernel to support a combination of fp32 logits and fp16 past state. # Currently logits and past state shall be same data type. arguments.extend(["--op_block_list", "Add", "LayerNormalization", "FastGelu"]) - convert_gpt2_to_onnx(arguments) + + if args.verbose: + print(f"arguments for convert_to_onnx:{arguments}") + + convert_gpt2_to_onnx(argv=arguments) -def shape_inference(decoder_onnx_path): - if version.parse(onnx.__version__) >= version.parse("1.11.0"): - logger.warn("SymbolicShapeInference might fail using onnx version 1.11. Please install 1.10.0 for now.") +def t5_to_onnx(args: argparse.Namespace): + """Convert T5 model to onnx + Args: + args (argparse.Namespace): arguments parsed from command line + """ + paths = export_t5_onnx_models( + args.model_name_or_path, + args.cache_dir, + Path(args.output).parent, + use_gpu=args.use_gpu, + use_external_data_format=args.use_external_data_format, + optimize_onnx=False, + precision=args.precision, + verbose=False, + use_decoder_start_token=False, + merge_encoder_and_decoder_init=True, + overwrite=True, + disable_auto_mixed_precision=False, + use_int32_inputs=True, + ) + args.encoder_decoder_init_onnx = paths[0] + args.decoder_onnx = paths[1] + + +def shape_inference(onnx_path: str): + """Shape inference on an onnx file, which will be overwritten. + + Args: + onnx_path (str): Path of onnx model + """ # Run symbolic shape inference to walk around ORT shape inference issue for subgraph. from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference - out = SymbolicShapeInference.infer_shapes(onnx.load(decoder_onnx_path), auto_merge=True, guess_output_rank=False) + out = SymbolicShapeInference.infer_shapes(onnx.load(onnx_path), auto_merge=True, guess_output_rank=False) if out: - # TODO: Use external format if input has extra data. - onnx.save(out, decoder_onnx_path) + # TODO(tianleiwu): Use external format if input has extra data. + onnx.save(out, onnx_path) else: print("Failed to run symbolic shape inference on the model.") -def create_ort_session(model_path, use_gpu): - from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions - from onnxruntime import __version__ as ort_version - from onnxruntime import get_available_providers +def create_ort_session(model_path: str, use_gpu: bool) -> InferenceSession: + """Create OnnxRuntime session. + Args: + model_path (str): onnx model path + use_gpu (bool): use GPU or not + + Raises: + RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified. + + Returns: + onnxruntime.InferenceSession: The created session. + """ sess_options = SessionOptions() sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"] if use_gpu: if "CUDAExecutionProvider" not in get_available_providers(): - raise RuntimeError("CUDAExecutionProvider is not avaiable for --use_gpu!") + raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!") else: print("use CUDAExecutionProvider") @@ -282,11 +352,26 @@ def create_ort_session(model_path, use_gpu): return ort_session -def verify_gpt2_subgraph(graph, precision): +def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision): + """Verify GPT-2 subgraph + + Args: + graph (onnx.GraphProto): onnx graph of GPT-2 + precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. + + Raises: + ValueError: Number of inputs not expected. + ValueError: Input name is not expected. + ValueError: Input data type is not expected. + ValueError: Number of outputs not expected. + ValueError: Output name is not expected. + ValueError: Output data type is not expected. + """ is_float16 = Precision.FLOAT16 == precision input_count = len(graph.input) layer_count = input_count - 3 + assert layer_count >= 1 expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)] if len(graph.input) != len(expected_inputs): @@ -300,10 +385,9 @@ def verify_gpt2_subgraph(graph, precision): if i >= 3: expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT - if graph.input[i].type.tensor_type.elem_type != expected_type: - raise ValueError( - f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.input[i].type.tensor_type.elem_type}" - ) + input_type = graph.input[i].type.tensor_type.elem_type + if input_type != expected_type: + raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") print("Verifying GPT-2 graph inputs: name and data type are good.") expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)] @@ -315,49 +399,197 @@ def verify_gpt2_subgraph(graph, precision): raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT - if graph.output[i].type.tensor_type.elem_type != expected_type: - raise ValueError( - f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}" - ) + output_type = graph.output[i].type.tensor_type.elem_type + if output_type != expected_type: + raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}") print("Verifying GPT-2 graph outputs: name and data type are good.") - # TODO: verify shapes of inputs and outputs. + # TODO(tianleiwu): verify shapes of inputs and outputs. return -def verify_t5_decoder_subgraph(graph, precision): - # TODO: implement it - pass +def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision): + """Verify T5 decoder subgraph + + Args: + graph (onnx.GraphProto): onnx graph of T5 decoder + precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. + + Raises: + ValueError: Number of inputs not expected. + ValueError: Input name is not expected. + ValueError: Input data type is not expected. + ValueError: Number of outputs not expected. + ValueError: Output name is not expected. + ValueError: Output data type is not expected. + """ + is_float16 = Precision.FLOAT16 == precision + float_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT + + input_count = len(graph.input) + layer_count = (input_count - 3) // 4 + assert layer_count >= 1 + + # Expect inputs: + # input_ids: int32 (B, 1) + # encoder_attention_mask: int32 (B, encode_sequence_length) + # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + + # past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size) + # past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size) + # ... (for each self attention layer) + + # past_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + # past_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + # ... (for each cross attention layer) + expected_inputs = ["input_ids", "encoder_attention_mask", "encoder_hidden_states"] + for i in range(layer_count): + expected_inputs.append(f"past_key_self_{i}") + expected_inputs.append(f"past_value_self_{i}") + for i in range(layer_count): + expected_inputs.append(f"past_key_cross_{i}") + expected_inputs.append(f"past_value_cross_{i}") + + if len(graph.input) != len(expected_inputs): + raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") + + for i, expected_input in enumerate(expected_inputs): + if graph.input[i].name != expected_input: + raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") + + expected_type = onnx_proto.TensorProto.INT32 if i < 2 else float_type + input_type = graph.input[i].type.tensor_type.elem_type + if input_type != expected_type: + raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") + + # Expect outputs: + # logits: (B, 1, vocab_size) + # present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) + # present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size) + # ... (for each self attention layer) + expected_outputs = ["logits"] + for i in range(layer_count): + expected_outputs.append(f"present_key_self_{i}") + expected_outputs.append(f"present_value_self_{i}") + + if len(graph.output) != len(expected_outputs): + raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") + + for i, expected_output in enumerate(expected_outputs): + if graph.output[i].name != expected_output: + raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") + output_type = graph.output[i].type.tensor_type.elem_type + if output_type != float_type: + raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}") -def verify_t5_encoder_decoder_init_subgraph(graph, precision): - # TODO: implement it - pass +def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision): + """Verify T5 decoder subgraph + + Args: + graph (onnx.GraphProto): onnx graph of T5 decoder + precision (Precision): Precision (FLOAT16 or FLOAT32) of the model. + + Raises: + ValueError: Number of inputs not expected. + ValueError: Input name is not expected. + ValueError: Input data type is not expected. + ValueError: Number of outputs not expected. + ValueError: Output name is not expected. + ValueError: Output data type is not expected. + """ + is_float16 = Precision.FLOAT16 == precision + layer_count = (len(graph.output) - 2) // 4 + assert layer_count >= 1 + + # Expect 3 inputs: + # encoder_input_ids: int32 (B, encode_sequence_length) + # encoder_attention_mask: int32 (B, encode_sequence_length) + # decoder_input_ids: int32 (B, 1) + expected_inputs = ["encoder_input_ids", "encoder_attention_mask", "decoder_input_ids"] + if len(graph.input) != len(expected_inputs): + raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}") + + for i, expected_input in enumerate(expected_inputs): + if graph.input[i].name != expected_input: + raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}") + + expected_type = onnx_proto.TensorProto.INT32 + input_type = graph.input[i].type.tensor_type.elem_type + if input_type != expected_type: + raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}") + + # Expected outputs: + # logits: (B, 1, vocab_size) + # encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size) + # present_key_self_0: (B, num_heads, 1, head_size) + # present_value_self_0: (B, num_heads, 1, head_size) + # ... (for each self attention layer) + # present_key_cross_0: (B, num_heads, encode_sequence_length, head_size) + # present_value_cross_0: (B, num_heads, encode_sequence_length, head_size) + # ... (for each cross attention layer) + expected_outputs = ["logits", "encoder_hidden_states"] + for i in range(layer_count): + expected_outputs.append(f"present_key_self_{i}") + expected_outputs.append(f"present_value_self_{i}") + for i in range(layer_count): + expected_outputs.append(f"present_key_cross_{i}") + expected_outputs.append(f"present_value_cross_{i}") + + if len(graph.output) != len(expected_outputs): + raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}") + + for i, expected_output in enumerate(expected_outputs): + if graph.output[i].name != expected_output: + raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}") + + expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT + output_type = graph.output[i].type.tensor_type.elem_type + if output_type != expected_type: + raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}") + + print("T5 encoder graph verified: name and data type of inputs and outputs are good.") -def convert_model(args): - if os.path.exists(args.decoder_onnx): - print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}") - else: - assert args.model_type == "gpt2", "please have onnx model ready for model type that is not gpt2" - gpt2_to_onnx(args) +def convert_model(args: argparse.Namespace): + """Convert model according to command line arguments. - # TODO: fix shape inference for T5. Currently symbolic shape inference on T5 is broken. + Args: + args (argparse.Namespace): arguments parsed from command line + """ + is_gpt2: bool = args.model_type == "gpt2" + if is_gpt2: + if os.path.exists(args.decoder_onnx): + print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}") + else: + print(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...") + gpt2_to_onnx(args) + else: # t5 + if args.decoder_onnx and args.encoder_decoder_init_onnx: + print( + f"skip convert_to_onnx since paths specified: {args.decoder_onnx} and {args.encoder_decoder_init_onnx}" + ) + else: + print(f"Convert T5 model {args.model_name_or_path} to onnx ...") + t5_to_onnx(args) + + # TODO(tianleiwu): fix shape inference for T5. Currently symbolic shape inference on T5 is broken. enable_shape_inference = args.model_type == "gpt2" if enable_shape_inference: print(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.") shape_inference(args.decoder_onnx) - global config - if args.model_type == "gpt2": + if is_gpt2: config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) else: config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) - print(config) + + if args.verbose: + print(config) eos_token_id = config.eos_token_id - pad_token_id = config.eos_token_id + pad_token_id = config.eos_token_id if is_gpt2 else config.pad_token_id vocab_size = config.vocab_size # if vocab_size is given in parameters use that. @@ -394,7 +626,7 @@ def convert_model(args): assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" outputs.append("scores") - node = helper.make_node( + node = onnx.helper.make_node( "BeamSearch", inputs=inputs, outputs=outputs, @@ -403,12 +635,12 @@ def convert_model(args): node.domain = "com.microsoft" node.attribute.extend( [ - helper.make_attribute("eos_token_id", eos_token_id), - helper.make_attribute("pad_token_id", pad_token_id), - helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), - helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), - helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), - helper.make_attribute("decoder", model.graph), + onnx.helper.make_attribute("eos_token_id", eos_token_id), + onnx.helper.make_attribute("pad_token_id", pad_token_id), + onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), + onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), + onnx.helper.make_attribute("decoder", model.graph), ] ) @@ -421,22 +653,25 @@ def convert_model(args): verify_t5_encoder_decoder_init_subgraph(init_model.graph, args.precision) node.attribute.extend( [ - helper.make_attribute("encoder_decoder_init", init_model.graph), + onnx.helper.make_attribute("encoder", init_model.graph), + onnx.helper.make_attribute( + "decoder_start_token_id", config.decoder_start_token_id if len(init_model.graph.input) == 3 else -1 + ), ] ) from onnx import TensorProto # graph inputs - input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"]) - max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1]) - min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1]) - num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1]) - num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) - temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) - length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) - repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) - vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size]) + input_ids = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"]) + max_length = onnx.helper.make_tensor_value_info("max_length", TensorProto.INT32, [1]) + min_length = onnx.helper.make_tensor_value_info("min_length", TensorProto.INT32, [1]) + num_beams = onnx.helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1]) + num_return_sequences = onnx.helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1]) + temperature = onnx.helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1]) + length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) + repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) + vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size]) graph_inputs = [ input_ids, @@ -451,23 +686,23 @@ def convert_model(args): ] if args.prefix_vocab_mask: - prefix_vocab_mask = helper.make_tensor_value_info( + prefix_vocab_mask = onnx.helper.make_tensor_value_info( "prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size] ) graph_inputs.append(prefix_vocab_mask) # graph outputs - sequences = helper.make_tensor_value_info( + sequences = onnx.helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"], ) - sequences_scores = helper.make_tensor_value_info( + sequences_scores = onnx.helper.make_tensor_value_info( "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] ) - scores = helper.make_tensor_value_info( + scores = onnx.helper.make_tensor_value_info( "scores", TensorProto.FLOAT, ["max_length - sequence_length", "batch_size", "num_beams", vocab_size], @@ -483,7 +718,7 @@ def convert_model(args): if args.output_token_scores: graph_outputs.append(scores) - new_graph = helper.make_graph( + new_graph = onnx.helper.make_graph( [node], f"{args.model_type}-beam-search", graph_inputs, @@ -492,18 +727,44 @@ def convert_model(args): ) # Create the model - new_model = helper.make_model( + new_model = onnx.helper.make_model( new_graph, producer_name="onnxruntime.transformers", opset_imports=model.opset_import, ) + + # TODO(tianleiwu): move shared initializers from T5 encoder and decoder subgraphs to parent graph to save memory. onnx.save(new_model, args.output) -def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id, pad_token_id, bad_words_ids): +def test_torch_performance( + args: argparse.Namespace, + model: Union[GPT2LMHeadModel, T5ForConditionalGeneration], + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + eos_token_id: int, + pad_token_id: int, + bad_words_ids: List[List[int]], +) -> Dict[str, Any]: + """Test PyTorch performance of text generation. + + Args: + args (argparse.Namespace): arguments parsed from command line + model (Union[GPT2LMHeadModel, T5ForConditionalGeneration]): PyTorch model + input_ids (torch.Tensor): input_ids + attention_mask (torch.Tensor): Attention mask + eos_token_id (int): EOS token ID + pad_token_id (int): Padding token ID + bad_words_ids (List[List[int]]): Words shall not be generated. + + Raises: + RuntimeError: PyTorch with CUDA is not available for --use_gpu + + Returns: + Dict[str, Any]: A dictionary with string with metric name, and value can be integer or string. + """ if args.use_gpu and not torch.cuda.is_available(): - logger.error("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.") - return None + raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.") if args.precision == Precision.FLOAT16: model.half() @@ -543,23 +804,27 @@ def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id, return get_latency_result(torch_latency, batch_size) -def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): - if args.model_type != "gpt2": - print( - f"Skipping parity test since the support for model type {args.model_type} is not implemented in OnnxRuntime" - ) - return True +def test_gpt_model(args: argparse.Namespace, use_vocab_mask: bool = False, sentences: Optional[List[str]] = None): + """Test GPT-2 model + + Args: + args (argparse.Namespace): arguments parsed from command line + use_vocab_mask (bool, optional): use vocabulary mask. Defaults to False. + sentences (Optional[List[str]], optional): input text. Defaults to None. + + Returns: + Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. + """ + assert args.model_type == "gpt2" if args.temperature != 1.0: - # TODO: implement temperature in BeamSearch operator. + # TODO(tianleiwu): implement temperature in BeamSearch operator. print("Skipping parity test as temperature is not implemented in BeamSearch operator") - return True + return None if args.prefix_vocab_mask: print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face") - return True - - from transformers import GPT2LMHeadModel, GPT2Tokenizer + return None tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) tokenizer.padding_side = "left" @@ -589,14 +854,200 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): if use_vocab_mask: print("bad_words_ids", bad_words_ids) else: - bad_words_ids = None + bad_words_ids = [] - global config config = model.config eos_token_id = config.eos_token_id pad_token_id = config.eos_token_id vocab_size = config.vocab_size + torch_decoded_sequences = [] + beam_outputs = None + if not args.disable_parity: + print("-" * 50) + print("Test PyTorch model and beam search with huggingface transformers...") + beam_outputs = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_length=args.max_length, + min_length=args.min_length, + num_beams=args.num_beams, + early_stopping=args.early_stopping, + no_repeat_ngram_size=args.no_repeat_ngram_size, + eos_token_id=eos_token_id, + pad_token_id=pad_token_id, + num_return_sequences=args.num_return_sequences, + temperature=args.temperature, + length_penalty=args.length_penalty, + repetition_penalty=args.repetition_penalty, + bad_words_ids=bad_words_ids, + return_dict_in_generate=True, + output_scores=args.output_sequences_scores or args.output_token_scores, + ) + print("input_ids", input_ids) + print("huggingface transformers outputs:") + print("sequences", beam_outputs.sequences) + if args.output_sequences_scores: + print("sequences_scores", beam_outputs.sequences_scores) + if args.output_token_scores: + print("scores", beam_outputs.scores) + for i, sequence in enumerate(beam_outputs.sequences): + decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True) + torch_decoded_sequences.append(decoded_sequence) + print(f"{i}: {decoded_sequence}") + + print("-" * 50) + print("Test ONNX model and bream search with onnxruntime...") + + ort_session = create_ort_session(args.output, args.use_gpu) + + vocab_mask = np.ones((vocab_size), dtype=np.int32) + if use_vocab_mask: + for bad_word_id in bad_words_ids: + vocab_mask[bad_word_id] = 0 + + inputs = { + "input_ids": input_ids.cpu().numpy().astype(np.int32), + "max_length": np.array([args.max_length], dtype=np.int32), + "min_length": np.array([args.min_length], dtype=np.int32), + "num_beams": np.array([args.num_beams], dtype=np.int32), + "num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32), + "temperature": np.array([args.temperature], dtype=np.float32), + "length_penalty": np.array([args.length_penalty], dtype=np.float32), + "repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32), + "vocab_mask": vocab_mask, + } + + print("inputs", inputs) + result = ort_session.run(None, inputs) + + test_data_dir = Path(args.output).parent.as_posix() + print("test_data_dir", test_data_dir) + from bert_test_data import output_test_data + + all_inputs = [inputs] + for i, inputs in enumerate(all_inputs): + dir = os.path.join(test_data_dir, "test_data_set_" + str(i)) + output_test_data(dir, inputs) + + # Test performance + latency = [] + for _ in range(args.total_runs): + start = time.time() + _ = ort_session.run(None, inputs) + latency.append(time.time() - start) + batch_size = input_ids.shape[0] + from benchmark_helper import get_latency_result + + output = get_latency_result(latency, batch_size) + + print("ORT outputs:") + sequences = result[0] + print("sequences", sequences) + if args.output_sequences_scores: + print("sequences_scores", result[1]) + if args.output_token_scores: + print("scores", result[2]) + + (batch_size, num_sequences, max_length) = sequences.shape + ort_decoded_sequences = [] + for i in range(batch_size): + for j in range(num_sequences): + decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True) + ort_decoded_sequences.append(decoded_sequence) + print(f"batch {i} sequence {j}: {decoded_sequence}") + + if beam_outputs: + torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1) + ort_sequences = torch.LongTensor(sequences) + print("-" * 50) + print("Torch Sequences:") + print(torch_sequences) + print(torch_decoded_sequences) + print("-" * 50) + print("ORT Sequences:") + print(ort_sequences) + print(ort_decoded_sequences) + print("-" * 50) + # Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not. + is_same = torch_decoded_sequences == ort_decoded_sequences + print("Torch and ORT result is ", "same" if is_same else "different") + output["parity"] = is_same + + if args.torch_performance: + torch_latency_output = test_torch_performance( + args, + model, + input_ids, + attention_mask, + eos_token_id, + pad_token_id, + bad_words_ids, + ) + print("Torch Latency", torch_latency_output) + + print("ORT", output) + return output + + +def test_t5_model(args: argparse.Namespace, use_vocab_mask: bool = False, sentences: Optional[List[str]] = None): + """Test T5 model + + Args: + args (argparse.Namespace): arguments parsed from command line + use_vocab_mask (bool, optional): use vocabulary mask. Defaults to False. + sentences (Optional[List[str]], optional): input text. Defaults to None. + + Returns: + Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. + """ + assert args.model_type == "t5" + + if args.temperature != 1.0: + # TODO(tianleiwu): implement temperature in BeamSearch operator. + print("Skipping parity test as temperature is not implemented in BeamSearch operator") + return None + + if args.prefix_vocab_mask: + print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face") + return None + + tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir) + tokenizer.padding_side = "left" + + model = T5ForConditionalGeneration.from_pretrained( + args.model_name_or_path, + cache_dir=args.cache_dir, + ) + + # Use different length sentences to test batching + if sentences is None: + sentences = [ + "translate English to French: The product is released", + "summarize: research continues to show that pets bring real health benefits to their owners." + + "Having a dog around can lead to lower levels of stress for both adults and kids.", + # "summarize: I enjoy walking in the park. It makes my mind feel calm and refreshed. " + # + "I enjoy looking at the trees, flowers, and wildlife around me, and listening to sound from natural.", + ] + + inputs = tokenizer(sentences, return_tensors="pt", padding=True) + input_ids = inputs["input_ids"] + attention_mask = inputs["attention_mask"] + + bad_words = "walk in park" + bad_words_ids = tokenizer.encode(bad_words)[:-1] # exclude the last token (EOS) + bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list + if use_vocab_mask: + print("bad_words_ids", bad_words_ids) + else: + bad_words_ids = [] + + config = model.config + eos_token_id = config.eos_token_id + pad_token_id = config.pad_token_id + vocab_size = config.vocab_size + print(f"eos_token_id:{eos_token_id}, pad_token_id:{pad_token_id}, vocab_size:{vocab_size}") + torch_decoded_sequences = [] if not args.disable_parity: print("-" * 50) @@ -619,6 +1070,7 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): return_dict_in_generate=True, output_scores=args.output_sequences_scores or args.output_token_scores, ) + print("input_ids", input_ids) print("huggingface transformers outputs:") print("sequences", beam_outputs.sequences) @@ -724,17 +1176,44 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None): return output -def main(argv=None, sentences=None): +def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None): + """Main entry function + + Args: + argv (Optional[List[str]], optional): _description_. Defaults to None. + sentences (Optional[List[str]], optional): input text. Defaults to None. + + Raises: + ValueError: --decoder_onnx is not specified for GPT2 model + ValueError: Path does not exist: --encoder_decoder_init_onnx + ValueError: Path does not exist: --decoder_onnx + ValueError: --decoder_onnx and --encoder_decoder_init_onnx are not used together for T5 + + Returns: + Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string. + """ + args = parse_arguments(argv) + + if args.model_type == "gpt2": + if not args.decoder_onnx: + raise ValueError("--decoder_onnx shall be specified for gpt2 model") + elif args.model_type == "t5": + if args.encoder_decoder_init_onnx and not os.path.exists(args.encoder_decoder_init_onnx): + raise ValueError(f"Path does not exist: --encoder_decoder_init_onnx {args.encoder_decoder_init_onnx}") + if args.decoder_onnx and not os.path.exists(args.decoder_onnx): + raise ValueError(f"Path does not exist: --decoder_onnx {args.decoder_onnx}") + if (args.encoder_decoder_init_onnx and not args.decoder_onnx) or ( + args.decoder_onnx and not args.encoder_decoder_init_onnx + ): + raise ValueError("--decoder_onnx shall use together with --encoder_decoder_init_onnx") + + convert_model(args) + if args.model_type == "t5": - assert args.encoder_decoder_init_onnx, "please export t5 to onnx models before using this tool" - - if os.path.exists(args.output): - print(f"skip conversion since path existed: {args.output}") + return test_t5_model(args, use_vocab_mask=True, sentences=sentences) else: - convert_model(args) - - return test_model(args, use_vocab_mask=True, sentences=sentences) + return test_gpt_model(args, use_vocab_mask=True, sentences=sentences) if __name__ == "__main__": diff --git a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py index 231f08cfc4..5eb381b038 100644 --- a/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/gpt2/convert_to_onnx.py @@ -1,5 +1,4 @@ # ------------------------------------------------------------------------- -# ------------------------------------------------------------------------- # Copyright (c) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. See License.txt in the project root for # license information. diff --git a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py index 72be075089..7095ab244a 100644 --- a/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/t5/convert_to_onnx.py @@ -14,7 +14,7 @@ import torch from t5_helper import PRETRAINED_T5_MODELS, T5Helper sys.path.append(os.path.join(os.path.dirname(__file__), "..", "..")) -from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger +from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger # noqa: E402 logger = logging.getLogger("") @@ -80,7 +80,7 @@ def parse_arguments(): "--use_decoder_start_token", required=False, action="store_true", - help="Use config.decoder_start_token_id in decoding. Otherwise, add an extra graph input for decoder_input_ids.", + help="Use config.decoder_start_token_id. Otherwise, add an extra graph input for decoder_input_ids.", ) parser.set_defaults(use_decoder_start_token=False) @@ -101,6 +101,22 @@ def parse_arguments(): ) parser.set_defaults(disable_auto_mixed_precision=False) + parser.add_argument( + "--separate_encoder_and_decoder_init", + required=False, + action="store_true", + help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.", + ) + parser.set_defaults(separate_encoder_and_decoder_init=False) + + parser.add_argument( + "--use_int64_inputs", + required=False, + action="store_true", + help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.", + ) + parser.set_defaults(use_int64_inputs=False) + args = parser.parse_args() return args @@ -115,10 +131,11 @@ def export_onnx_models( optimize_onnx, precision, verbose, - use_decoder_start_token: bool = True, + use_decoder_start_token: bool = False, merge_encoder_and_decoder_init: bool = True, overwrite: bool = False, disable_auto_mixed_precision: bool = False, + use_int32_inputs: bool = True, ): device = torch.device("cuda:0" if use_gpu else "cpu") @@ -126,7 +143,7 @@ def export_onnx_models( config = models["decoder"].config if (not use_external_data_format) and (config.num_layers > 24): - logger.info(f"Try use_external_data_format when model size > 2GB") + logger.info("Try use_external_data_format when model size > 2GB") output_paths = [] for name, model in models.items(): @@ -151,6 +168,7 @@ def export_onnx_models( verbose, use_external_data_format, use_decoder_input_ids=not use_decoder_start_token, + use_int32_inputs=use_int32_inputs, ) else: logger.info(f"Skip exporting: existed ONNX model {onnx_path}") @@ -185,10 +203,12 @@ def export_onnx_models( use_gpu=use_gpu, provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"], ) - max_diff = T5Helper.verify_onnx(model, ort_session, device) + + with torch.no_grad(): + max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs) logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}") if max_diff > 1e-4: - logger.warn(f"PyTorch and OnnxRuntime results are NOT close") + logger.warning("PyTorch and OnnxRuntime results are NOT close") output_paths.append(output_path) @@ -212,24 +232,23 @@ def main(): assert args.use_gpu, "fp16 requires --use_gpu" if args.optimize_onnx: - logger.warn(f"Graph optimization for T5 is not implemented yet.") + logger.warning("Graph optimization for T5 is not implemented yet.") - with torch.no_grad(): - merge_encoder_and_decoder_init = True # Merge encoder and decoder initialization into one model is recommended. - output_paths = export_onnx_models( - args.model_name_or_path, - cache_dir, - output_dir, - args.use_gpu, - args.use_external_data_format, - args.optimize_onnx, - args.precision, - args.verbose, - args.use_decoder_start_token, - merge_encoder_and_decoder_init, - args.overwrite, - args.disable_auto_mixed_precision, - ) + output_paths = export_onnx_models( + args.model_name_or_path, + cache_dir, + output_dir, + args.use_gpu, + args.use_external_data_format, + args.optimize_onnx, + args.precision, + args.verbose, + args.use_decoder_start_token, + not args.separate_encoder_and_decoder_init, + args.overwrite, + args.disable_auto_mixed_precision, + not args.use_int64_inputs, + ) logger.info(f"Done! Outputs: {output_paths}") diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py index 0f98b9f7ce..76941fcf41 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py @@ -135,6 +135,7 @@ class T5DecoderInputs: past_decode_sequence_length: int, device: torch.device, float16: bool = False, + use_int32_inputs: bool = False, ): # -> T5DecoderInputs: """Create dummy inputs for T5Decoder. @@ -145,6 +146,7 @@ class T5DecoderInputs: past_decode_sequence_length (int): past sequence length of input_ids for decoder device (torch.device): device of output tensors float16 (bool): whether the model uses float32 or float16 in input + use_int32_inputs(bool): whether use int32 instead of int64 for some inputs Returns: T5DecoderInputs: dummy inputs for decoder @@ -159,11 +161,17 @@ class T5DecoderInputs: low=0, high=vocab_size - 1, size=(batch_size, sequence_length), - dtype=torch.int64, + dtype=(torch.int32 if use_int32_inputs else torch.int64), device=device, ) - encoder_inputs = T5EncoderInputs.create_dummy(batch_size, encode_sequence_length, vocab_size, device) + encoder_inputs = T5EncoderInputs.create_dummy( + batch_size, + encode_sequence_length, + vocab_size, + device, + use_int32_inputs=use_int32_inputs, + ) float_type = torch.float16 if float16 else torch.float32 encoder_hidden_state = torch.rand( @@ -211,7 +219,7 @@ class T5DecoderInputs: def to_fp32(self): encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32) - past = [p.to(dtype=torch.float32) for p in self.past_key_values] + past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None return T5DecoderInputs( self.decoder_input_ids.clone(), self.encoder_attention_mask.clone(), @@ -228,6 +236,7 @@ class T5DecoderHelper: onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, + use_int32_inputs: bool = False, ): """Export decoder to ONNX @@ -237,6 +246,7 @@ class T5DecoderHelper: onnx_model_path (str): onnx path verbose (bool, optional): print verbose information. Defaults to True. use_external_data_format (bool, optional): use external data format or not. Defaults to False. + use_int32_inputs (bool, optional): use int32 inputs """ assert isinstance(decoder, (T5Decoder, T5DecoderInit)) @@ -246,10 +256,9 @@ class T5DecoderHelper: encode_sequence_length=3, past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0, device=device, + use_int32_inputs=use_int32_inputs, ) input_list = inputs.to_list() - with torch.no_grad(): - outputs = decoder(*input_list) past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False) present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True) @@ -351,7 +360,8 @@ class T5DecoderHelper: model: Union[T5Decoder, T5DecoderInit], ort_session: InferenceSession, device: torch.device, - max_cases=4, + use_int32_inputs: bool, + max_cases: int = 4, ): """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)" @@ -373,6 +383,7 @@ class T5DecoderHelper: past_decode_sequence_length, device=device, float16=float16, + use_int32_inputs=use_int32_inputs, ) # We use fp32 PyTroch model as baseline even when ONNX model is fp16 diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py index 72a6b8585a..b9861b4568 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder.py @@ -42,28 +42,30 @@ class T5EncoderInputs: @staticmethod def create_dummy( - batch_size: int, sequence_length: int, vocab_size: int, device: torch.device + batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False ): # -> T5EncoderInputs """Create dummy inputs for T5 encoder. Args: batch_size (int): batch size sequence_length (int): sequence length - vocab_size (int): vocaburary size + vocab_size (int): vocabulary size device (torch.device): device of output tensors Returns: T5EncoderInputs: dummy inputs for encoder """ + dtype = torch.int32 if use_int32_inputs else torch.int64 + input_ids = torch.randint( low=0, high=vocab_size - 1, size=(batch_size, sequence_length), - dtype=torch.int64, + dtype=dtype, device=device, ) - attention_mask = torch.ones([batch_size, sequence_length], dtype=torch.int64, device=device) + attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device) if sequence_length >= 2: for i in range(batch_size): padding_position = random.randint(0, sequence_length - 1) @@ -83,6 +85,7 @@ class T5EncoderHelper: onnx_model_path: str, verbose: bool = True, use_external_data_format: bool = False, + use_int32_inputs: bool = False, ): """Export encoder to ONNX @@ -95,12 +98,13 @@ class T5EncoderHelper: """ config = encoder.config encoder_inputs = T5EncoderInputs.create_dummy( - batch_size=2, sequence_length=4, vocab_size=config.vocab_size, device=device + batch_size=2, + sequence_length=4, + vocab_size=config.vocab_size, + device=device, + use_int32_inputs=use_int32_inputs, ) - with torch.no_grad(): - outputs = encoder(encoder_inputs.input_ids, encoder_inputs.attention_mask) - Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) torch_onnx_export( encoder, @@ -131,13 +135,16 @@ class T5EncoderHelper: return ort_session.run(None, ort_inputs) @staticmethod - def verify_onnx(model: T5Encoder, ort_session: InferenceSession, device: torch.device): + def verify_onnx( + model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False + ): """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" inputs = T5EncoderInputs.create_dummy( batch_size=4, sequence_length=11, vocab_size=model.config.vocab_size, device=device, + use_int32_inputs=use_int32_inputs, ) input_list = inputs.to_list() torch_outputs = model(*input_list) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py index fe7552809f..b3e28c9d05 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py @@ -7,10 +7,12 @@ import logging import os import sys +import tempfile from pathlib import Path -from typing import List +from typing import List, Optional import numpy +import onnx import torch from past_helper import PastKeyValuesHelper from t5_decoder import T5DecoderInit @@ -34,7 +36,7 @@ class T5EncoderDecoderInit(torch.nn.Module): decoder: torch.nn.Module, lm_head: torch.nn.Module, config: T5Config, - decoder_start_token_id: int = None, + decoder_start_token_id: Optional[int] = None, ): super().__init__() self.config = config @@ -67,15 +69,19 @@ class T5EncoderDecoderInitInputs: encode_sequence_length: int, use_decoder_input_ids: int, device: torch.device, + use_int32_inputs: bool = False, ): # -> T5EncoderDecoderInitInputs: encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy( - batch_size, encode_sequence_length, config.vocab_size, device + batch_size, + encode_sequence_length, + config.vocab_size, + device, + use_int32_inputs=use_int32_inputs, ) decoder_input_ids = None if use_decoder_input_ids: - decoder_input_ids = ( - torch.ones((batch_size, 1), dtype=torch.long, device=device) * config.decoder_start_token_id - ) + dtype = torch.int32 if use_int32_inputs else torch.int64 + decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids) @@ -95,6 +101,7 @@ class T5EncoderDecoderInitHelper: use_decoder_input_ids: bool = True, verbose: bool = True, use_external_data_format: bool = False, + use_int32_inputs: bool = False, ): """Export decoder to ONNX @@ -113,9 +120,9 @@ class T5EncoderDecoderInitHelper: encode_sequence_length=3, use_decoder_input_ids=use_decoder_input_ids, device=device, + use_int32_inputs=use_int32_inputs, ) input_list = inputs.to_list() - outputs = model(*input_list) present_names = PastKeyValuesHelper.get_past_names(model.config.num_layers, present=True) @@ -135,7 +142,8 @@ class T5EncoderDecoderInitHelper: input_names = ["encoder_input_ids", "encoder_attention_mask"] - # ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2'. Use more friendly string here. + # ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference. + # We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value. sequence_length = "1" num_heads = str(model.config.num_heads) hidden_size = str(model.config.d_model) @@ -149,12 +157,18 @@ class T5EncoderDecoderInitHelper: 1: "encode_sequence_length", 2: hidden_size, }, - "logits": {0: "batch_size", 1: sequence_length}, + "logits": { + 0: "batch_size", + 1: sequence_length, + }, } if use_decoder_input_ids: input_names.append("decoder_input_ids") - dynamic_axes["decoder_input_ids"] = {0: "batch_size", 1: sequence_length} + dynamic_axes["decoder_input_ids"] = { + 0: "batch_size", + 1: sequence_length, + } for name in present_names: if "cross" in name: @@ -173,20 +187,47 @@ class T5EncoderDecoderInitHelper: 3: head_size, } - Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) - torch_onnx_export( - model, - args=tuple(input_list), - f=onnx_model_path, - export_params=True, - input_names=input_names, - output_names=output_names, - dynamic_axes=dynamic_axes, - opset_version=12, - do_constant_folding=True, - use_external_data_format=use_external_data_format, - verbose=verbose, - ) + with tempfile.TemporaryDirectory() as tmp_dir_name: + temp_onnx_model_path = os.path.join(tmp_dir_name, "model.onnx") + Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True) + torch_onnx_export( + model, + args=tuple(input_list), + f=temp_onnx_model_path, + export_params=True, + input_names=input_names, + output_names=output_names, + dynamic_axes=dynamic_axes, + opset_version=12, + do_constant_folding=True, + use_external_data_format=use_external_data_format, + verbose=verbose, + ) + + # Workaround as mentioned earlier: change numeric dim_param to dim_value + model = onnx.load(temp_onnx_model_path) + for tensor in model.graph.output: + for dim_proto in tensor.type.tensor_type.shape.dim: + if dim_proto.HasField("dim_param") and dim_proto.dim_param in [ + sequence_length, + num_heads, + hidden_size, + head_size, + ]: + dim_value = int(dim_proto.dim_param) + dim_proto.Clear() + dim_proto.dim_value = dim_value + + Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True) + onnx.save_model( + model, + onnx_model_path, + save_as_external_data=use_external_data_format, + all_tensors_to_one_file=True, + location=onnx_model_path + ".data", + size_threshold=4096, + convert_attribute=False, + ) @staticmethod def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs): @@ -208,7 +249,8 @@ class T5EncoderDecoderInitHelper: model: T5EncoderDecoderInit, ort_session: InferenceSession, device: torch.device, - max_cases=4, + use_int32_inputs: bool, + max_cases: int = 4, ): """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" ort_inputs = ort_session.get_inputs() @@ -223,6 +265,7 @@ class T5EncoderDecoderInitHelper: encode_sequence_length, use_decoder_input_ids=use_decoder_input_ids, device=device, + use_int32_inputs=use_int32_inputs, ) ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py index 389e6b33a3..5505203f12 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_helper.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_helper.py @@ -26,7 +26,7 @@ from optimizer import optimize_model logger = logging.getLogger(__name__) -PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3B", "t5-11B"] +PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"] class T5Helper: @@ -110,9 +110,17 @@ class T5Helper: verbose: bool = True, use_external_data_format: bool = False, use_decoder_input_ids: bool = True, + use_int32_inputs: bool = False, ): if isinstance(model, T5Encoder): - T5EncoderHelper.export_onnx(model, device, onnx_model_path, verbose, use_external_data_format) + T5EncoderHelper.export_onnx( + model, + device, + onnx_model_path, + verbose, + use_external_data_format, + use_int32_inputs, + ) elif isinstance(model, T5EncoderDecoderInit): T5EncoderDecoderInitHelper.export_onnx( model, @@ -121,9 +129,17 @@ class T5Helper: use_decoder_input_ids, verbose, use_external_data_format, + use_int32_inputs, ) else: - T5DecoderHelper.export_onnx(model, device, onnx_model_path, verbose, use_external_data_format) + T5DecoderHelper.export_onnx( + model, + device, + onnx_model_path, + verbose, + use_external_data_format, + use_int32_inputs, + ) @staticmethod def auto_mixed_precision( @@ -234,11 +250,13 @@ class T5Helper: model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit], ort_session: InferenceSession, device: torch.device, + use_int32_inputs: bool, ): """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good.""" if isinstance(model, T5Encoder): - return T5EncoderHelper.verify_onnx(model, ort_session, device) - elif isinstance(model, T5EncoderDecoderInit): - return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device) - else: - return T5DecoderHelper.verify_onnx(model, ort_session, device) + return T5EncoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs) + + if isinstance(model, T5EncoderDecoderInit): + return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs) + + return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs) diff --git a/onnxruntime/python/tools/transformers/requirements.txt b/onnxruntime/python/tools/transformers/requirements.txt index b1908acdd3..a299aa2374 100644 --- a/onnxruntime/python/tools/transformers/requirements.txt +++ b/onnxruntime/python/tools/transformers/requirements.txt @@ -7,6 +7,7 @@ py3nvml packaging transformers >= 4.0 scipy +sentencepiece # please follow https://pytorch.org/ to install PyTorch for your OS torch >= 1.8 \ No newline at end of file diff --git a/onnxruntime/test/python/transformers/test_beam_search.py b/onnxruntime/test/python/transformers/test_beam_search.py index 56bf6a3b21..3aa3f8bdbf 100644 --- a/onnxruntime/test/python/transformers/test_beam_search.py +++ b/onnxruntime/test/python/transformers/test_beam_search.py @@ -10,63 +10,189 @@ import os import unittest import pytest +import torch from parity_utilities import find_transformers_source -if find_transformers_source(): +from onnxruntime import get_available_providers + +if find_transformers_source() and find_transformers_source(["models", "t5"]): + from benchmark_helper import Precision from convert_beam_search import main as run + from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models else: + from onnxruntime.transformers.benchmark_helper import Precision from onnxruntime.transformers.convert_beam_search import main as run + from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models -class TestBeamSearch(unittest.TestCase): +class TestBeamSearchGpt(unittest.TestCase): + """Test BeamSearch for GPT-2 model""" + def setUp(self): - # TODO: use a smaller model and enable tests in CI pipeline self.model_name = "gpt2" self.gpt2_onnx_path = os.path.join(".", "onnx_models", "gpt2_past_fp32_shape.onnx") self.beam_search_onnx_path = os.path.join(".", "onnx_models", "gpt2_beam_search.onnx") - self.cpu_params = f"-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0" + self.default_arguments = [ + f"-m {self.model_name}", + f"--decoder_onnx {self.gpt2_onnx_path}", + f"--output {self.beam_search_onnx_path}", + "--output_sequences_score", + "--repetition_penalty 2.0", + ] + self.sentences = [ + "The product is released", + "I enjoy walking in the park", + "Test best way to invest", + ] + self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers() + self.remove_onnx_files() - def run_beam_search(self, arguments: str, sentences=None): - return run(arguments.split(), sentences=sentences) + def tearDown(self): + self.remove_onnx_files() + + def remove_onnx_files(self): + if os.path.exists(self.gpt2_onnx_path): + os.remove(self.gpt2_onnx_path) + + if os.path.exists(self.beam_search_onnx_path): + os.remove(self.beam_search_onnx_path) + + def run_beam_search(self, extra_arguments: str, sentences=None): + arguments = " ".join(self.default_arguments + [extra_arguments]).split() + + # Test CPU + result = run(arguments, sentences=self.sentences if sentences is None else sentences) + self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}") + + # Test GPU + if self.enable_cuda: + if "--use_gpu" not in arguments: + arguments.append("--use_gpu") + result = run(arguments, sentences=self.sentences if sentences is None else sentences) + self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}") + + os.remove(self.beam_search_onnx_path) @pytest.mark.slow - def test_cpu(self): - result = self.run_beam_search( - self.cpu_params + " --num_return_sequences 2", - sentences=["The product is released"], - ) - os.remove(self.gpt2_onnx_path) - os.remove(self.beam_search_onnx_path) - self.assertTrue(result["parity"], "ORT and PyTorch result is different") + def test_return_sequences(self): + for return_sequences in [1, 2]: + self.run_beam_search(f"--num_return_sequences {return_sequences}") @pytest.mark.slow def test_early_stopping(self): - result = self.run_beam_search(self.cpu_params + " --early_stopping") - os.remove(self.gpt2_onnx_path) - os.remove(self.beam_search_onnx_path) - self.assertTrue(result["parity"], "ORT and PyTorch result is different") + self.run_beam_search("--early_stopping") - @pytest.mark.slow - def test_temperature(self): - result = self.run_beam_search(self.cpu_params + " --temperature 0.5") - os.remove(self.gpt2_onnx_path) - os.remove(self.beam_search_onnx_path) - self.assertTrue(result["parity"], "ORT and PyTorch result is different") + # @pytest.mark.slow + # def test_temperature(self): + # self.run_beam_search("--temperature 0.5") @pytest.mark.slow def test_length_penalty(self): - result = self.run_beam_search(self.cpu_params + " --length_penalty 0.5") - os.remove(self.gpt2_onnx_path) - os.remove(self.beam_search_onnx_path) - self.assertTrue(result["parity"], "ORT and PyTorch result is different") + for length_penalty in [0.5, 2.0]: + self.run_beam_search(f"--length_penalty {length_penalty}") @pytest.mark.slow def test_no_repeat_ngram(self): for ngram_size in [1, 2]: - result = self.run_beam_search(self.cpu_params + f" --no_repeat_ngram_size {ngram_size}") - os.remove(self.gpt2_onnx_path) + self.run_beam_search(f"--no_repeat_ngram_size {ngram_size}") + + +class TestBeamSearchT5(unittest.TestCase): + """Test BeamSearch for T5 model""" + + def setUp(self): + self.model_name = "t5-small" + self.decoder_onnx_path = os.path.join(".", "onnx_models", "t5-small_decoder.onnx") + self.encoder_onnx_path = os.path.join(".", "onnx_models", "t5-small_encoder_decoder_init.onnx") + self.beam_search_onnx_path = os.path.join(".", "onnx_models", "t5_small_beam_search.onnx") + self.default_arguments = [ + f"-m {self.model_name}", + "--model_type t5", + f"--decoder_onnx {self.decoder_onnx_path}", + f"--encoder_decoder_init_onnx {self.encoder_onnx_path}", + f"--output {self.beam_search_onnx_path}", + "--output_sequences_score", + "--repetition_penalty 2.0", + ] + + self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers() + + export_t5_onnx_models( + self.model_name, + os.path.join(".", "cache_models"), + os.path.join(".", "onnx_models"), + use_gpu=False, + use_external_data_format=False, + optimize_onnx=False, + precision=Precision.FLOAT32, + verbose=False, + use_decoder_start_token=False, + merge_encoder_and_decoder_init=True, + overwrite=True, + disable_auto_mixed_precision=False, + use_int32_inputs=True, + ) + + self.sentences = [ + "translate English to French: The product is released", + "summarize: research continues to show that pets bring real health benefits to their owners." + + "Having a dog around can lead to lower levels of stress for both adults and kids.", + ] + + if os.path.exists(self.beam_search_onnx_path): os.remove(self.beam_search_onnx_path) - self.assertTrue(result["parity"], "ORT and PyTorch result is different") + + def tearDown(self): + self.remove_onnx_files() + + def remove_onnx_files(self): + if os.path.exists(self.beam_search_onnx_path): + os.remove(self.beam_search_onnx_path) + + if os.path.exists(self.decoder_onnx_path): + os.remove(self.decoder_onnx_path) + + if os.path.exists(self.encoder_onnx_path): + os.remove(self.encoder_onnx_path) + + def run_beam_search(self, extra_arguments: str, sentences=None): + arguments = " ".join(self.default_arguments + [extra_arguments]).split() + + # Test CPU + result = run(arguments, sentences=self.sentences if sentences is None else sentences) + self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}") + + # Test GPU + if self.enable_cuda: + if "--use_gpu" not in arguments: + arguments.append("--use_gpu") + result = run(arguments, sentences=self.sentences if sentences is None else sentences) + self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}") + + os.remove(self.beam_search_onnx_path) + + @pytest.mark.slow + def test_return_sequences(self): + for return_sequences in [1, 2]: + self.run_beam_search(f"--num_return_sequences {return_sequences}") + + @pytest.mark.slow + def test_early_stopping(self): + self.run_beam_search("--early_stopping") + + # @pytest.mark.slow + # def test_temperature(self): + # self.run_beam_search("--temperature 0.5") + + @pytest.mark.slow + def test_length_penalty(self): + for length_penalty in [0.5, 2.0]: + self.run_beam_search(f"--length_penalty {length_penalty}") + + @pytest.mark.slow + def test_no_repeat_ngram(self): + for ngram_size in [1, 2]: + self.run_beam_search(f"--no_repeat_ngram_size {ngram_size}") if __name__ == "__main__":