mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-31 23:27:43 +00:00
Support T5 in BeamSearch operator (#11450)
(1) Support T5 in BeamSearch operator, and add both CPU and CUDA implementation. (2) Change BeamSearch op: rename encoder_decoder_init attribute to encoder, and add decoder_start_token_id attribute (3) Update convert_to_onnx for T5 to use int32 instead of int64 inputs as default. (4) Add more tests in best_beam_search.py (5) fix ORT_ENFORCE of hypothesis_buffer_offset_ (6) Improve ONNX conversion: (a) Change encoder some dynamic axes to fixed dim value (b) add --separate_encoder_and_decoder_init (c) correct name t5-3B => t5-3b, t5-11B => t5-11b (d) Add --use_int32_inputs in convert t5 to onnx (e) Allow t5 beam search conversion in one step
This commit is contained in:
parent
768b9cfb60
commit
def78a1b81
47 changed files with 3708 additions and 1465 deletions
|
|
@ -355,10 +355,12 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dl>
|
||||
<dt><tt>decoder</tt> : graph (required)</dt>
|
||||
<dd>Decoder subgraph to execute in a loop.</dd>
|
||||
<dt><tt>decoder_start_token_id</tt> : int</dt>
|
||||
<dd>The id of the token that indicates decoding starts.</dd>
|
||||
<dt><tt>early_stopping</tt> : int</dt>
|
||||
<dd>early stop or not</dd>
|
||||
<dt><tt>encoder_decoder_init</tt> : graph</dt>
|
||||
<dd>subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
|
||||
<dt><tt>encoder</tt> : graph</dt>
|
||||
<dd>The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
|
||||
<dt><tt>eos_token_id</tt> : int (required)</dt>
|
||||
<dd>The id of the end-of-sequence token</dd>
|
||||
<dt><tt>model_type</tt> : int</dt>
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@
|
|||
#pragma warning(disable : 4996)
|
||||
#endif
|
||||
|
||||
#include <memory>
|
||||
#include <assert.h>
|
||||
#include <functional>
|
||||
#include "core/common/safeint.h"
|
||||
|
|
@ -26,11 +27,13 @@
|
|||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include "gsl/gsl"
|
||||
#include "beam_search.h"
|
||||
#include "logits_processor.h"
|
||||
#include "sequences.h"
|
||||
#include "dump_tensor.h"
|
||||
#include "beam_search_scorer.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search.h"
|
||||
#include "contrib_ops/cpu/transformers/logits_processor.h"
|
||||
#include "contrib_ops/cpu/transformers/sequences.h"
|
||||
#include "contrib_ops/cpu/transformers/dump_tensor.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_impl_gpt.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_impl_t5.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
|
@ -53,590 +56,156 @@ REGISTER_KERNEL_TYPED(float)
|
|||
|
||||
namespace transformers {
|
||||
|
||||
template <typename T>
|
||||
gsl::span<T> AllocateBuffer(AllocatorPtr allocator,
|
||||
BufferUniquePtr& buffer,
|
||||
size_t elements,
|
||||
bool fill = false,
|
||||
T fill_value = T{}) {
|
||||
size_t bytes = SafeInt<size_t>(sizeof(T)) * elements;
|
||||
void* data = allocator->Alloc(bytes);
|
||||
BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
|
||||
buffer = std::move(temp_buffer);
|
||||
T* first = reinterpret_cast<T*>(buffer.get());
|
||||
auto span = gsl::make_span(first, elements);
|
||||
|
||||
if (fill) {
|
||||
std::fill_n(first, elements, fill_value);
|
||||
}
|
||||
|
||||
return span;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct BeamSearchState : public IBeamSearchState<T> {
|
||||
void Init(AllocatorPtr allocator,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
bool output_scores) {
|
||||
size_t batch_beam_size = SafeInt<size_t>(batch_size) * num_beams;
|
||||
|
||||
size_t next_token_size = SafeInt<size_t>(batch_beam_size) * vocab_size;
|
||||
this->next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size);
|
||||
this->next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size);
|
||||
|
||||
this->next_tokens = AllocateBuffer<int32_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size);
|
||||
|
||||
this->next_indices = AllocateBuffer<int32_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size);
|
||||
|
||||
this->next_positions = AllocateBuffer<int32_t>(allocator, next_positions_buffer_, batch_beam_size);
|
||||
|
||||
this->beam_scores = AllocateBuffer<float>(allocator, beam_scores_buffer_, batch_beam_size);
|
||||
|
||||
if (output_scores) {
|
||||
size_t elements = SafeInt<size_t>(max_length - sequence_length) * batch_size * num_beams * vocab_size;
|
||||
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements);
|
||||
this->remaining_scores = this->scores;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
BufferUniquePtr next_token_logits_buffer_;
|
||||
BufferUniquePtr next_token_scores_buffer_;
|
||||
BufferUniquePtr next_tokens_buffer_;
|
||||
BufferUniquePtr next_indices_buffer_;
|
||||
BufferUniquePtr next_positions_buffer_;
|
||||
BufferUniquePtr beam_scores_buffer_;
|
||||
BufferUniquePtr scores_buffer_;
|
||||
};
|
||||
|
||||
struct BeamSearchCpuState : public IBeamSearchCpuState {
|
||||
Sequences sequences;
|
||||
|
||||
void Init(AllocatorPtr allocator, size_t batch_beam_size, int max_length, bool is_cuda) {
|
||||
this->sequence_lengths = AllocateBuffer<int32_t>(allocator, sequence_lengths_buffer_, batch_beam_size);
|
||||
this->sequences_space = AllocateBuffer<int32_t>(allocator, sequences_space_buffer_, SafeInt<size_t>(2) * batch_beam_size * max_length);
|
||||
|
||||
if (is_cuda) {
|
||||
// buffers used by CUDA operator but not by CPU operator.
|
||||
this->topk_scores = AllocateBuffer<float>(allocator, topk_scores_buffer_, 2 * batch_beam_size);
|
||||
this->topk_tokens = AllocateBuffer<int32_t>(allocator, topk_tokens_buffer_, 2 * batch_beam_size);
|
||||
this->topk_indices = AllocateBuffer<int32_t>(allocator, topk_indices_buffer_, 2 * batch_beam_size);
|
||||
this->final_beam_scores = AllocateBuffer<float>(allocator, final_beam_scores_buffer_, batch_beam_size);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
BufferUniquePtr final_beam_scores_buffer_;
|
||||
BufferUniquePtr sequence_lengths_buffer_;
|
||||
BufferUniquePtr topk_scores_buffer_;
|
||||
BufferUniquePtr topk_tokens_buffer_;
|
||||
BufferUniquePtr topk_indices_buffer_;
|
||||
BufferUniquePtr sequences_space_buffer_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class BeamSearchImpl {
|
||||
public:
|
||||
BeamSearchImpl(OpKernelContextInternal& context,
|
||||
const SessionState& session_state,
|
||||
GptSubgraph& gpt_subgraph,
|
||||
concurrency::ThreadPool* thread_pool,
|
||||
void* cuda_stream,
|
||||
IConsoleDumper* cuda_dumper,
|
||||
BeamSearchParameters& params,
|
||||
const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const BeamSearchDeviceHelper::TopkFunc& topk_func,
|
||||
const BeamSearchDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
|
||||
const BeamSearchDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const BeamSearchDeviceHelper::UpdateFeedsFunc<T>& update_feeds_func)
|
||||
: context_(context),
|
||||
session_state_(session_state),
|
||||
gpt_subgraph_(gpt_subgraph),
|
||||
thread_pool_(thread_pool),
|
||||
implicit_inputs_(context_.GetImplicitInputs()),
|
||||
cuda_stream_(cuda_stream),
|
||||
cuda_dumper_(cuda_dumper),
|
||||
parameters_(¶ms),
|
||||
cpu_allocator_(nullptr),
|
||||
temp_space_allocator_(nullptr),
|
||||
create_inputs_func_(create_inputs_func),
|
||||
add_to_feeds_func_(add_to_feeds_func),
|
||||
topk_func_(topk_func),
|
||||
process_logits_func_(process_logits_func),
|
||||
init_beam_state_func_(init_beam_state_func),
|
||||
device_copy_func_(device_copy_func),
|
||||
update_feeds_func_(update_feeds_func) {
|
||||
parameters_->ParseFromInputs(&context);
|
||||
|
||||
cpu_allocator_ = session_state.GetExecutionProviders()
|
||||
.Get(onnxruntime::kCpuExecutionProvider)
|
||||
->GetAllocator(0, OrtMemTypeDefault);
|
||||
}
|
||||
|
||||
// Initialize by validating all the inputs, and allocating the output tensors.
|
||||
Status Initialize();
|
||||
|
||||
// Execute beam search in iterations util stopping criteria is reached.
|
||||
// In each iteration, GPT subgraph is called, and next token for each sequence is generated.
|
||||
Status Execute(const FeedsFetchesManager& feeds_fetches_manager);
|
||||
|
||||
private:
|
||||
bool IsCuda() const { return cuda_stream_ != nullptr; }
|
||||
|
||||
// Validate inputs.
|
||||
Status CheckInputs(const OpKernelContextInternal& context);
|
||||
|
||||
// Prepare the inputs for first inference of subgraph
|
||||
Status CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths, OrtValue& expanded_input_ids, std::vector<OrtValue>& feeds, IAllocatorUniquePtr<char>& buffer);
|
||||
|
||||
// Update the input for next iteration.
|
||||
Status UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices);
|
||||
|
||||
// Process logits and append next tokens to sequences.
|
||||
Status GenerateNextToken(const OrtValue& logits,
|
||||
gsl::span<int32_t>& beam_next_tokens,
|
||||
gsl::span<int32_t>& beam_indices,
|
||||
BeamSearchState<T>& beam_state,
|
||||
BeamSearchCpuState& cpu_state,
|
||||
int counter);
|
||||
|
||||
// Calculate scores from logits, then apply filtering and select next token for each beam.
|
||||
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
|
||||
BeamSearchState<T>& beam_state,
|
||||
BeamSearchCpuState& cpu_state,
|
||||
AllocatorPtr& allocator,
|
||||
int counter);
|
||||
|
||||
const IConsoleDumper* GetConsoleDumper() const { return IsCuda() ? cuda_dumper_ : &(cpu_dumper_); }
|
||||
|
||||
OpKernelContextInternal& context_;
|
||||
|
||||
const SessionState& session_state_;
|
||||
|
||||
GptSubgraph& gpt_subgraph_;
|
||||
|
||||
concurrency::ThreadPool* thread_pool_;
|
||||
|
||||
const std::vector<const OrtValue*>& implicit_inputs_;
|
||||
|
||||
void* cuda_stream_;
|
||||
|
||||
IConsoleDumper* cuda_dumper_;
|
||||
CpuTensorConsoleDumper cpu_dumper_;
|
||||
|
||||
BeamSearchParameters* parameters_;
|
||||
|
||||
LogitsProcessorList logits_processors_;
|
||||
|
||||
std::unique_ptr<BeamSearchScorer> beam_scorer_;
|
||||
|
||||
AllocatorPtr cpu_allocator_;
|
||||
AllocatorPtr temp_space_allocator_;
|
||||
|
||||
// Device specific functions
|
||||
BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_;
|
||||
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
BeamSearchDeviceHelper::TopkFunc topk_func_;
|
||||
BeamSearchDeviceHelper::ProcessLogitsFunc<T> process_logits_func_;
|
||||
BeamSearchDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
|
||||
BeamSearchDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
|
||||
BeamSearchDeviceHelper::UpdateFeedsFunc<T> update_feeds_func_;
|
||||
};
|
||||
|
||||
void BeamSearch::Init(const OpKernelInfo& info) {
|
||||
// Make sure the decoder attribute was present even though we don't need it here.
|
||||
parameters_.ParseFromAttributes(info);
|
||||
|
||||
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5).
|
||||
ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
|
||||
parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
|
||||
|
||||
ONNX_NAMESPACE::GraphProto proto;
|
||||
if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
|
||||
}
|
||||
|
||||
// Make sure the decoder attribute was present even though we don't need it here.
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("decoder", &proto).IsOK());
|
||||
ORT_IGNORE_RETURN_VALUE(proto);
|
||||
|
||||
parameters_.ParseFromAttributes(info);
|
||||
}
|
||||
|
||||
Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
|
||||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) {
|
||||
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
// TODO: handle another subgraph with attribute name "encoder_decode_init"
|
||||
if (attribute_name == "decoder") {
|
||||
const auto& node = Node();
|
||||
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
|
||||
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
|
||||
feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
|
||||
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
|
||||
gpt_subgraph_->num_heads,
|
||||
gpt_subgraph_->head_size,
|
||||
gpt_subgraph_->num_layers);
|
||||
const auto& node = Node();
|
||||
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
|
||||
if (attribute_name == "decoder") {
|
||||
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
gpt_subgraph_ = std::make_unique<GptSubgraph>(node, attribute_name, subgraph_session_state.GetGraphViewer());
|
||||
ORT_RETURN_IF_ERROR(gpt_subgraph_->Setup(session_state, subgraph_session_state));
|
||||
decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
|
||||
parameters_.SetSubgraphParameters(gpt_subgraph_->vocab_size,
|
||||
gpt_subgraph_->num_heads,
|
||||
gpt_subgraph_->head_size,
|
||||
gpt_subgraph_->num_layers);
|
||||
}
|
||||
} else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) {
|
||||
if (attribute_name == "encoder") {
|
||||
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
|
||||
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
t5_encoder_subgraph_ = std::make_unique<T5EncoderSubgraph>(node,
|
||||
attribute_name,
|
||||
subgraph_session_state.GetGraphViewer());
|
||||
ORT_RETURN_IF_ERROR(t5_encoder_subgraph_->Setup(session_state, subgraph_session_state));
|
||||
encoder_feeds_fetches_manager_ = t5_encoder_subgraph_->GetFeedsFetchesManager();
|
||||
|
||||
if (parameters_.decoder_start_token_id < 0) {
|
||||
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 2,
|
||||
"Encoder subgraph shall have 2 inputs when decoder_start_token_id attribute is empty");
|
||||
} else {
|
||||
ORT_RETURN_IF(t5_encoder_subgraph_->num_subgraph_inputs != 3,
|
||||
"Encoder subgraph shall have 3 inputs when decoder_start_token_id attribute is available");
|
||||
}
|
||||
} else if (attribute_name == "decoder") {
|
||||
ORT_ENFORCE(t5_decoder_subgraph_ == nullptr,
|
||||
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
t5_decoder_subgraph_ = std::make_unique<T5DecoderSubgraph>(node,
|
||||
attribute_name,
|
||||
subgraph_session_state.GetGraphViewer());
|
||||
ORT_RETURN_IF_ERROR(t5_decoder_subgraph_->Setup(session_state, subgraph_session_state));
|
||||
decoder_feeds_fetches_manager_ = t5_decoder_subgraph_->GetFeedsFetchesManager();
|
||||
parameters_.SetSubgraphParameters(t5_decoder_subgraph_->vocab_size,
|
||||
t5_decoder_subgraph_->num_heads,
|
||||
t5_decoder_subgraph_->head_size,
|
||||
t5_decoder_subgraph_->num_layers);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
||||
if (parameters_.model_type != 0) {
|
||||
// TODO: support encoder decoder model like T5
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Support of 'model_type' != 0 is not implemented");
|
||||
}
|
||||
|
||||
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
auto* session_state = ctx_internal->SubgraphSessionState("decoder");
|
||||
ORT_ENFORCE(session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
|
||||
ORT_ENFORCE(feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
|
||||
auto* decoder_session_state = ctx_internal->SubgraphSessionState("decoder");
|
||||
ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
|
||||
ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
||||
BeamSearchParameters parameters = parameters_; // make a copy since we will update the parameters based on inputs later
|
||||
// Make a copy of parameters since we will update it based on inputs later
|
||||
BeamSearchParameters parameters = parameters_;
|
||||
|
||||
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
|
||||
if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32
|
||||
BeamSearchGpt<float> impl{
|
||||
*ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
|
||||
BeamSearchCpuDeviceHelper::CreateGptInputs,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
|
||||
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits<float>,
|
||||
init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState<float>,
|
||||
device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<float>,
|
||||
device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<int32_t>,
|
||||
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateGptFeeds<float>};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(*decoder_feeds_fetches_manager_);
|
||||
} else { // Output float16
|
||||
BeamSearchGpt<MLFloat16> impl{
|
||||
*ctx_internal, *decoder_session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
|
||||
BeamSearchCpuDeviceHelper::CreateGptInputs,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
|
||||
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
|
||||
process_logits_fp16_func_,
|
||||
init_beam_state_fp16_func_,
|
||||
device_copy_func_,
|
||||
device_copy_int32_func_,
|
||||
update_gpt_feeds_fp16_func_};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(*decoder_feeds_fetches_manager_);
|
||||
}
|
||||
}
|
||||
|
||||
auto* encoder_session_state = ctx_internal->SubgraphSessionState("encoder");
|
||||
ORT_ENFORCE(encoder_session_state, "Subgraph SessionState was not found for 'encoder' attribute.");
|
||||
ORT_ENFORCE(encoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
|
||||
// Subgraph has constraint that the output is either float or float16
|
||||
if (!gpt_subgraph_->IsOutputFloat16()) {
|
||||
BeamSearchImpl<float> impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
|
||||
create_inputs_func_ ? create_inputs_func_ : BeamSearchCpuDeviceHelper::CreateInputs,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
|
||||
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits<float>,
|
||||
init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState<float>,
|
||||
device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<float>,
|
||||
update_feeds_func_ ? update_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateFeeds<float>};
|
||||
if (!t5_decoder_subgraph_->IsOutputFloat16()) {
|
||||
BeamSearchT5<float> impl{
|
||||
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
|
||||
*t5_decoder_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
|
||||
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : BeamSearchCpuDeviceHelper::ProcessLogits<float>,
|
||||
init_beam_state_func_ ? init_beam_state_func_ : BeamSearchCpuDeviceHelper::InitBeamState<float>,
|
||||
device_copy_func_ ? device_copy_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<float>,
|
||||
device_copy_int32_func_ ? device_copy_int32_func_ : BeamSearchCpuDeviceHelper::DeviceCopy<int32_t>,
|
||||
create_encoder_inputs_func_ ? create_encoder_inputs_func_ : BeamSearchCpuDeviceHelper::CreateEncoderInputs,
|
||||
update_decoder_feeds_func_ ? update_decoder_feeds_func_ : BeamSearchCpuDeviceHelper::UpdateDecoderFeeds<float>};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(*feeds_fetches_manager_);
|
||||
return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
} else {
|
||||
BeamSearchImpl<MLFloat16> impl{*ctx_internal, *session_state, *gpt_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
|
||||
create_inputs_func_ ? create_inputs_func_ : BeamSearchCpuDeviceHelper::CreateInputs,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
|
||||
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
|
||||
process_logits_fp16_func_,
|
||||
init_beam_state_fp16_func_,
|
||||
device_copy_func_,
|
||||
update_feeds_fp16_func_};
|
||||
BeamSearchT5<MLFloat16> impl{
|
||||
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
|
||||
*t5_decoder_subgraph_, thread_pool, cuda_stream_, dumper_, parameters,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : BeamSearchCpuDeviceHelper::AddToFeeds,
|
||||
topk_func_ ? topk_func_ : BeamSearchCpuDeviceHelper::TopK,
|
||||
process_logits_fp16_func_,
|
||||
init_beam_state_fp16_func_,
|
||||
device_copy_func_,
|
||||
device_copy_int32_func_,
|
||||
create_encoder_inputs_func_,
|
||||
update_decoder_feeds_fp16_func_};
|
||||
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(*feeds_fetches_manager_);
|
||||
return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::CheckInputs(const OpKernelContextInternal& context) {
|
||||
// Input shapes:
|
||||
// input_ids : (batch_size, sequence_length)
|
||||
// vocab_mask : (vocab_size) or nullptr
|
||||
|
||||
const Tensor* input_ids = context.Input<Tensor>(0);
|
||||
const auto& dims = input_ids->Shape().GetDims();
|
||||
if (dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'input_ids' is expected to have 2 dimensions, got ",
|
||||
dims.size());
|
||||
}
|
||||
|
||||
const Tensor* vocab_mask = context.Input<Tensor>(8);
|
||||
if (vocab_mask != nullptr) { // vocab_mask is optional
|
||||
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
|
||||
if (vocab_mask_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' is expected to have 1 dimension, got ",
|
||||
vocab_mask_dims.size());
|
||||
}
|
||||
|
||||
// There is dependency on vocab_size parameter, which shall be set before calling this function.
|
||||
if (static_cast<int>(vocab_mask_dims[0]) != parameters_->vocab_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'vocab_mask' shape does not match with vocab_size, got ",
|
||||
vocab_mask_dims[0]);
|
||||
}
|
||||
|
||||
// store vocab mask in parameters.
|
||||
parameters_->vocab_mask = vocab_mask->DataAsSpan<int32_t>();
|
||||
}
|
||||
|
||||
const Tensor* prefix_vocab_mask = context.Input<Tensor>(9);
|
||||
if (prefix_vocab_mask != nullptr) {
|
||||
// prefix_vocab_mask is optional
|
||||
const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims();
|
||||
if (vocab_mask_dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' is expected to have 2 dimensions, got ",
|
||||
vocab_mask_dims.size());
|
||||
}
|
||||
|
||||
// prefix_vocab_mask first dimension should be same as the first dimension of input_ids
|
||||
if (static_cast<int>(vocab_mask_dims[0]) != static_cast<int>(dims[0])) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "input_ids and prefix_vocab_mask must have the same batch_size");
|
||||
}
|
||||
|
||||
// There is dependency on vocab_size parameter, which shall be set before calling this function.
|
||||
if (static_cast<int>(vocab_mask_dims[1]) != parameters_->vocab_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 'prefix_vocab_mask' shape does not match with vocab_size, got ",
|
||||
vocab_mask_dims[0]);
|
||||
}
|
||||
|
||||
// store prefix vocab mask in parameters.
|
||||
parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan<int32_t>();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::Initialize() {
|
||||
ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator_));
|
||||
|
||||
#define CHECK_SCALAR_INPUT(name, index, required) \
|
||||
auto* name##_tensor = context_.Input<Tensor>(index); \
|
||||
if (name##_tensor) { \
|
||||
if (!name##_tensor->Shape().IsScalar()) { \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \
|
||||
name##_tensor->Shape()); \
|
||||
} \
|
||||
} else if (required) { \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \
|
||||
}
|
||||
|
||||
CHECK_SCALAR_INPUT(min_length, 1, false);
|
||||
|
||||
CHECK_SCALAR_INPUT(max_length, 2, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(num_beams, 3, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(num_return_sequences, 4, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(temperature, 5, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(length_penalty, 6, true);
|
||||
|
||||
ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams, "'num_return_sequences' has to be smaller or equal to 'num_beams'.");
|
||||
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(context_));
|
||||
|
||||
// This flag will be updated later when the scores output exists.
|
||||
parameters_->output_scores = false;
|
||||
|
||||
if (!IsCuda()) {
|
||||
// Logits processor is used in CPU only. In CUDA, cuda kernels are used instead.
|
||||
// Initialize processsors after CheckInputs so that parameters_->vocab_mask is ready.
|
||||
logits_processors_.Init(*parameters_);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths, OrtValue& expanded_input_ids, std::vector<OrtValue>& feeds, IAllocatorUniquePtr<char>& buffer) {
|
||||
const OrtValue* input_ids_value = context_.GetInputOrtValue(0);
|
||||
const Tensor& input_ids = input_ids_value->Get<Tensor>();
|
||||
return gpt_subgraph_.CreateInitialFeeds(input_ids, implicit_inputs_, parameters_->num_beams, parameters_->pad_token_id, sequence_lengths, expanded_input_ids, feeds, create_inputs_func_, add_to_feeds_func_, buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::ProcessLogits(
|
||||
const OrtValue& logits,
|
||||
BeamSearchState<T>& beam_state,
|
||||
BeamSearchCpuState& cpu_state,
|
||||
AllocatorPtr& allocator,
|
||||
int counter) {
|
||||
return process_logits_func_(logits, &beam_state, &cpu_state, &(cpu_state.sequences), allocator,
|
||||
thread_pool_, &logits_processors_, beam_scorer_.get(),
|
||||
parameters_, counter, cuda_stream_, GetConsoleDumper());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::GenerateNextToken(
|
||||
const OrtValue& logits,
|
||||
gsl::span<int32_t>& beam_next_tokens,
|
||||
gsl::span<int32_t>& beam_indices,
|
||||
BeamSearchState<T>& beam_state,
|
||||
BeamSearchCpuState& cpu_state,
|
||||
int counter) {
|
||||
// Process logits to get next token scores
|
||||
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, cpu_state, temp_space_allocator_, counter));
|
||||
|
||||
gsl::span<float>& beam_scores = beam_scorer_->GetNextScores();
|
||||
// It is optional to clone beam_scores. Change it to use same buffer also works for CPU:
|
||||
// beam_state.beam_scores = beam_scores
|
||||
// Here we make a copy to reduce the coupling with little cost (the buffer size is small).
|
||||
ORT_RETURN_IF_ERROR(device_copy_func_(beam_state.beam_scores, beam_scores, cuda_stream_, DeviceCopyDirection::hostToDevice));
|
||||
|
||||
beam_next_tokens = beam_scorer_->GetNextTokens();
|
||||
beam_indices = beam_scorer_->GetNextIndices();
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
cpu_dumper_.Print("beam_scores after scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
cpu_dumper_.Print("beam_next_tokens after scorer", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
cpu_dumper_.Print("beam_indices after scorer", beam_indices.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
#endif
|
||||
|
||||
cpu_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
cpu_state.sequences.PrintSequences(&cpu_dumper_);
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices) {
|
||||
return update_feeds_func_(temp_space_allocator_, cuda_stream_, last_outputs, next_inputs, current_length, position_ids,
|
||||
beam_next_tokens, beam_indices, parameters_->num_beams, GetConsoleDumper());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchImpl<T>::Execute(const FeedsFetchesManager& feeds_fetches_manager) {
|
||||
auto status = Status::OK();
|
||||
int64_t sequences_dims[] = {parameters_->batch_size, parameters_->num_return_sequences, parameters_->max_length};
|
||||
TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0]));
|
||||
Tensor* output_sequences = context_.Output(0, sequences_shape);
|
||||
|
||||
int64_t sequences_scores_dims[] = {parameters_->batch_size, parameters_->num_return_sequences};
|
||||
TensorShape sequences_scores_shape(&sequences_scores_dims[0], sizeof(sequences_scores_dims) / sizeof(sequences_scores_dims[0]));
|
||||
Tensor* output_sequences_scores = context_.Output(1, sequences_scores_shape);
|
||||
|
||||
int64_t scores_dims[] = {
|
||||
static_cast<int64_t>(parameters_->max_length) - static_cast<int64_t>(parameters_->sequence_length),
|
||||
parameters_->batch_size, parameters_->num_beams, parameters_->vocab_size};
|
||||
TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0]));
|
||||
Tensor* output_scores = context_.Output(2, scores_shape);
|
||||
|
||||
// Update the flag to indicate whether scores exists in output
|
||||
parameters_->output_scores = (output_scores != nullptr);
|
||||
|
||||
std::vector<OrtValue> feeds;
|
||||
// TODO: allocate fetches. use ping-pong buffers for past state.
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
// Initialize resources
|
||||
onnxruntime::OrtStlAllocator<HypothesisScore> hypothesis_score_allocator(cpu_allocator_);
|
||||
onnxruntime::OrtStlAllocator<BeamHypotheses> beam_hyps_allocator(cpu_allocator_);
|
||||
beam_scorer_ = std::make_unique<BeamSearchScorer>(static_cast<size_t>(parameters_->batch_size),
|
||||
static_cast<size_t>(parameters_->num_beams),
|
||||
static_cast<size_t>(parameters_->max_length),
|
||||
parameters_->length_penalty,
|
||||
parameters_->early_stopping,
|
||||
static_cast<size_t>(parameters_->num_return_sequences),
|
||||
parameters_->pad_token_id,
|
||||
parameters_->eos_token_id,
|
||||
hypothesis_score_allocator,
|
||||
beam_hyps_allocator);
|
||||
beam_scorer_->Initialize(cpu_allocator_, parameters_->sequence_length);
|
||||
|
||||
BeamSearchCpuState cpu_state;
|
||||
cpu_state.Init(cpu_allocator_, static_cast<size_t>(parameters_->BatchBeamSize()), parameters_->max_length, IsCuda());
|
||||
|
||||
// buffer in GPU for input_ids, position_ids and attention_mask
|
||||
// size_t buffer_bytes = SafeInt<size_t>(sizeof(int32_t) + sizeof(int32_t) + sizeof(int32_t)) * parameters_->batch_size * parameters_->num_beams * parameters_->sequence_length;
|
||||
// IAllocatorUniquePtr<char> buffer = gpt_subgraph_.GetProvider()->GetScratchBuffer<char>(buffer_bytes);
|
||||
IAllocatorUniquePtr<char> buffer;
|
||||
OrtValue expanded_input_ids_in_cpu;
|
||||
ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
|
||||
|
||||
BeamSearchState<T> beam_state;
|
||||
beam_state.Init(temp_space_allocator_,
|
||||
parameters_->batch_size,
|
||||
parameters_->num_beams,
|
||||
parameters_->vocab_size,
|
||||
parameters_->sequence_length,
|
||||
parameters_->max_length,
|
||||
parameters_->output_scores);
|
||||
|
||||
cpu_state.sequences.Init(cpu_state.sequences_space,
|
||||
parameters_->BatchBeamSize(),
|
||||
parameters_->sequence_length,
|
||||
parameters_->max_length);
|
||||
|
||||
gsl::span<const int32_t> input_ids = expanded_input_ids_in_cpu.Get<Tensor>().DataAsSpan<int32_t>();
|
||||
init_beam_state_func_(&beam_state,
|
||||
&cpu_state,
|
||||
cpu_state.sequence_lengths,
|
||||
parameters_->batch_size,
|
||||
parameters_->num_beams,
|
||||
input_ids,
|
||||
parameters_->sequence_length,
|
||||
parameters_->max_length,
|
||||
cuda_stream_);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
const IConsoleDumper* dumper = GetConsoleDumper();
|
||||
dumper->Print("input_ids", feeds[0]);
|
||||
dumper->Print("position_ids", feeds[1]);
|
||||
dumper->Print("attention_mask", feeds[2]);
|
||||
#endif
|
||||
|
||||
// position ids for all iterations except the first. It uses memory buffer owned by next_positions.
|
||||
OrtValue position_ids;
|
||||
int64_t dims[] = {parameters_->BatchBeamSize(), 1};
|
||||
TensorShape shape(&dims[0], 2);
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), shape, beam_state.next_positions.data(), temp_space_allocator_->Info(), position_ids);
|
||||
|
||||
int current_length = parameters_->sequence_length;
|
||||
int iteration_counter = 0;
|
||||
while (current_length < parameters_->max_length) {
|
||||
iteration_counter++;
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
auto cur_len = std::to_string(current_length);
|
||||
dumper->Print("***CurrentLength", cur_len, true);
|
||||
#endif
|
||||
|
||||
status = utils::ExecuteSubgraph(session_state_, feeds_fetches_manager, feeds, fetches, {},
|
||||
ExecutionMode::ORT_SEQUENTIAL, context_.GetTerminateFlag(), context_.Logger());
|
||||
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
const OrtValue& logits = fetches[0];
|
||||
gsl::span<int32_t> beam_next_tokens;
|
||||
gsl::span<int32_t> beam_indices;
|
||||
ORT_RETURN_IF_ERROR(GenerateNextToken(logits, beam_next_tokens, beam_indices, beam_state, cpu_state, iteration_counter));
|
||||
|
||||
// When all batches are finished, stop earlier to avoid wasting computation.
|
||||
if (beam_scorer_->IsDone()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Increase sequence length after a new token is generated.
|
||||
++current_length;
|
||||
|
||||
// Prepare inputs for next round of subgraph call.
|
||||
if (current_length < parameters_->max_length) {
|
||||
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
|
||||
position_ids,
|
||||
beam_next_tokens.as_span<const int32_t>(),
|
||||
beam_indices.as_span<const int32_t>()));
|
||||
}
|
||||
fetches.clear();
|
||||
}
|
||||
|
||||
gsl::span<const float> final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
|
||||
if (IsCuda()) {
|
||||
ORT_RETURN_IF_ERROR(device_copy_func_(cpu_state.final_beam_scores, final_beam_scores, nullptr, DeviceCopyDirection::deviceToHost));
|
||||
final_beam_scores = gsl::make_span<const float>(cpu_state.final_beam_scores.data(), cpu_state.final_beam_scores.size());
|
||||
}
|
||||
|
||||
beam_scorer_->Finalize(&(cpu_state.sequences),
|
||||
final_beam_scores,
|
||||
output_sequences,
|
||||
output_sequences_scores);
|
||||
|
||||
// Output per token scores
|
||||
if (output_scores != nullptr) {
|
||||
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
|
||||
gsl::span<const float> source = gsl::span<const float>(beam_state.scores.data(), beam_state.scores.size());
|
||||
assert(target.length() == source.length());
|
||||
ORT_RETURN_IF_ERROR(device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,12 +2,16 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/controlflow/utils.h"
|
||||
#include "beam_search_parameters.h"
|
||||
#include "gpt_subgraph.h"
|
||||
#include "beam_search_device_helper.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class FeedsFetchesManager;
|
||||
|
|
@ -20,7 +24,11 @@ using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
|
|||
class BeamSearch : public IControlFlowKernel {
|
||||
public:
|
||||
BeamSearch(const OpKernelInfo& info)
|
||||
: IControlFlowKernel(info), feeds_fetches_manager_(nullptr), cuda_stream_(nullptr), dumper_(nullptr) {
|
||||
: IControlFlowKernel(info),
|
||||
encoder_feeds_fetches_manager_(nullptr),
|
||||
decoder_feeds_fetches_manager_(nullptr),
|
||||
cuda_stream_(nullptr),
|
||||
dumper_(nullptr) {
|
||||
Init(info);
|
||||
}
|
||||
|
||||
|
|
@ -36,54 +44,76 @@ class BeamSearch : public IControlFlowKernel {
|
|||
void SetComputeStream(void* stream) { cuda_stream_ = stream; }
|
||||
void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; }
|
||||
|
||||
// device helpers that is same for both GPT and encoder-decoder models.
|
||||
void SetDeviceHelpers(
|
||||
// const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const BeamSearchDeviceHelper::TopkFunc& topk_func) {
|
||||
// create_inputs_func_ = create_inputs_func;
|
||||
const BeamSearchDeviceHelper::TopkFunc& topk_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
|
||||
const BeamSearchDeviceHelper::ProcessLogitsFunc<float>& process_logits_func,
|
||||
const BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16>& process_logits_fp16_func,
|
||||
const BeamSearchDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
|
||||
const BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_fp16_func) {
|
||||
add_to_feeds_func_ = add_to_feeds_func;
|
||||
topk_func_ = topk_func;
|
||||
}
|
||||
|
||||
// Type dependent helpers: float
|
||||
void SetDeviceHelpers(
|
||||
const BeamSearchDeviceHelper::ProcessLogitsFunc<float>& process_logits_func,
|
||||
const BeamSearchDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const BeamSearchDeviceHelper::UpdateFeedsFunc<float>& update_feeds_func) {
|
||||
process_logits_func_ = process_logits_func;
|
||||
init_beam_state_func_ = init_beam_state_func;
|
||||
device_copy_func_ = device_copy_func;
|
||||
update_feeds_func_ = update_feeds_func;
|
||||
device_copy_int32_func_ = device_copy_int32_func;
|
||||
process_logits_func_ = process_logits_func;
|
||||
process_logits_fp16_func_ = process_logits_fp16_func;
|
||||
init_beam_state_func_ = init_beam_state_func;
|
||||
init_beam_state_fp16_func_ = init_beam_state_fp16_func;
|
||||
}
|
||||
|
||||
// Type dependent helpers: MLFloat16
|
||||
void SetDeviceHelpers(
|
||||
const BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16>& process_logits_func,
|
||||
const BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_func,
|
||||
const BeamSearchDeviceHelper::UpdateFeedsFunc<MLFloat16>& update_feeds_func) {
|
||||
process_logits_fp16_func_ = process_logits_func;
|
||||
init_beam_state_fp16_func_ = init_beam_state_func;
|
||||
update_feeds_fp16_func_ = update_feeds_func;
|
||||
void SetDeviceHelpers_Gpt(
|
||||
const BeamSearchDeviceHelper::UpdateGptFeedsFunc<float>& update_gpt_feeds_func,
|
||||
const BeamSearchDeviceHelper::UpdateGptFeedsFunc<MLFloat16>& update_gpt_feeds_fp16_func) {
|
||||
update_gpt_feeds_func_ = update_gpt_feeds_func;
|
||||
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
|
||||
}
|
||||
|
||||
// device helpers for encoder-decoder model like T5
|
||||
void SetDeviceHelpers_EncoderDecoder(
|
||||
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float>& update_decoder_feeds_func,
|
||||
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16>& update_decoder_feeds_fp16_func) {
|
||||
update_decoder_feeds_func_ = update_decoder_feeds_func;
|
||||
update_decoder_feeds_fp16_func_ = update_decoder_feeds_fp16_func;
|
||||
}
|
||||
|
||||
private:
|
||||
// Device specific functions
|
||||
BeamSearchDeviceHelper::CreateInputsFunc create_inputs_func_;
|
||||
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
BeamSearchDeviceHelper::TopkFunc topk_func_;
|
||||
BeamSearchDeviceHelper::ProcessLogitsFunc<float> process_logits_func_;
|
||||
BeamSearchDeviceHelper::InitBeamStateFunc<float> init_beam_state_func_;
|
||||
BeamSearchDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
|
||||
BeamSearchDeviceHelper::UpdateFeedsFunc<float> update_feeds_func_;
|
||||
BeamSearchDeviceHelper::DeviceCopyFunc<int32_t> device_copy_int32_func_;
|
||||
|
||||
BeamSearchDeviceHelper::ProcessLogitsFunc<float> process_logits_func_;
|
||||
BeamSearchDeviceHelper::ProcessLogitsFunc<MLFloat16> process_logits_fp16_func_;
|
||||
BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16> init_beam_state_fp16_func_;
|
||||
BeamSearchDeviceHelper::UpdateFeedsFunc<MLFloat16> update_feeds_fp16_func_;
|
||||
|
||||
BeamSearchDeviceHelper::InitBeamStateFunc<float> init_beam_state_func_;
|
||||
BeamSearchDeviceHelper::InitBeamStateFunc<MLFloat16> init_beam_state_fp16_func_;
|
||||
|
||||
//------------------------------------------------------------
|
||||
// Device specific functions for GPT
|
||||
//------------------------------------------------------------
|
||||
BeamSearchDeviceHelper::UpdateGptFeedsFunc<float> update_gpt_feeds_func_;
|
||||
BeamSearchDeviceHelper::UpdateGptFeedsFunc<MLFloat16> update_gpt_feeds_fp16_func_;
|
||||
|
||||
//------------------------------------------------------------
|
||||
// Device specific functions for encoder-decoder model like T5
|
||||
//------------------------------------------------------------
|
||||
BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_;
|
||||
|
||||
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<float> update_decoder_feeds_func_;
|
||||
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<MLFloat16> update_decoder_feeds_fp16_func_;
|
||||
|
||||
//------------------------------------------------------------
|
||||
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
|
||||
//------------------------------------------------------------
|
||||
std::unique_ptr<GptSubgraph> gpt_subgraph_;
|
||||
FeedsFetchesManager* feeds_fetches_manager_;
|
||||
std::unique_ptr<T5EncoderSubgraph> t5_encoder_subgraph_;
|
||||
std::unique_ptr<T5DecoderSubgraph> t5_decoder_subgraph_;
|
||||
FeedsFetchesManager* encoder_feeds_fetches_manager_;
|
||||
FeedsFetchesManager* decoder_feeds_fetches_manager_;
|
||||
|
||||
void* cuda_stream_;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <vector>
|
||||
#include <algorithm>
|
||||
#include "core/providers/cpu/math/top_k.h"
|
||||
#include "core/providers/cpu/math/softmax_shared.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "gsl/gsl"
|
||||
#include "sequences.h"
|
||||
#include "beam_search_scorer.h"
|
||||
#include "beam_search_device_helper.h"
|
||||
#include "contrib_ops/cpu/transformers/sequences.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -25,11 +32,10 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
input->DataType(), " is not supported yet");
|
||||
}
|
||||
|
||||
OrtValue ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator) {
|
||||
// Input shape (batch_size, sequence_length)
|
||||
template <typename T>
|
||||
void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded) {
|
||||
// Input shape (batch_size, sequence_length). The input is required with data type T.
|
||||
// Output shape (batch_size * num_beams, sequence_length)
|
||||
if (num_beams == 1)
|
||||
return input;
|
||||
|
||||
const TensorShape& input_shape = input.Get<Tensor>().Shape();
|
||||
const int64_t& batch_size = input_shape[0];
|
||||
|
|
@ -38,31 +44,28 @@ OrtValue ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocat
|
|||
int64_t dims[] = {batch_size * num_beams, sequence_length};
|
||||
TensorShape expanded_shape(&dims[0], 2);
|
||||
|
||||
OrtValue expanded;
|
||||
MLDataType element_type = input.Get<Tensor>().DataType();
|
||||
ORT_ENFORCE(element_type == DataTypeImpl::GetType<int32_t>(), "input_ids, position_ids and attention_mask is required to be int32 data type");
|
||||
ORT_ENFORCE(element_type == DataTypeImpl::GetType<T>());
|
||||
|
||||
Tensor::InitOrtValue(element_type, expanded_shape, allocator, expanded);
|
||||
|
||||
const int32_t* input_data = input.Get<Tensor>().Data<int32_t>();
|
||||
int32_t* expanded_data = expanded.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
int32_t* target = expanded_data;
|
||||
const T* input_data = input.Get<Tensor>().Data<T>();
|
||||
T* expanded_data = expanded.GetMutable<Tensor>()->MutableData<T>();
|
||||
T* target = expanded_data;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
for (int j = 0; j < num_beams; j++) {
|
||||
memcpy(target, input_data + i * sequence_length, sizeof(int32_t) * sequence_length);
|
||||
memcpy(target, input_data + i * sequence_length, sizeof(T) * sequence_length);
|
||||
target += sequence_length;
|
||||
}
|
||||
}
|
||||
|
||||
return expanded;
|
||||
}
|
||||
|
||||
Status CreateInputs(
|
||||
Status CreateGptInputs(
|
||||
const Tensor* original_input_ids,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
AllocatorPtr alloactor,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded_input_ids,
|
||||
OrtValue& expanded_position_ids,
|
||||
OrtValue& expanded_attention_mask) {
|
||||
|
|
@ -74,21 +77,22 @@ Status CreateInputs(
|
|||
// Allocate position_ids and attention_mask based on shape of input_ids
|
||||
auto element_type = DataTypeImpl::GetType<int32_t>();
|
||||
|
||||
const OrtMemoryInfo& location = alloactor->Info();
|
||||
const OrtMemoryInfo& location = allocator->Info();
|
||||
|
||||
// Use original input_ids. This requires the input_ids for subgraph is also int32.
|
||||
// Current shape is (batch_size, sequence_length)
|
||||
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
|
||||
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
|
||||
OrtValue input_ids;
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, const_cast<Tensor*>(original_input_ids)->MutableData<int32_t>(), location, input_ids);
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape,
|
||||
const_cast<Tensor*>(original_input_ids)->MutableData<int32_t>(), location, input_ids);
|
||||
|
||||
OrtValue position_ids;
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, alloactor, position_ids);
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, position_ids);
|
||||
|
||||
OrtValue attention_mask;
|
||||
auto mask_type = DataTypeImpl::GetType<int32_t>();
|
||||
Tensor::InitOrtValue(mask_type, input_ids_shape, alloactor, attention_mask);
|
||||
Tensor::InitOrtValue(mask_type, input_ids_shape, allocator, attention_mask);
|
||||
|
||||
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
|
||||
// Set position id to be 0 for pad tokens, and accumulated sum of mask in a batch for other tokens
|
||||
|
|
@ -115,43 +119,45 @@ Status CreateInputs(
|
|||
}
|
||||
}
|
||||
|
||||
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) for input_ids, position_ids and attention_mask
|
||||
// TODO: Try expand outputs after first subgraph call instead. That may get better performance, but more complex to implement.
|
||||
expanded_input_ids = ExpandInputs(input_ids, num_beams, alloactor);
|
||||
expanded_position_ids = ExpandInputs(position_ids, num_beams, alloactor);
|
||||
expanded_attention_mask = ExpandInputs(attention_mask, num_beams, alloactor);
|
||||
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
|
||||
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
|
||||
ExpandInputs<int32_t>(input_ids, num_beams, allocator, expanded_input_ids);
|
||||
ExpandInputs<int32_t>(position_ids, num_beams, allocator, expanded_position_ids);
|
||||
ExpandInputs<int32_t>(attention_mask, num_beams, allocator, expanded_attention_mask);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status AddToFeeds(const IExecutionProvider* /*execution_provider*/,
|
||||
OrtValue& input_ids,
|
||||
OrtValue& position_ids,
|
||||
OrtValue& attention_mask,
|
||||
std::initializer_list<OrtValue> inputs,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& /*buffer*/) {
|
||||
feeds.push_back(input_ids);
|
||||
feeds.push_back(position_ids);
|
||||
feeds.push_back(attention_mask);
|
||||
for (auto& input : inputs) {
|
||||
if (input.IsAllocated()) {
|
||||
feeds.push_back(input);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
|
||||
transformers::IBeamSearchCpuState* cpu_state,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
gsl::span<const int32_t> input_ids_in_cpu,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
void* /*stream*/) {
|
||||
memset(beam_state->beam_scores.data(), 0, beam_state->beam_scores.size_bytes());
|
||||
memset(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes());
|
||||
memset(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes());
|
||||
memset(beam_state->next_tokens.data(), 0, beam_state->next_tokens.size_bytes());
|
||||
memset(beam_state->next_indices.data(), 0, beam_state->next_indices.size_bytes());
|
||||
memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes());
|
||||
|
||||
// T5 does not need position, so next_positions is empty for T5.
|
||||
if (!beam_state->next_positions.empty()) {
|
||||
memset(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes());
|
||||
gsl::copy(sequence_lengths, beam_state->next_positions);
|
||||
}
|
||||
|
||||
// Initialize score of first beam of each group with 0 and the rest with -1e9.
|
||||
// This ensures that the beams in the same group don't produce same tokens every time.
|
||||
|
|
@ -161,19 +167,6 @@ void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
|
|||
beam_scores[SafeInt<gsl::index>(i) * num_beams + j] = -1e9;
|
||||
}
|
||||
}
|
||||
|
||||
gsl::copy(sequence_lengths, beam_state->next_positions);
|
||||
|
||||
memset(cpu_state->sequences_space.data(), 0, cpu_state->sequences_space.size_bytes());
|
||||
|
||||
// Copy input_ids to sequences[0].
|
||||
gsl::span<int32_t> sequences_0 = cpu_state->sequences_space;
|
||||
int batch_beam_size = batch_size * num_beams;
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
for (int j = 0; j < sequence_length; j++) {
|
||||
sequences_0[SafeInt<gsl::index>(i) * max_length + j] = static_cast<int32_t>(input_ids_in_cpu[SafeInt<gsl::index>(i) * sequence_length + j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -216,7 +209,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
const T* current_logits = logits_data + (input_length - 1) * vocab_size;
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
gsl::span<const T> source(current_logits, vocab_size);
|
||||
gsl::span<T> target = next_token_logits.subspan(SafeInt<gsl::index>(i) * vocab_size, static_cast<gsl::index>(vocab_size));
|
||||
gsl::span<T> target = next_token_logits.subspan(SafeInt<gsl::index>(i) * vocab_size,
|
||||
static_cast<gsl::index>(vocab_size));
|
||||
gsl::copy(source, target);
|
||||
current_logits += input_length * vocab_size;
|
||||
}
|
||||
|
|
@ -224,7 +218,9 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("logits", logits);
|
||||
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
|
||||
if (input_length > 1) {
|
||||
dumper->Print("next_token_logits", next_token_logits.data(), batch_size, num_beams, vocab_size);
|
||||
}
|
||||
#endif
|
||||
|
||||
// Get scores for candidates of next token: next_token_scores = log_softmax(next_token_logits, dim=-1)
|
||||
|
|
@ -244,12 +240,12 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
logits_processors->Process(sequences, next_token_scores, step);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
#endif
|
||||
|
||||
// Add beam score to next token scores. Corresponding python code is like:
|
||||
// next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
||||
// TODO: use thread pool to parrellel
|
||||
// TODO(tianleiwu): use thread pool to parallel
|
||||
int offset = 0;
|
||||
int batch_beam_index = 0;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
|
|
@ -261,7 +257,7 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
}
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("next_token_scores after adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
dumper->Print("next_token_scores adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
#endif
|
||||
|
||||
if (output_scores) {
|
||||
|
|
@ -277,7 +273,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
OrtValue next_token_scores_value;
|
||||
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value);
|
||||
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(),
|
||||
next_token_scores_value);
|
||||
const Tensor& input = next_token_scores_value.Get<Tensor>();
|
||||
|
||||
constexpr int axis = 1;
|
||||
|
|
@ -287,7 +284,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
|
||||
std::unique_ptr<Tensor> topk_scores;
|
||||
std::unique_ptr<Tensor> topk_indices;
|
||||
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, topk_scores, topk_indices));
|
||||
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
|
||||
topk_scores, topk_indices));
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
|
|
@ -331,32 +329,34 @@ Status DeviceCopy(gsl::span<T> target, gsl::span<const T> source, void* /*stream
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Copy present state to past state for GPT model
|
||||
template <typename T>
|
||||
void PickPastState(const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
gsl::span<const int32_t>& beam_indices,
|
||||
AllocatorPtr allocator,
|
||||
void* /*stream*/) {
|
||||
void PickGptPastState(const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
gsl::span<const int32_t>& beam_indices,
|
||||
AllocatorPtr allocator) {
|
||||
int num_present_tensors = static_cast<int>(last_outputs.size()) - transformers::GptSubgraph::kFirstPresentOutputIndex;
|
||||
for (int i = 0; i < num_present_tensors; ++i) {
|
||||
const OrtValue& present = last_outputs[transformers::GptSubgraph::kFirstPresentOutputIndex + i];
|
||||
|
||||
for (size_t i = 1; i < last_outputs.size(); ++i) {
|
||||
const OrtValue& present = last_outputs[i]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64)
|
||||
// shape is like (2, batch_beam_size, 12, past_seq_len, 64)
|
||||
const TensorShape& past_shape = present.Get<Tensor>().Shape();
|
||||
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
|
||||
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
|
||||
|
||||
// Create a tensor with same shape.
|
||||
// TODO: allocate one buffer for all layers
|
||||
// TODO(tianleiwu): allocate one buffer for all layers
|
||||
OrtValue past;
|
||||
auto past_type = DataTypeImpl::GetType<T>();
|
||||
Tensor::InitOrtValue(past_type, past_shape, allocator, past);
|
||||
|
||||
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
|
||||
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
|
||||
|
||||
gsl::span<T> past_span = gsl::make_span<T>(past.GetMutable<Tensor>()->MutableData<T>(), past_shape.Size());
|
||||
gsl::span<const T> present_span = gsl::make_span<const T>(present.Get<Tensor>().Data<T>(), past_shape.Size());
|
||||
for (gsl::index j = 0; j < beam_indices.length(); j++) {
|
||||
int32_t beam_index = beam_indices[j];
|
||||
gsl::span<const T> present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<const T> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<const T> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam,
|
||||
block_size_per_beam);
|
||||
|
||||
gsl::span<T> past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<T> past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
|
||||
|
|
@ -364,12 +364,12 @@ void PickPastState(const std::vector<OrtValue>& last_outputs,
|
|||
gsl::copy(present_value, past_value);
|
||||
}
|
||||
|
||||
next_inputs[i + 2] = past;
|
||||
next_inputs[transformers::GptSubgraph::kFirstPastInputIndex + i] = past;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status UpdateFeeds(
|
||||
Status UpdateGptFeeds(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
|
|
@ -382,6 +382,7 @@ Status UpdateFeeds(
|
|||
const transformers::IConsoleDumper* dumper) {
|
||||
// last_outputs: logits, present_0, present_1, ...
|
||||
// next_inputs: input_ids, position_id, attention_mask, past_0, past_1
|
||||
ORT_UNUSED_PARAMETER(stream);
|
||||
|
||||
// The following updates inputs for subgraph
|
||||
|
||||
|
|
@ -391,7 +392,7 @@ Status UpdateFeeds(
|
|||
TensorShape input_ids_shape(&dims[0], 2);
|
||||
auto int32_type = DataTypeImpl::GetType<int32_t>();
|
||||
OrtValue input_ids;
|
||||
// TODO: Reuse buffer for input_ids to reduce memory allocation.
|
||||
// TODO(tianleiwu): Reuse buffer for input_ids to reduce memory allocation.
|
||||
Tensor::InitOrtValue(int32_type, input_ids_shape, allocator, input_ids);
|
||||
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
|
|
@ -433,25 +434,187 @@ Status UpdateFeeds(
|
|||
// Update past state
|
||||
if (num_beams == 1) {
|
||||
// feed present_* output to past_* inputs one by one
|
||||
for (size_t i = 1; i < last_outputs.size(); ++i) {
|
||||
next_inputs[i + 2] = last_outputs[i];
|
||||
const int k = transformers::GptSubgraph::kFirstPastInputIndex - transformers::GptSubgraph::kFirstPresentOutputIndex;
|
||||
for (size_t i = transformers::GptSubgraph::kFirstPresentOutputIndex; i < last_outputs.size(); ++i) {
|
||||
next_inputs[i + k] = last_outputs[i];
|
||||
}
|
||||
} else {
|
||||
PickPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream);
|
||||
PickGptPastState<T>(last_outputs, next_inputs, beam_indices, allocator);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// The following functions are for encoder-decoder model like T5
|
||||
// ---------------------------------------------------------------
|
||||
Status CreateEncoderInputs(
|
||||
const Tensor* original_encoder_input_ids,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
int start_token_id,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded_encoder_input_ids,
|
||||
OrtValue& expanded_encoder_attention_mask,
|
||||
OrtValue& expanded_decoder_input_ids) {
|
||||
const TensorShape& input_ids_shape = original_encoder_input_ids->Shape();
|
||||
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
|
||||
const int64_t& batch_size = input_ids_shape[0];
|
||||
const int64_t& sequence_length = input_ids_shape[1];
|
||||
|
||||
// Allocate attention_mask based on shape of input_ids
|
||||
auto element_type = DataTypeImpl::GetType<int32_t>();
|
||||
|
||||
// Use original encoder_input_ids. This requires the input_ids for subgraph is also int32.
|
||||
// Current shape is (batch_size, sequence_length)
|
||||
// Note that we will expand it to (batch_size * num_beams, sequence_length) later.
|
||||
// To avoid cloning input_ids, we use const_cast here since this function does not change its content.
|
||||
OrtValue encoder_input_ids;
|
||||
Tensor::InitOrtValue(element_type,
|
||||
input_ids_shape,
|
||||
const_cast<Tensor*>(original_encoder_input_ids)->MutableData<int32_t>(),
|
||||
allocator->Info(),
|
||||
encoder_input_ids);
|
||||
|
||||
OrtValue encoder_attention_mask;
|
||||
auto mask_type = DataTypeImpl::GetType<int32_t>();
|
||||
Tensor::InitOrtValue(mask_type, input_ids_shape, allocator, encoder_attention_mask);
|
||||
|
||||
// Set attention mask to be 0 for pad tokens, and 1 for all other tokens.
|
||||
int32_t* mask_data = encoder_attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
const int32_t* word_id = original_encoder_input_ids->Data<int32_t>();
|
||||
int32_t* mask = mask_data;
|
||||
for (int i = 0; i < batch_size; i++) {
|
||||
int32_t abs_position = 0;
|
||||
for (int j = 0; j < sequence_length; j++, word_id++, mask++) {
|
||||
// T5Tokenizer might add one EOS pad token at the end.
|
||||
// That EOS token shall have attention mask 1 even when EOS token is same as pad token.
|
||||
// Here we only set attention mask to be 0 for left padding only, so as to be parity with huggingface.
|
||||
if (*word_id == pad_token_id && abs_position == 0) {
|
||||
*mask = 0;
|
||||
} else {
|
||||
*mask = 1;
|
||||
abs_position++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
|
||||
// for encoder_input_ids and encoder_attention_mask
|
||||
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
|
||||
ExpandInputs<int32_t>(encoder_input_ids, num_beams, allocator, expanded_encoder_input_ids);
|
||||
ExpandInputs<int32_t>(encoder_attention_mask, num_beams, allocator, expanded_encoder_attention_mask);
|
||||
|
||||
// decoder_input_ids is optional.
|
||||
if (start_token_id >= 0) {
|
||||
// Expanded decoder_input_ids has shape (batch_size * num_beams, 1), and filled with start token ID
|
||||
int64_t dims[] = {batch_size * num_beams, 1};
|
||||
TensorShape decoder_input_ids_shape(&dims[0], 2);
|
||||
Tensor::InitOrtValue(element_type, decoder_input_ids_shape, allocator, expanded_decoder_input_ids);
|
||||
int32_t* data = expanded_decoder_input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
for (int i = 0; i < batch_size * num_beams; i++, data++) {
|
||||
*data = start_token_id;
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Copy present state to past state for T5 model
|
||||
template <typename T>
|
||||
void PickT5PastState(const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t>& beam_indices,
|
||||
AllocatorPtr allocator) {
|
||||
for (int i = 0; i < num_present_tensors; ++i) {
|
||||
const OrtValue& present = last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i];
|
||||
|
||||
// shape is like (batch_beam_size, 12, past_seq_len, 64)
|
||||
const TensorShape& past_shape = present.Get<Tensor>().Shape();
|
||||
auto block_size_per_beam = past_shape[1] * past_shape[2] * past_shape[3];
|
||||
|
||||
// Create a tensor with same shape.
|
||||
// TODO(tianleiwu): allocate one buffer for all layers
|
||||
OrtValue past;
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<T>(), past_shape, allocator, past);
|
||||
|
||||
gsl::span<T> past_span = gsl::make_span<T>(past.GetMutable<Tensor>()->MutableData<T>(), past_shape.Size());
|
||||
gsl::span<const T> present_span = gsl::make_span<const T>(present.Get<Tensor>().Data<T>(), past_shape.Size());
|
||||
for (gsl::index j = 0; j < beam_indices.length(); j++) {
|
||||
int32_t beam_index = beam_indices[j];
|
||||
gsl::span<const T> present_beam = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<T> past_beam = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
|
||||
gsl::copy(present_beam, past_beam);
|
||||
}
|
||||
|
||||
next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] = past;
|
||||
}
|
||||
}
|
||||
|
||||
// Update decoder inputs given decoder outputs of last iteration.
|
||||
template <typename T>
|
||||
Status UpdateDecoderFeeds(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper) {
|
||||
ORT_UNUSED_PARAMETER(stream);
|
||||
|
||||
// last_outputs: logits, present_key_self_0, present_value_self_0, ...
|
||||
// next_inputs: input_ids,
|
||||
// encoder_attention_mask, encoder_hidden_states,
|
||||
// past_key_self_0, past_value_self_0, ...
|
||||
// past_key_cross_0, past_value_cross_0, ...
|
||||
// Only need copy beam next tokens to input_ids, and copy present_*_self_* to past_*_self_*,
|
||||
|
||||
// Update input_ids with next tokens.
|
||||
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
|
||||
int64_t dims[] = {batch_beam_size, 1};
|
||||
TensorShape input_ids_shape(&dims[0], 2);
|
||||
|
||||
// TODO(tianleiwu): Reuse buffer for input_ids to reduce memory allocation.
|
||||
OrtValue input_ids;
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);
|
||||
|
||||
gsl::copy(beam_next_tokens, input_ids.GetMutable<Tensor>()->MutableDataAsSpan<int32_t>());
|
||||
|
||||
next_inputs[0] = input_ids;
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("input_ids", input_ids);
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(dumper);
|
||||
#endif
|
||||
|
||||
// Update past state
|
||||
ORT_ENFORCE(last_outputs.size() >= static_cast<size_t>(1 + num_present_tensors));
|
||||
// TODO(tianleiwu): remove num_beams==1 once GreedySearch operator is available.
|
||||
if (num_beams == 1) {
|
||||
// feed present_* output to past_* inputs one by one
|
||||
for (int i = 0; i < num_present_tensors; ++i) {
|
||||
next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] =
|
||||
last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i];
|
||||
}
|
||||
} else {
|
||||
PickT5PastState<T>(last_outputs, next_inputs, num_present_tensors, beam_indices, allocator);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
//------------------------------------------------
|
||||
// Explicit template instantiations of functions
|
||||
//------------------------------------------------
|
||||
|
||||
template void InitBeamState<float>(
|
||||
transformers::IBeamSearchState<float>* beam_state,
|
||||
transformers::IBeamSearchCpuState* cpu_state,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
gsl::span<const int32_t> input_ids_in_cpu,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
void* stream);
|
||||
|
||||
template Status ProcessLogits<float>(
|
||||
|
|
@ -472,9 +635,15 @@ template Status DeviceCopy<float>(
|
|||
gsl::span<float> target,
|
||||
gsl::span<const float> source,
|
||||
void* stream,
|
||||
int copyDirectionn);
|
||||
int copyDirection);
|
||||
|
||||
template Status UpdateFeeds<float>(
|
||||
template Status DeviceCopy<int32_t>(
|
||||
gsl::span<int32_t> target,
|
||||
gsl::span<const int32_t> source,
|
||||
void* stream,
|
||||
int copyDirection);
|
||||
|
||||
template Status UpdateGptFeeds<float>(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
|
|
@ -486,6 +655,19 @@ template Status UpdateFeeds<float>(
|
|||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
template Status UpdateDecoderFeeds<float>(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
template void ExpandInputs<int32_t>(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);
|
||||
|
||||
} // namespace BeamSearchCpuDeviceHelper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#ifndef SHARED_PROVIDER
|
||||
|
|
@ -6,9 +9,10 @@
|
|||
#include "core/framework/allocator.h"
|
||||
#endif
|
||||
|
||||
#include <vector>
|
||||
#include "gsl/gsl"
|
||||
#include "logits_processor.h"
|
||||
#include "beam_search_shared.h"
|
||||
#include "contrib_ops/cpu/transformers/logits_processor.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class IExecutionProvider;
|
||||
|
|
@ -31,40 +35,34 @@ namespace BeamSearchDeviceHelper {
|
|||
using TopkFunc = std::function<Status(
|
||||
const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
void* stream, // cudaStream_t stream,
|
||||
void* stream, // cudaStream_t
|
||||
onnxruntime::concurrency::ThreadPool* threadpool,
|
||||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices)>;
|
||||
|
||||
// Create subgraph inputs: input_ids, position_ids and attention_mask
|
||||
using CreateInputsFunc = std::function<Status(
|
||||
// Create subgraph inputs: input_ids, position_ids and attention_mask (for GPT-2).
|
||||
using CreateGptInputsFunc = std::function<Status(
|
||||
const Tensor* original_input_ids,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
AllocatorPtr alloactor,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded_input_ids,
|
||||
OrtValue& expanded_position_ids,
|
||||
OrtValue& expanded_attention_mask)>;
|
||||
|
||||
using AddToFeedsFunc = std::function<Status(
|
||||
const IExecutionProvider* provider,
|
||||
OrtValue& input_ids,
|
||||
OrtValue& position_ids,
|
||||
OrtValue& attention_mask,
|
||||
std::initializer_list<OrtValue> inputs,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer)>;
|
||||
|
||||
template <typename T>
|
||||
using InitBeamStateFunc = std::function<void(
|
||||
transformers::IBeamSearchState<T>* beam_state,
|
||||
transformers::IBeamSearchCpuState* cpu_state,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
gsl::span<const int32_t> input_ids_in_cpu,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
void* stream)>;
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -89,8 +87,9 @@ using DeviceCopyFunc = std::function<Status(
|
|||
void* stream,
|
||||
int copyDirection)>;
|
||||
|
||||
// Update subgraph inputs given outputs of last iteration (for GPT-2).
|
||||
template <typename T>
|
||||
using UpdateFeedsFunc = std::function<Status(
|
||||
using UpdateGptFeedsFunc = std::function<Status(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
|
|
@ -102,6 +101,29 @@ using UpdateFeedsFunc = std::function<Status(
|
|||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper)>;
|
||||
|
||||
// Create encoder inputs (for encoder-decoder model like T5).
|
||||
using CreateEncoderInputsFunc = std::function<Status(
|
||||
const Tensor* original_encoder_input_ids,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
int start_token_id,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded_encoder_input_ids,
|
||||
OrtValue& expanded_encoder_attention_mask,
|
||||
OrtValue& expanded_decoder_input_ids)>;
|
||||
|
||||
// Update decoder inputs given decoder outputs of last iteration (for encoder-decoder model like T5).
|
||||
template <typename T>
|
||||
using UpdateDecoderFeedsFunc = std::function<Status(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper)>;
|
||||
} // namespace BeamSearchDeviceHelper
|
||||
|
||||
// These are CPU specific device helper implementations
|
||||
|
|
@ -114,33 +136,17 @@ Status TopK(
|
|||
std::unique_ptr<Tensor>& output_values,
|
||||
std::unique_ptr<Tensor>& output_indices);
|
||||
|
||||
Status CreateInputs(
|
||||
const Tensor* original_input_ids,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
AllocatorPtr alloactor,
|
||||
OrtValue& expanded_input_ids,
|
||||
OrtValue& expanded_position_ids,
|
||||
OrtValue& expanded_attention_mask);
|
||||
|
||||
Status AddToFeeds(
|
||||
const IExecutionProvider* execution_provider,
|
||||
OrtValue& input_ids,
|
||||
OrtValue& position_ids,
|
||||
OrtValue& attention_mask,
|
||||
std::initializer_list<OrtValue> inputs,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer);
|
||||
|
||||
template <typename T>
|
||||
void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
|
||||
transformers::IBeamSearchCpuState* cpu_state,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
gsl::span<const int32_t> input_ids_in_cpu,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
void* stream);
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -163,8 +169,22 @@ Status DeviceCopy(gsl::span<T> target,
|
|||
void* stream,
|
||||
int copyDirectionn);
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Functions for GPT model only
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
Status CreateGptInputs(
|
||||
const Tensor* original_input_ids,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded_input_ids,
|
||||
OrtValue& expanded_position_ids,
|
||||
OrtValue& expanded_attention_mask);
|
||||
|
||||
template <typename T>
|
||||
Status UpdateFeeds(
|
||||
Status UpdateGptFeeds(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
|
|
@ -176,6 +196,38 @@ Status UpdateFeeds(
|
|||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Functions for encoder-decoder model like T5
|
||||
// ---------------------------------------------------------------
|
||||
Status CreateEncoderInputs(
|
||||
const Tensor* original_encoder_input_ids,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
int start_token_id,
|
||||
AllocatorPtr allocator,
|
||||
OrtValue& expanded_encoder_input_ids,
|
||||
OrtValue& expanded_encoder_attention_mask,
|
||||
OrtValue& expanded_decoder_input_ids);
|
||||
|
||||
// Update decoder inputs given decoder outputs of last iteration.
|
||||
template <typename T>
|
||||
Status UpdateDecoderFeeds(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Utility Functions
|
||||
// ---------------------------------------------------------------
|
||||
template <typename T>
|
||||
void ExpandInputs(const OrtValue& input, int num_beams, AllocatorPtr allocator, OrtValue& expanded);
|
||||
|
||||
} // namespace BeamSearchCpuDeviceHelper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
364
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
Normal file
364
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
Normal file
|
|
@ -0,0 +1,364 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <utility>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
namespace transformers {
|
||||
|
||||
template <typename T>
|
||||
gsl::span<T> AllocateBuffer(AllocatorPtr allocator,
|
||||
BufferUniquePtr& buffer,
|
||||
size_t elements,
|
||||
bool fill = false,
|
||||
T fill_value = T{}) {
|
||||
size_t bytes = SafeInt<size_t>(sizeof(T)) * elements;
|
||||
void* data = allocator->Alloc(bytes);
|
||||
BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
|
||||
buffer = std::move(temp_buffer);
|
||||
T* first = reinterpret_cast<T*>(buffer.get());
|
||||
auto span = gsl::make_span(first, elements);
|
||||
|
||||
if (fill) {
|
||||
std::fill_n(first, elements, fill_value);
|
||||
}
|
||||
|
||||
return span;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct BeamSearchState : public IBeamSearchState<T> {
|
||||
void Init(AllocatorPtr allocator,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
bool output_scores,
|
||||
bool use_position) {
|
||||
size_t batch_beam_size = SafeInt<size_t>(batch_size) * num_beams;
|
||||
|
||||
size_t next_token_size = SafeInt<size_t>(batch_beam_size) * vocab_size;
|
||||
this->next_token_logits = AllocateBuffer<T>(allocator, next_token_logits_buffer_, next_token_size);
|
||||
this->next_token_scores = AllocateBuffer<float>(allocator, next_token_scores_buffer_, next_token_size);
|
||||
|
||||
this->next_tokens = AllocateBuffer<int32_t>(allocator, next_tokens_buffer_, SafeInt<size_t>(2) * batch_beam_size);
|
||||
|
||||
this->next_indices = AllocateBuffer<int32_t>(allocator, next_indices_buffer_, SafeInt<size_t>(2) * batch_beam_size);
|
||||
|
||||
if (use_position) {
|
||||
this->next_positions = AllocateBuffer<int32_t>(allocator, next_positions_buffer_, batch_beam_size);
|
||||
}
|
||||
|
||||
this->beam_scores = AllocateBuffer<float>(allocator, beam_scores_buffer_, batch_beam_size);
|
||||
|
||||
if (output_scores) {
|
||||
size_t elements = SafeInt<size_t>(max_length - sequence_length) * batch_size * num_beams * vocab_size;
|
||||
this->scores = AllocateBuffer<float>(allocator, scores_buffer_, elements);
|
||||
this->remaining_scores = this->scores;
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
BufferUniquePtr next_token_logits_buffer_;
|
||||
BufferUniquePtr next_token_scores_buffer_;
|
||||
BufferUniquePtr next_tokens_buffer_;
|
||||
BufferUniquePtr next_indices_buffer_;
|
||||
BufferUniquePtr next_positions_buffer_;
|
||||
BufferUniquePtr beam_scores_buffer_;
|
||||
BufferUniquePtr scores_buffer_;
|
||||
};
|
||||
|
||||
struct BeamSearchCpuState : public IBeamSearchCpuState {
|
||||
Sequences sequences;
|
||||
|
||||
void Init(AllocatorPtr allocator, size_t batch_beam_size, int max_length, int sequence_length, bool is_cuda) {
|
||||
this->sequence_lengths = AllocateBuffer<int32_t>(allocator, sequence_lengths_buffer_, batch_beam_size);
|
||||
|
||||
size_t sequences_bytes = SafeInt<size_t>(2) * batch_beam_size * max_length;
|
||||
this->sequences_space = AllocateBuffer<int32_t>(allocator, sequences_space_buffer_, sequences_bytes);
|
||||
memset(this->sequences_space.data(), 0, this->sequences_space.size_bytes());
|
||||
|
||||
if (is_cuda) {
|
||||
// buffers used by CUDA operator but not by CPU operator.
|
||||
this->topk_scores = AllocateBuffer<float>(allocator, topk_scores_buffer_, 2 * batch_beam_size);
|
||||
this->topk_tokens = AllocateBuffer<int32_t>(allocator, topk_tokens_buffer_, 2 * batch_beam_size);
|
||||
this->topk_indices = AllocateBuffer<int32_t>(allocator, topk_indices_buffer_, 2 * batch_beam_size);
|
||||
this->final_beam_scores = AllocateBuffer<float>(allocator, final_beam_scores_buffer_, batch_beam_size);
|
||||
}
|
||||
|
||||
this->sequences.Init(this->sequences_space, static_cast<int>(batch_beam_size), sequence_length, max_length);
|
||||
}
|
||||
|
||||
// Copy input_ids to sequences[0]
|
||||
void SetSequence(gsl::span<const int32_t> input_ids_in_cpu,
|
||||
size_t batch_beam_size,
|
||||
int max_length,
|
||||
int sequence_length) {
|
||||
gsl::span<int32_t> sequences_0 = sequences_space;
|
||||
for (size_t i = 0; i < batch_beam_size; i++) {
|
||||
for (int j = 0; j < sequence_length; j++) {
|
||||
const size_t index = SafeInt<gsl::index>(i) * max_length + j;
|
||||
sequences_0[index] = input_ids_in_cpu[SafeInt<gsl::index>(i) * sequence_length + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
BufferUniquePtr final_beam_scores_buffer_;
|
||||
BufferUniquePtr sequence_lengths_buffer_;
|
||||
BufferUniquePtr topk_scores_buffer_;
|
||||
BufferUniquePtr topk_tokens_buffer_;
|
||||
BufferUniquePtr topk_indices_buffer_;
|
||||
BufferUniquePtr sequences_space_buffer_;
|
||||
};
|
||||
|
||||
// Base class of beam search implementation that is common for both GPT-2 and T5.
|
||||
template <typename T>
|
||||
class BeamSearchBase {
|
||||
public:
|
||||
BeamSearchBase(OpKernelContextInternal& context,
|
||||
const SessionState& decoder_session_state,
|
||||
concurrency::ThreadPool* thread_pool,
|
||||
void* cuda_stream,
|
||||
IConsoleDumper* cuda_dumper,
|
||||
BeamSearchParameters& params,
|
||||
const BeamSearchDeviceHelper::TopkFunc& topk_func,
|
||||
const BeamSearchDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func)
|
||||
: context_(context),
|
||||
decoder_session_state_(decoder_session_state),
|
||||
thread_pool_(thread_pool),
|
||||
implicit_inputs_(context_.GetImplicitInputs()),
|
||||
cuda_stream_(cuda_stream),
|
||||
cuda_dumper_(cuda_dumper),
|
||||
parameters_(¶ms),
|
||||
cpu_allocator_(nullptr),
|
||||
temp_space_allocator_(nullptr),
|
||||
topk_func_(topk_func),
|
||||
process_logits_func_(process_logits_func),
|
||||
device_copy_func_(device_copy_func),
|
||||
device_copy_int32_func_(device_copy_int32_func) {
|
||||
parameters_->ParseFromInputs(&context);
|
||||
|
||||
cpu_allocator_ = decoder_session_state.GetExecutionProviders()
|
||||
.Get(onnxruntime::kCpuExecutionProvider)
|
||||
->GetAllocator(0, OrtMemTypeDefault);
|
||||
}
|
||||
|
||||
// Initialize by validating all the inputs, and allocating the output tensors.
|
||||
Status Initialize();
|
||||
|
||||
// Validate inputs.
|
||||
Status CheckInputs(const OpKernelContextInternal& context);
|
||||
|
||||
protected:
|
||||
// Process logits and append next tokens to sequences.
|
||||
Status GenerateNextToken(const OrtValue& logits,
|
||||
gsl::span<int32_t>& beam_next_tokens,
|
||||
gsl::span<int32_t>& beam_indices,
|
||||
BeamSearchState<T>& beam_state,
|
||||
BeamSearchCpuState& cpu_state,
|
||||
int counter);
|
||||
|
||||
// Calculate scores from logits, then apply filtering and select next token for each beam.
|
||||
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
|
||||
BeamSearchState<T>& beam_state,
|
||||
BeamSearchCpuState& cpu_state,
|
||||
AllocatorPtr& allocator,
|
||||
int counter);
|
||||
|
||||
bool IsCuda() const { return cuda_stream_ != nullptr; }
|
||||
|
||||
const IConsoleDumper* GetConsoleDumper() const { return IsCuda() ? cuda_dumper_ : &(cpu_dumper_); }
|
||||
|
||||
OpKernelContextInternal& context_;
|
||||
|
||||
const SessionState& decoder_session_state_;
|
||||
|
||||
concurrency::ThreadPool* thread_pool_;
|
||||
|
||||
const std::vector<const OrtValue*>& implicit_inputs_;
|
||||
|
||||
void* cuda_stream_;
|
||||
|
||||
IConsoleDumper* cuda_dumper_;
|
||||
CpuTensorConsoleDumper cpu_dumper_;
|
||||
|
||||
BeamSearchParameters* parameters_;
|
||||
|
||||
LogitsProcessorList logits_processors_;
|
||||
|
||||
std::unique_ptr<BeamSearchScorer> beam_scorer_;
|
||||
|
||||
AllocatorPtr cpu_allocator_;
|
||||
AllocatorPtr temp_space_allocator_;
|
||||
|
||||
// Device specific functions
|
||||
BeamSearchDeviceHelper::TopkFunc topk_func_;
|
||||
BeamSearchDeviceHelper::ProcessLogitsFunc<T> process_logits_func_;
|
||||
BeamSearchDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
|
||||
BeamSearchDeviceHelper::DeviceCopyFunc<int32_t> device_copy_int32_func_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
|
||||
// Input shapes:
|
||||
// input_ids : (batch_size, sequence_length)
|
||||
// vocab_mask : (vocab_size) or nullptr
|
||||
|
||||
const Tensor* input_ids = context.Input<Tensor>(0);
|
||||
const auto& dims = input_ids->Shape().GetDims();
|
||||
if (dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'input_ids' is expected to have 2 dimensions, got ", dims.size());
|
||||
}
|
||||
|
||||
const Tensor* vocab_mask = context.Input<Tensor>(8);
|
||||
if (vocab_mask != nullptr) { // vocab_mask is optional
|
||||
const auto& vocab_mask_dims = vocab_mask->Shape().GetDims();
|
||||
if (vocab_mask_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'vocab_mask' is expected to have 1 dimension, got ", vocab_mask_dims.size());
|
||||
}
|
||||
|
||||
// There is dependency on vocab_size parameter, which shall be set before calling this function.
|
||||
if (static_cast<int>(vocab_mask_dims[0]) != parameters_->vocab_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'vocab_mask' shape does not match with vocab_size, got ", vocab_mask_dims[0]);
|
||||
}
|
||||
|
||||
// store vocab mask in parameters.
|
||||
parameters_->vocab_mask = vocab_mask->DataAsSpan<int32_t>();
|
||||
}
|
||||
|
||||
const Tensor* prefix_vocab_mask = context.Input<Tensor>(9);
|
||||
if (prefix_vocab_mask != nullptr) {
|
||||
// prefix_vocab_mask is optional
|
||||
const auto& vocab_mask_dims = prefix_vocab_mask->Shape().GetDims();
|
||||
if (vocab_mask_dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'prefix_vocab_mask' is expected to be 2 dimensions, got ", vocab_mask_dims.size());
|
||||
}
|
||||
|
||||
// prefix_vocab_mask first dimension should be same as the first dimension of input_ids
|
||||
if (static_cast<int>(vocab_mask_dims[0]) != static_cast<int>(dims[0])) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input_ids and prefix_vocab_mask must have the same batch_size");
|
||||
}
|
||||
|
||||
// There is dependency on vocab_size parameter, which shall be set before calling this function.
|
||||
if (static_cast<int>(vocab_mask_dims[1]) != parameters_->vocab_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'prefix_vocab_mask' shape[1] shall be vocab_size, got ", vocab_mask_dims[1]);
|
||||
}
|
||||
|
||||
// store prefix vocab mask in parameters.
|
||||
parameters_->prefix_vocab_mask = prefix_vocab_mask->DataAsSpan<int32_t>();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchBase<T>::Initialize() {
|
||||
ORT_RETURN_IF_ERROR(context_.GetTempSpaceAllocator(&temp_space_allocator_));
|
||||
|
||||
#define CHECK_SCALAR_INPUT(name, index, required) \
|
||||
auto* name##_tensor = context_.Input<Tensor>(index); \
|
||||
if (name##_tensor) { \
|
||||
if (!name##_tensor->Shape().IsScalar()) { \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " should be a scalar. Got shape of ", \
|
||||
name##_tensor->Shape()); \
|
||||
} \
|
||||
} else if (required) { \
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "'BeamSearch' input " #name " is required"); \
|
||||
}
|
||||
|
||||
CHECK_SCALAR_INPUT(min_length, 1, false);
|
||||
|
||||
CHECK_SCALAR_INPUT(max_length, 2, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(num_beams, 3, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(num_return_sequences, 4, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(temperature, 5, true);
|
||||
|
||||
CHECK_SCALAR_INPUT(length_penalty, 6, true);
|
||||
|
||||
ORT_RETURN_IF(parameters_->num_return_sequences > parameters_->num_beams,
|
||||
"'num_return_sequences' has to be smaller or equal to 'num_beams'.");
|
||||
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(context_));
|
||||
|
||||
// This flag will be updated later when the scores output exists.
|
||||
parameters_->output_scores = false;
|
||||
|
||||
if (!IsCuda()) {
|
||||
// Logits processor is used in CPU only. In CUDA, cuda kernels are used instead.
|
||||
// Initialize processors after CheckInputs so that parameters_->vocab_mask is ready.
|
||||
logits_processors_.Init(*parameters_);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchBase<T>::ProcessLogits(
|
||||
const OrtValue& logits,
|
||||
BeamSearchState<T>& beam_state,
|
||||
BeamSearchCpuState& cpu_state,
|
||||
AllocatorPtr& allocator,
|
||||
int counter) {
|
||||
return process_logits_func_(logits, &beam_state, &cpu_state, &(cpu_state.sequences), allocator,
|
||||
thread_pool_, &logits_processors_, beam_scorer_.get(),
|
||||
parameters_, counter, cuda_stream_, GetConsoleDumper());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchBase<T>::GenerateNextToken(
|
||||
const OrtValue& logits,
|
||||
gsl::span<int32_t>& beam_next_tokens,
|
||||
gsl::span<int32_t>& beam_indices,
|
||||
BeamSearchState<T>& beam_state,
|
||||
BeamSearchCpuState& cpu_state,
|
||||
int counter) {
|
||||
// Process logits to get next token scores
|
||||
ORT_RETURN_IF_ERROR(ProcessLogits(logits, beam_state, cpu_state, temp_space_allocator_, counter));
|
||||
|
||||
gsl::span<float>& beam_scores = beam_scorer_->GetNextScores();
|
||||
// It is optional to clone beam_scores. Change it to use same buffer also works for CPU:
|
||||
// beam_state.beam_scores = beam_scores
|
||||
// Here we make a copy to reduce the coupling with little cost (the buffer size is small).
|
||||
ORT_RETURN_IF_ERROR(device_copy_func_(beam_state.beam_scores,
|
||||
beam_scores,
|
||||
cuda_stream_,
|
||||
DeviceCopyDirection::hostToDevice));
|
||||
|
||||
beam_next_tokens = beam_scorer_->GetNextTokens();
|
||||
beam_indices = beam_scorer_->GetNextIndices();
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
cpu_dumper_.Print("beam_scores from scorer", beam_scores.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
cpu_dumper_.Print("beam_next_tokens", beam_next_tokens.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
cpu_dumper_.Print("beam_indices", beam_indices.data(), parameters_->batch_size, parameters_->num_beams);
|
||||
#endif
|
||||
|
||||
cpu_state.sequences.AppendNextTokenToSequences(beam_indices, beam_next_tokens);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
cpu_state.sequences.PrintSequences(&cpu_dumper_);
|
||||
#endif
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
278
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
Normal file
278
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_gpt.h
Normal file
|
|
@ -0,0 +1,278 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/transformers/beam_search_impl_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
namespace transformers {
|
||||
|
||||
// Beam search implementation for GPT-2 model.
|
||||
template <typename T>
|
||||
class BeamSearchGpt : public BeamSearchBase<T> {
|
||||
public:
|
||||
BeamSearchGpt(OpKernelContextInternal& context,
|
||||
const SessionState& decoder_session_state,
|
||||
GptSubgraph& gpt_subgraph,
|
||||
concurrency::ThreadPool* thread_pool,
|
||||
void* cuda_stream,
|
||||
IConsoleDumper* cuda_dumper,
|
||||
BeamSearchParameters& params,
|
||||
const BeamSearchDeviceHelper::CreateGptInputsFunc& create_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const BeamSearchDeviceHelper::TopkFunc& topk_func,
|
||||
const BeamSearchDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
|
||||
const BeamSearchDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
|
||||
const BeamSearchDeviceHelper::UpdateGptFeedsFunc<T>& update_feeds_func)
|
||||
: BeamSearchBase<T>(context, decoder_session_state, thread_pool,
|
||||
cuda_stream, cuda_dumper, params,
|
||||
topk_func, process_logits_func, device_copy_func, device_copy_int32_func),
|
||||
gpt_subgraph_(gpt_subgraph),
|
||||
create_inputs_func_(create_inputs_func),
|
||||
add_to_feeds_func_(add_to_feeds_func),
|
||||
init_beam_state_func_(init_beam_state_func),
|
||||
update_feeds_func_(update_feeds_func) {
|
||||
}
|
||||
|
||||
// Execute beam search in iterations util stopping criteria is reached.
|
||||
// In each iteration, GPT subgraph is called, and next token for each sequence is generated.
|
||||
Status Execute(const FeedsFetchesManager& feeds_fetches_manager);
|
||||
|
||||
private:
|
||||
// Prepare the inputs for first inference of subgraph
|
||||
Status CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer);
|
||||
|
||||
// Update the input for next iteration.
|
||||
Status UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices);
|
||||
|
||||
GptSubgraph& gpt_subgraph_;
|
||||
|
||||
// Device specific functions
|
||||
BeamSearchDeviceHelper::CreateGptInputsFunc create_inputs_func_;
|
||||
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
BeamSearchDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
|
||||
BeamSearchDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer) {
|
||||
const OrtValue* input_ids_value = this->context_.GetInputOrtValue(0);
|
||||
const Tensor& input_ids = input_ids_value->Get<Tensor>();
|
||||
return gpt_subgraph_.CreateInitialFeeds(input_ids,
|
||||
this->implicit_inputs_,
|
||||
this->parameters_->num_beams,
|
||||
this->parameters_->pad_token_id,
|
||||
sequence_lengths,
|
||||
expanded_input_ids,
|
||||
feeds,
|
||||
this->create_inputs_func_,
|
||||
this->add_to_feeds_func_,
|
||||
buffer);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchGpt<T>::UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
OrtValue& position_ids,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices) {
|
||||
return update_feeds_func_(this->temp_space_allocator_,
|
||||
this->cuda_stream_,
|
||||
last_outputs,
|
||||
next_inputs,
|
||||
current_length,
|
||||
position_ids,
|
||||
beam_next_tokens,
|
||||
beam_indices,
|
||||
this->parameters_->num_beams,
|
||||
this->GetConsoleDumper());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchGpt<T>::Execute(const FeedsFetchesManager& feeds_fetches_manager) {
|
||||
auto status = Status::OK();
|
||||
const BeamSearchParameters* parameters = this->parameters_;
|
||||
int64_t sequences_dims[] = {parameters->batch_size, parameters->num_return_sequences, parameters->max_length};
|
||||
TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0]));
|
||||
Tensor* output_sequences = this->context_.Output(0, sequences_shape);
|
||||
|
||||
int64_t sequences_scores_dims[] = {parameters->batch_size, parameters->num_return_sequences};
|
||||
TensorShape sequences_scores_shape(&sequences_scores_dims[0], 2);
|
||||
Tensor* output_sequences_scores = this->context_.Output(1, sequences_scores_shape);
|
||||
|
||||
int64_t scores_dims[] = {
|
||||
static_cast<int64_t>(parameters->max_length) - static_cast<int64_t>(parameters->sequence_length),
|
||||
parameters->batch_size, parameters->num_beams, parameters->vocab_size};
|
||||
TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0]));
|
||||
Tensor* output_scores = this->context_.Output(2, scores_shape);
|
||||
|
||||
// Update the flag to indicate whether scores exists in output
|
||||
this->parameters_->output_scores = (output_scores != nullptr);
|
||||
|
||||
std::vector<OrtValue> feeds;
|
||||
// TODO(tianleiwu): allocate fetches. use ping-pong buffers for past state.
|
||||
std::vector<OrtValue> fetches;
|
||||
|
||||
// Initialize resources
|
||||
onnxruntime::OrtStlAllocator<HypothesisScore> hypothesis_score_allocator(this->cpu_allocator_);
|
||||
onnxruntime::OrtStlAllocator<BeamHypotheses> beam_hyps_allocator(this->cpu_allocator_);
|
||||
this->beam_scorer_ = std::make_unique<BeamSearchScorer>(static_cast<size_t>(parameters->batch_size),
|
||||
static_cast<size_t>(parameters->num_beams),
|
||||
static_cast<size_t>(parameters->max_length),
|
||||
parameters->length_penalty,
|
||||
parameters->early_stopping,
|
||||
static_cast<size_t>(parameters->num_return_sequences),
|
||||
parameters->pad_token_id,
|
||||
parameters->eos_token_id,
|
||||
hypothesis_score_allocator,
|
||||
beam_hyps_allocator);
|
||||
this->beam_scorer_->Initialize(this->cpu_allocator_, parameters->sequence_length);
|
||||
|
||||
BeamSearchCpuState cpu_state;
|
||||
cpu_state.Init(this->cpu_allocator_,
|
||||
static_cast<size_t>(parameters->BatchBeamSize()),
|
||||
parameters->max_length,
|
||||
parameters->sequence_length,
|
||||
this->IsCuda());
|
||||
|
||||
// buffer in GPU for input_ids, position_ids and attention_mask
|
||||
IAllocatorUniquePtr<char> buffer;
|
||||
OrtValue expanded_input_ids_in_cpu;
|
||||
ORT_RETURN_IF_ERROR(CreateInitialFeeds(cpu_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
|
||||
|
||||
BeamSearchState<T> beam_state;
|
||||
constexpr bool use_position = true;
|
||||
beam_state.Init(this->temp_space_allocator_,
|
||||
parameters->batch_size,
|
||||
parameters->num_beams,
|
||||
parameters->vocab_size,
|
||||
parameters->sequence_length,
|
||||
parameters->max_length,
|
||||
parameters->output_scores,
|
||||
use_position);
|
||||
|
||||
init_beam_state_func_(&beam_state,
|
||||
cpu_state.sequence_lengths,
|
||||
parameters->batch_size,
|
||||
parameters->num_beams,
|
||||
this->cuda_stream_);
|
||||
|
||||
gsl::span<const int32_t> input_ids = expanded_input_ids_in_cpu.Get<Tensor>().DataAsSpan<int32_t>();
|
||||
cpu_state.SetSequence(input_ids,
|
||||
static_cast<size_t>(parameters->BatchBeamSize()),
|
||||
parameters->max_length,
|
||||
parameters->sequence_length);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
const IConsoleDumper* dumper = this->GetConsoleDumper();
|
||||
dumper->Print("input_ids", feeds[0]);
|
||||
dumper->Print("position_ids", feeds[1]);
|
||||
dumper->Print("attention_mask", feeds[2]);
|
||||
#endif
|
||||
|
||||
// Position ids for all iterations except the first. It uses memory buffer owned by next_positions.
|
||||
OrtValue position_ids;
|
||||
int64_t dims[] = {parameters->BatchBeamSize(), 1};
|
||||
TensorShape shape(&dims[0], 2);
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(),
|
||||
shape,
|
||||
beam_state.next_positions.data(),
|
||||
this->temp_space_allocator_->Info(),
|
||||
position_ids);
|
||||
|
||||
int current_length = parameters->sequence_length;
|
||||
int iteration_counter = 0;
|
||||
while (current_length < parameters->max_length) {
|
||||
iteration_counter++;
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
auto cur_len = std::to_string(current_length);
|
||||
dumper->Print("***CurrentLength", cur_len, true);
|
||||
#endif
|
||||
|
||||
status = utils::ExecuteSubgraph(this->decoder_session_state_,
|
||||
feeds_fetches_manager,
|
||||
feeds,
|
||||
fetches,
|
||||
{},
|
||||
ExecutionMode::ORT_SEQUENTIAL,
|
||||
this->context_.GetTerminateFlag(),
|
||||
this->context_.Logger());
|
||||
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
const OrtValue& logits = fetches[0];
|
||||
gsl::span<int32_t> beam_next_tokens;
|
||||
gsl::span<int32_t> beam_indices;
|
||||
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
|
||||
beam_next_tokens,
|
||||
beam_indices,
|
||||
beam_state,
|
||||
cpu_state,
|
||||
iteration_counter));
|
||||
|
||||
// When all batches are finished, stop earlier to avoid wasting computation.
|
||||
if (this->beam_scorer_->IsDone()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Increase sequence length after a new token is generated.
|
||||
++current_length;
|
||||
|
||||
// Prepare inputs for next round of subgraph call.
|
||||
if (current_length < parameters->max_length) {
|
||||
ORT_RETURN_IF_ERROR(UpdateFeeds(fetches, feeds, current_length,
|
||||
position_ids,
|
||||
beam_next_tokens.as_span<const int32_t>(),
|
||||
beam_indices.as_span<const int32_t>()));
|
||||
}
|
||||
fetches.clear();
|
||||
}
|
||||
|
||||
gsl::span<const float> final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
|
||||
if (this->IsCuda()) {
|
||||
ORT_RETURN_IF_ERROR(this->device_copy_func_(cpu_state.final_beam_scores,
|
||||
final_beam_scores,
|
||||
nullptr,
|
||||
DeviceCopyDirection::deviceToHost));
|
||||
final_beam_scores = gsl::make_span<const float>(cpu_state.final_beam_scores.data(),
|
||||
cpu_state.final_beam_scores.size());
|
||||
}
|
||||
|
||||
this->beam_scorer_->Finalize(&(cpu_state.sequences),
|
||||
final_beam_scores,
|
||||
output_sequences,
|
||||
output_sequences_scores);
|
||||
|
||||
// Output per token scores
|
||||
if (output_scores != nullptr) {
|
||||
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
|
||||
gsl::span<const float> source = gsl::span<const float>(beam_state.scores.data(), beam_state.scores.size());
|
||||
assert(target.length() == source.length());
|
||||
ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
307
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
Normal file
307
onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h
Normal file
|
|
@ -0,0 +1,307 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/transformers/beam_search_shared.h" // for DEBUG_BEAM_SEARCH
|
||||
#include "contrib_ops/cpu/transformers/beam_search_impl_base.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
namespace transformers {
|
||||
|
||||
// Beam search implementation for T5 model.
|
||||
template <typename T>
|
||||
class BeamSearchT5 : public BeamSearchBase<T> {
|
||||
public:
|
||||
BeamSearchT5(OpKernelContextInternal& context,
|
||||
const SessionState& encoder_session_state,
|
||||
const SessionState& decoder_session_state,
|
||||
T5EncoderSubgraph& encoder_subgraph,
|
||||
T5DecoderSubgraph& decoder_subgraph,
|
||||
concurrency::ThreadPool* thread_pool,
|
||||
void* cuda_stream,
|
||||
IConsoleDumper* cuda_dumper,
|
||||
BeamSearchParameters& params,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const BeamSearchDeviceHelper::TopkFunc& topk_func,
|
||||
const BeamSearchDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
|
||||
const BeamSearchDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
|
||||
const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func,
|
||||
const BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<T>& update_decoder_feeds_func)
|
||||
: BeamSearchBase<T>(context, decoder_session_state, thread_pool,
|
||||
cuda_stream, cuda_dumper, params,
|
||||
topk_func, process_logits_func, device_copy_func, device_copy_int32_func),
|
||||
encoder_session_state_(encoder_session_state),
|
||||
encoder_subgraph_(encoder_subgraph),
|
||||
decoder_subgraph_(decoder_subgraph),
|
||||
add_to_feeds_func_(add_to_feeds_func),
|
||||
init_beam_state_func_(init_beam_state_func),
|
||||
create_encoder_inputs_func_(create_encoder_inputs_func),
|
||||
update_decoder_feeds_func_(update_decoder_feeds_func) {
|
||||
}
|
||||
|
||||
// Execute beam search in iterations util stopping criteria is reached.
|
||||
Status Execute(const FeedsFetchesManager& encoder_feeds_fetches_manager,
|
||||
const FeedsFetchesManager& decoder_feeds_fetches_manager);
|
||||
|
||||
private:
|
||||
const SessionState& encoder_session_state_;
|
||||
|
||||
T5EncoderSubgraph& encoder_subgraph_;
|
||||
T5DecoderSubgraph& decoder_subgraph_;
|
||||
|
||||
// Device specific functions
|
||||
BeamSearchDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
BeamSearchDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
|
||||
|
||||
BeamSearchDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_;
|
||||
BeamSearchDeviceHelper::UpdateDecoderFeedsFunc<T> update_decoder_feeds_func_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches_manager,
|
||||
const FeedsFetchesManager& decoder_feeds_fetches_manager) {
|
||||
auto status = Status::OK();
|
||||
|
||||
const BeamSearchParameters* parameters = this->parameters_;
|
||||
ORT_ENFORCE(parameters->sequence_length == 1);
|
||||
|
||||
// Allocate output tensors.
|
||||
int64_t sequences_dims[] = {parameters->batch_size, parameters->num_return_sequences, parameters->max_length};
|
||||
TensorShape sequences_shape(&sequences_dims[0], sizeof(sequences_dims) / sizeof(sequences_dims[0]));
|
||||
Tensor* output_sequences = this->context_.Output(0, sequences_shape);
|
||||
|
||||
int64_t sequences_scores_dims[] = {parameters->batch_size, parameters->num_return_sequences};
|
||||
constexpr int64_t dims = sizeof(sequences_scores_dims) / sizeof(sequences_scores_dims[0]);
|
||||
TensorShape sequences_scores_shape(&sequences_scores_dims[0], dims);
|
||||
Tensor* output_sequences_scores = this->context_.Output(1, sequences_scores_shape);
|
||||
|
||||
int64_t scores_dims[] = {
|
||||
static_cast<int64_t>(parameters->max_length) - static_cast<int64_t>(parameters->sequence_length),
|
||||
parameters->batch_size, parameters->num_beams, parameters->vocab_size};
|
||||
TensorShape scores_shape(&scores_dims[0], sizeof(scores_dims) / sizeof(scores_dims[0]));
|
||||
Tensor* output_scores = this->context_.Output(2, scores_shape);
|
||||
|
||||
// Update the flag to indicate whether scores exists in output
|
||||
this->parameters_->output_scores = (output_scores != nullptr);
|
||||
|
||||
// ------------------------------------
|
||||
// Call encoder subgraph.
|
||||
// ------------------------------------
|
||||
std::vector<OrtValue> encoder_feeds;
|
||||
std::vector<OrtValue> encoder_fetches;
|
||||
|
||||
const OrtValue* encoder_input_ids_value = this->context_.GetInputOrtValue(0);
|
||||
const Tensor& encoder_input_ids = encoder_input_ids_value->Get<Tensor>();
|
||||
|
||||
BeamSearchCpuState cpu_state;
|
||||
cpu_state.Init(this->cpu_allocator_,
|
||||
static_cast<size_t>(parameters->BatchBeamSize()),
|
||||
parameters->max_length,
|
||||
parameters->sequence_length,
|
||||
this->IsCuda());
|
||||
|
||||
IAllocatorUniquePtr<char> buffer;
|
||||
OrtValue expanded_decoder_input_ids; // Tensor in CPU, and it will be used to initialize sequence in cpu_state
|
||||
ORT_RETURN_IF_ERROR(this->encoder_subgraph_.CreateInitialFeeds(
|
||||
encoder_input_ids,
|
||||
this->implicit_inputs_,
|
||||
parameters->num_beams,
|
||||
parameters->pad_token_id,
|
||||
parameters->decoder_start_token_id,
|
||||
encoder_feeds,
|
||||
this->create_encoder_inputs_func_,
|
||||
this->add_to_feeds_func_,
|
||||
buffer,
|
||||
expanded_decoder_input_ids));
|
||||
|
||||
ORT_RETURN_IF_ERROR(utils::ExecuteSubgraph(this->encoder_session_state_,
|
||||
encoder_feeds_fetches_manager,
|
||||
encoder_feeds,
|
||||
encoder_fetches,
|
||||
{},
|
||||
ExecutionMode::ORT_SEQUENTIAL,
|
||||
this->context_.GetTerminateFlag(),
|
||||
this->context_.Logger()));
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
const IConsoleDumper* dumper = this->GetConsoleDumper();
|
||||
for (size_t i = 0; i < encoder_feeds.size(); i++) {
|
||||
dumper->Print("encoder_feeds", static_cast<int>(i), true);
|
||||
dumper->Print("", encoder_feeds[i]);
|
||||
}
|
||||
|
||||
for (int i = 0; i <= T5EncoderSubgraph::kFirstPresentOutputIndex; i++) {
|
||||
dumper->Print("encoder_fetches", i, true);
|
||||
dumper->Print("", encoder_fetches[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
// ------------------------------------
|
||||
// Initialize resources
|
||||
// ------------------------------------
|
||||
|
||||
// Copy expanded_decoder_input_ids (in CPU) to sequence. It contains decoder_start_token_id for each beam.
|
||||
cpu_state.SetSequence(expanded_decoder_input_ids.Get<Tensor>().DataAsSpan<int32_t>(),
|
||||
static_cast<size_t>(parameters->BatchBeamSize()),
|
||||
parameters->max_length,
|
||||
parameters->sequence_length);
|
||||
|
||||
onnxruntime::OrtStlAllocator<HypothesisScore> hypothesis_score_allocator(this->cpu_allocator_);
|
||||
onnxruntime::OrtStlAllocator<BeamHypotheses> beam_hyps_allocator(this->cpu_allocator_);
|
||||
this->beam_scorer_ = std::make_unique<BeamSearchScorer>(static_cast<size_t>(parameters->batch_size),
|
||||
static_cast<size_t>(parameters->num_beams),
|
||||
static_cast<size_t>(parameters->max_length),
|
||||
parameters->length_penalty,
|
||||
parameters->early_stopping,
|
||||
static_cast<size_t>(parameters->num_return_sequences),
|
||||
parameters->pad_token_id,
|
||||
parameters->eos_token_id,
|
||||
hypothesis_score_allocator,
|
||||
beam_hyps_allocator);
|
||||
this->beam_scorer_->Initialize(this->cpu_allocator_, parameters->sequence_length);
|
||||
|
||||
BeamSearchState<T> beam_state;
|
||||
constexpr bool use_position = false;
|
||||
beam_state.Init(this->temp_space_allocator_,
|
||||
parameters->batch_size,
|
||||
parameters->num_beams,
|
||||
parameters->vocab_size,
|
||||
parameters->sequence_length,
|
||||
parameters->max_length,
|
||||
parameters->output_scores,
|
||||
use_position);
|
||||
|
||||
init_beam_state_func_(&beam_state,
|
||||
cpu_state.sequence_lengths,
|
||||
parameters->batch_size,
|
||||
parameters->num_beams,
|
||||
this->cuda_stream_);
|
||||
|
||||
// ------------------------------------------------------------------------------
|
||||
// Generate next token from logits output from encoder, and initialize decoder inputs.
|
||||
// ------------------------------------------------------------------------------
|
||||
gsl::span<int32_t> beam_next_tokens;
|
||||
gsl::span<int32_t> beam_indices;
|
||||
|
||||
int iteration_counter = 0;
|
||||
std::vector<OrtValue> decoder_feeds;
|
||||
int current_length = parameters->sequence_length;
|
||||
if (current_length + 1 < parameters->max_length) {
|
||||
++iteration_counter;
|
||||
ORT_RETURN_IF_ERROR(this->GenerateNextToken(encoder_fetches[0],
|
||||
beam_next_tokens,
|
||||
beam_indices,
|
||||
beam_state,
|
||||
cpu_state,
|
||||
iteration_counter));
|
||||
++current_length; // Increase sequence length after a new token is generated.
|
||||
ORT_RETURN_IF_ERROR(decoder_subgraph_.CreateInitialFeeds(beam_next_tokens.as_span<const int32_t>(),
|
||||
this->implicit_inputs_,
|
||||
encoder_feeds,
|
||||
encoder_fetches,
|
||||
decoder_feeds,
|
||||
this->device_copy_int32_func_,
|
||||
this->cuda_stream_));
|
||||
}
|
||||
|
||||
// TODO(tianleiwu): allocate fetches. use ping-pong buffers for past state.
|
||||
std::vector<OrtValue> decoder_fetches;
|
||||
while (current_length < parameters->max_length) {
|
||||
iteration_counter++;
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
auto cur_len = std::to_string(current_length);
|
||||
dumper->Print("***CurrentLength", cur_len, true);
|
||||
|
||||
for (int i = 0; i <= T5DecoderSubgraph::kFirstPastInputIndex; i++) {
|
||||
dumper->Print("decoder_feeds", i, true);
|
||||
dumper->Print("", decoder_feeds[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
status = utils::ExecuteSubgraph(this->decoder_session_state_,
|
||||
decoder_feeds_fetches_manager,
|
||||
decoder_feeds,
|
||||
decoder_fetches,
|
||||
{},
|
||||
ExecutionMode::ORT_SEQUENTIAL,
|
||||
this->context_.GetTerminateFlag(),
|
||||
this->context_.Logger());
|
||||
|
||||
ORT_RETURN_IF_ERROR(status);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
for (int i = 0; i <= T5DecoderSubgraph::kFirstPresentOutputIndex; i++) {
|
||||
dumper->Print("decoder_fetches", i, true);
|
||||
dumper->Print("", decoder_fetches[i]);
|
||||
}
|
||||
#endif
|
||||
|
||||
const OrtValue& logits = decoder_fetches[0];
|
||||
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
|
||||
beam_next_tokens,
|
||||
beam_indices,
|
||||
beam_state,
|
||||
cpu_state,
|
||||
iteration_counter));
|
||||
|
||||
// When all batches are finished, stop earlier to avoid wasting computation.
|
||||
if (this->beam_scorer_->IsDone()) {
|
||||
break;
|
||||
}
|
||||
|
||||
// Increase sequence length after a new token is generated.
|
||||
++current_length;
|
||||
|
||||
// Prepare inputs for next round of subgraph call.
|
||||
if (current_length < parameters->max_length) {
|
||||
const int num_present_outputs = 2 * parameters->num_layers; // number of outputs with name like present_*
|
||||
ORT_RETURN_IF_ERROR(this->update_decoder_feeds_func_(
|
||||
this->temp_space_allocator_,
|
||||
this->cuda_stream_,
|
||||
decoder_fetches,
|
||||
decoder_feeds,
|
||||
num_present_outputs,
|
||||
beam_next_tokens.as_span<const int32_t>(),
|
||||
beam_indices.as_span<const int32_t>(),
|
||||
parameters->num_beams,
|
||||
this->GetConsoleDumper()));
|
||||
}
|
||||
decoder_fetches.clear();
|
||||
}
|
||||
|
||||
gsl::span<const float> final_beam_scores(beam_state.beam_scores.data(), beam_state.beam_scores.size());
|
||||
if (this->IsCuda()) {
|
||||
ORT_RETURN_IF_ERROR(this->device_copy_func_(cpu_state.final_beam_scores,
|
||||
final_beam_scores,
|
||||
nullptr,
|
||||
DeviceCopyDirection::deviceToHost));
|
||||
final_beam_scores = gsl::make_span<const float>(cpu_state.final_beam_scores.data(),
|
||||
cpu_state.final_beam_scores.size());
|
||||
}
|
||||
|
||||
this->beam_scorer_->Finalize(&(cpu_state.sequences),
|
||||
final_beam_scores,
|
||||
output_sequences,
|
||||
output_sequences_scores);
|
||||
|
||||
// Output per token scores
|
||||
if (output_scores != nullptr) {
|
||||
gsl::span<float> target = output_scores->MutableDataAsSpan<float>();
|
||||
gsl::span<const float> source = gsl::span<const float>(beam_state.scores.data(), beam_state.scores.size());
|
||||
assert(target.length() == source.length());
|
||||
ORT_RETURN_IF_ERROR(this->device_copy_func_(target, source, nullptr, DeviceCopyDirection::deviceToDevice));
|
||||
}
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "beam_search_parameters.h"
|
||||
|
||||
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -17,10 +18,11 @@ Status BeamSearchParameters::Validate() const {
|
|||
}
|
||||
|
||||
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
|
||||
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", 0));
|
||||
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IBeamSearchParameters::kModelTypeGpt));
|
||||
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
|
||||
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
|
||||
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
|
||||
decoder_start_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("decoder_start_token_id", -1));
|
||||
no_repeat_ngram_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_repeat_ngram_size", 0));
|
||||
}
|
||||
|
||||
|
|
@ -30,25 +32,32 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
|||
const auto& dims = input_ids->Shape().GetDims();
|
||||
ORT_ENFORCE(dims.size() == 2, "input_ids shall have 2 dimensions. Got ", dims.size());
|
||||
batch_size = static_cast<int>(dims[0]);
|
||||
sequence_length = static_cast<int>(dims[1]);
|
||||
|
||||
// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
|
||||
sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;
|
||||
|
||||
auto* max_length_tensor = context->Input<Tensor>(1);
|
||||
max_length = max_length_tensor ? static_cast<int>(*max_length_tensor->Data<int32_t>()) : kMaxSequenceLength;
|
||||
ORT_ENFORCE(max_length > sequence_length, "max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")");
|
||||
ORT_ENFORCE(max_length <= kMaxSequenceLength, "max_length (", max_length, ") shall be no more than ", kMaxSequenceLength);
|
||||
ORT_ENFORCE(max_length > sequence_length,
|
||||
"max_length (", max_length, ") shall be greater than input sequence length (", sequence_length, ")");
|
||||
ORT_ENFORCE(max_length <= kMaxSequenceLength,
|
||||
"max_length (", max_length, ") shall be no more than ", kMaxSequenceLength);
|
||||
|
||||
auto* min_length_tensor = context->Input<Tensor>(2);
|
||||
min_length = min_length_tensor ? static_cast<int>(*min_length_tensor->Data<int32_t>()) : 0;
|
||||
|
||||
auto* num_beams_tensor = context->Input<Tensor>(3);
|
||||
num_beams = num_beams_tensor ? static_cast<int>(*num_beams_tensor->Data<int32_t>()) : 1;
|
||||
// TODO: limit num_beams > 1 when we can have another operator for greedy search.
|
||||
ORT_ENFORCE(num_beams >= 1 && num_beams <= kMaxNumBeams, "num_beams shall be a positive integer no more than ", kMaxNumBeams, ", got ", num_beams);
|
||||
// TODO(tianleiwu): limit num_beams > 1 when we can have another operator for greedy search.
|
||||
ORT_ENFORCE(num_beams >= 1 && num_beams <= kMaxNumBeams,
|
||||
"num_beams shall be a positive integer no more than ", kMaxNumBeams, ", got ", num_beams);
|
||||
|
||||
auto* num_return_sequences_tensor = context->Input<Tensor>(4);
|
||||
num_return_sequences = num_return_sequences_tensor ? static_cast<int>(*num_return_sequences_tensor->Data<int32_t>()) : 1;
|
||||
ORT_ENFORCE(num_return_sequences >= 1, "num_return_sequences shall be a positive integer, got ", num_return_sequences);
|
||||
ORT_ENFORCE(num_beams >= num_return_sequences, "num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
|
||||
num_return_sequences = num_return_sequences_tensor ? *num_return_sequences_tensor->Data<int32_t>() : 1;
|
||||
ORT_ENFORCE(num_return_sequences >= 1,
|
||||
"num_return_sequences shall be a positive integer, got ", num_return_sequences);
|
||||
ORT_ENFORCE(num_beams >= num_return_sequences,
|
||||
"num_return_sequences (", num_return_sequences, ") shall be be no more than num_beams (", num_beams, ")");
|
||||
|
||||
auto* temperature_tensor = context->Input<Tensor>(5);
|
||||
temperature = temperature_tensor ? static_cast<float>(*temperature_tensor->Data<float>()) : 1;
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "beam_search_shared.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/cpu/rnn/rnn_helpers.h"
|
||||
#include "beam_search_scorer.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -69,7 +69,7 @@ void BeamHypotheses::Output(
|
|||
}
|
||||
|
||||
// Since pop get the worst sequence, so output it in the reverse order.
|
||||
// The frist (worst) beam shall be put at the last position among top_k sequences.
|
||||
// The first (worst) beam shall be put at the last position among top_k sequences.
|
||||
int index = top_k - 1;
|
||||
while (!beams_.empty()) {
|
||||
auto item = beams_.top();
|
||||
|
|
@ -124,7 +124,7 @@ void BeamSearchScorer::Initialize(AllocatorPtr& allocator, int sequence_length)
|
|||
ORT_ENFORCE(next_beam_scores_.empty()); // Make sure this is called only once.
|
||||
|
||||
size_t batch_beam_size = batch_size_ * num_beams_;
|
||||
constexpr bool no_fill = false; // do not fill values after allocation
|
||||
constexpr bool no_fill = false; // Do not fill values after allocation
|
||||
|
||||
done_ = Allocate<bool>(allocator, batch_size_, done_ptr_, no_fill);
|
||||
std::fill_n(done_.data(), done_.size(), false);
|
||||
|
|
@ -134,8 +134,8 @@ void BeamSearchScorer::Initialize(AllocatorPtr& allocator, int sequence_length)
|
|||
next_beam_indices_ = Allocate<int32_t>(allocator, batch_beam_size, next_beam_indices_ptr_, no_fill);
|
||||
|
||||
// Space to store intermediate sequence with length sequence_length, sequence_length + 1, ..., max_sequence_length.
|
||||
size_t buffer_per_beam = (SafeInt<size_t>(max_length_) * (max_length_ + 1) - SafeInt<size_t>(sequence_length - 1) * sequence_length) / 2;
|
||||
hypothesis_buffer_length_ = batch_beam_size * buffer_per_beam;
|
||||
size_t per_beam = (SafeInt<size_t>(max_length_) * (max_length_ + 1) - (sequence_length - 1) * sequence_length) / 2;
|
||||
hypothesis_buffer_length_ = batch_beam_size * per_beam;
|
||||
hypothesis_buffer_ = Allocate<int32_t>(allocator, hypothesis_buffer_length_, hypothesis_buffer_ptr_, no_fill);
|
||||
}
|
||||
|
||||
|
|
@ -155,7 +155,8 @@ void BeamSearchScorer::Process(ISequences* sequences,
|
|||
for (size_t batch = 0; batch < batch_size_; batch++) {
|
||||
BeamHypotheses& beam_hyp = beam_hyps_[batch];
|
||||
if (done_[batch]) {
|
||||
ORT_ENFORCE(beam_hyp.Size() >= gsl::narrow_cast<int>(num_beams_), "Batch can only be done if all beams have been generated");
|
||||
ORT_ENFORCE(beam_hyp.Size() >= gsl::narrow_cast<int>(num_beams_),
|
||||
"Batch can only be done if all beams have been generated");
|
||||
|
||||
// Pad the batch.
|
||||
for (size_t j = 0; j < num_beams_; j++) {
|
||||
|
|
@ -203,7 +204,7 @@ void BeamSearchScorer::Process(ISequences* sequences,
|
|||
}
|
||||
|
||||
ORT_ENFORCE(beam_idx == num_beams_);
|
||||
ORT_ENFORCE(hypothesis_buffer_offset_ <= batch_size_ * num_beams_ * max_length_);
|
||||
ORT_ENFORCE(hypothesis_buffer_offset_ <= hypothesis_buffer_length_);
|
||||
|
||||
// Check if we are done so that we can save a pad step if all(done)
|
||||
if (!done_[batch]) {
|
||||
|
|
@ -258,7 +259,8 @@ void BeamSearchScorer::Finalize(ISequences* sequences,
|
|||
BeamHypotheses& beam_hyp = beam_hyps_[batch_index];
|
||||
|
||||
const size_t num_return_sequences = num_beam_hyps_to_keep_;
|
||||
auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_, num_return_sequences * max_length_);
|
||||
auto batch_output = output.subspan(batch_index * num_return_sequences * max_length_,
|
||||
num_return_sequences * max_length_);
|
||||
|
||||
if (output_sequence_scores != nullptr) {
|
||||
batch_sequence_score = sequence_scores.subspan(batch_index * num_return_sequences, num_return_sequences);
|
||||
|
|
@ -274,4 +276,4 @@ void BeamSearchScorer::Finalize(ISequences* sequences,
|
|||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@
|
|||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "core/providers/cpu/containers.h"
|
||||
#include "sequences.h"
|
||||
#include "beam_search_shared.h"
|
||||
#include "contrib_ops/cpu/transformers/sequences.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -52,7 +52,7 @@ class BeamHypotheses {
|
|||
// Output results. Note that it will clear all beams.
|
||||
void Output(int top_k, // number of sequences to return
|
||||
int max_length, // max sequence length
|
||||
gsl::span<int32_t>& sequences, // buffer filled with pad token ID, with shape (num_return_sequences, max_length)
|
||||
gsl::span<int32_t>& sequences, // buffer with pad token, shape (num_return_sequences, max_length)
|
||||
gsl::span<float>& sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
|
||||
|
||||
private:
|
||||
|
|
@ -60,7 +60,9 @@ class BeamHypotheses {
|
|||
float length_penalty_;
|
||||
bool early_stopping_;
|
||||
float worst_score_;
|
||||
std::priority_queue<HypothesisScore, onnxruntime::FastAllocVector<HypothesisScore>, HypothesisScoreCompare> beams_; // min-heap for top k
|
||||
|
||||
// Min-heap for top k
|
||||
std::priority_queue<HypothesisScore, onnxruntime::FastAllocVector<HypothesisScore>, HypothesisScoreCompare> beams_;
|
||||
};
|
||||
|
||||
class BeamSearchScorer : public IBeamScorer {
|
||||
|
|
@ -103,7 +105,7 @@ class BeamSearchScorer : public IBeamScorer {
|
|||
int eos_token_id_;
|
||||
|
||||
IAllocatorUniquePtr<bool> done_ptr_; // Allocated buffer for done_
|
||||
gsl::span<bool> done_; // List of flags indicates whether each batch is finished or not. Its shape is (batch_size).
|
||||
gsl::span<bool> done_; // Flags indicates whether each batch is finished or not. Shape is (batch_size).
|
||||
|
||||
IAllocatorUniquePtr<float> next_beam_scores_ptr_;
|
||||
gsl::span<float> next_beam_scores_;
|
||||
|
|
@ -117,11 +119,11 @@ class BeamSearchScorer : public IBeamScorer {
|
|||
IAllocatorUniquePtr<int32_t> hypothesis_buffer_ptr_; // Allocated buffer to hold all hypotheses
|
||||
gsl::span<int32_t> hypothesis_buffer_; // Span of the allocated buffer
|
||||
size_t hypothesis_buffer_length_; // Total number of elements
|
||||
size_t hypothesis_buffer_offset_; // Offset of avaiable buffer, or length of used buffer.
|
||||
size_t hypothesis_buffer_offset_; // Offset of available buffer, or length of used buffer.
|
||||
|
||||
onnxruntime::FastAllocVector<BeamHypotheses> beam_hyps_;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gsl/gsl"
|
||||
|
|
@ -23,10 +26,10 @@ struct IBeamSearchState {
|
|||
gsl::span<float> next_token_scores; // shape (batch_size, num_beams * vocab_size)
|
||||
gsl::span<int32_t> next_tokens; // shape (batch_size, 2 * num_beams)
|
||||
gsl::span<int32_t> next_indices; // shape (batch_size, 2 * num_beams)
|
||||
gsl::span<int32_t> next_positions; // shape (batch_size, num_beams). Next position value for position_ids.
|
||||
gsl::span<int32_t> next_positions; // shape (batch_size, num_beams), empty for T5. Next position for position_ids.
|
||||
gsl::span<float> beam_scores; // shape (batch_size, num_beams)
|
||||
gsl::span<float> scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
|
||||
gsl::span<float> remaining_scores; // portion of scores that is avaiable for appending next token scores.
|
||||
gsl::span<float> remaining_scores; // portion of scores that is available for appending next token scores.
|
||||
};
|
||||
|
||||
struct IBeamSearchCpuState {
|
||||
|
|
@ -72,10 +75,14 @@ class IBeamScorer {
|
|||
};
|
||||
|
||||
struct IBeamSearchParameters {
|
||||
static constexpr int kModelTypeGpt = 0;
|
||||
static constexpr int kModelTypeT5 = 1;
|
||||
|
||||
// Parameters from node attributes
|
||||
int model_type;
|
||||
int model_type; // 0 for GPT-2; 1 for encoder-decoder like T5
|
||||
int eos_token_id;
|
||||
int pad_token_id;
|
||||
int decoder_start_token_id;
|
||||
int no_repeat_ngram_size;
|
||||
bool early_stopping;
|
||||
|
||||
|
|
@ -88,7 +95,7 @@ struct IBeamSearchParameters {
|
|||
float length_penalty;
|
||||
float repetition_penalty;
|
||||
int batch_size; // deduce from first dimension of input_ids
|
||||
int sequence_length; // deduce from second dimension of input_ids
|
||||
int sequence_length; // deduce from second dimension of input_ids of GPT-2 or decoder_input_ids of T5
|
||||
|
||||
gsl::span<const int32_t> vocab_mask;
|
||||
gsl::span<const int32_t> prefix_vocab_mask;
|
||||
|
|
@ -128,4 +135,4 @@ class IConsoleDumper {
|
|||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -70,12 +70,17 @@ void DumpCpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) {
|
|||
void DumpCpuTensor(const char* name, const Tensor& tensor) {
|
||||
const auto& shape = tensor.Shape();
|
||||
|
||||
if (nullptr != name) {
|
||||
std::cout << std::string(name) << std::endl;
|
||||
}
|
||||
std::cout << "Shape:" << shape << std::endl;
|
||||
|
||||
size_t num_dims = shape.NumDimensions();
|
||||
if (num_dims >= 3) {
|
||||
int dim0 = static_cast<int>(shape.SizeToDimension(num_dims - 2));
|
||||
int dim1 = static_cast<int>(shape[num_dims - 2]);
|
||||
int dim2 = static_cast<int>(shape[num_dims - 1]);
|
||||
DumpCpuTensor(name, tensor, dim0, dim1, dim2);
|
||||
DumpCpuTensor(nullptr, tensor, dim0, dim1, dim2);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -85,7 +90,7 @@ void DumpCpuTensor(const char* name, const Tensor& tensor) {
|
|||
num_rows = static_cast<size_t>(shape[0]);
|
||||
}
|
||||
size_t row_size = num_items / num_rows;
|
||||
DumpCpuTensor(name, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
|
||||
DumpCpuTensor(nullptr, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
|
||||
}
|
||||
|
||||
void CpuTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const {
|
||||
|
|
@ -209,4 +214,4 @@ void CpuTensorConsoleDumper::Print(const char*, const std::string&, bool) const
|
|||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
#pragma once
|
||||
#include <string>
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "beam_search_shared.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -30,4 +30,4 @@ class CpuTensorConsoleDumper : public IConsoleDumper {
|
|||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,248 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "gsl/gsl"
|
||||
#include "gpt_subgraph.h"
|
||||
#include "dump_tensor.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
GptSubgraph::GptSubgraph(
|
||||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in)
|
||||
: node(node_in), attribute(attribute_name), subgraph(subgraph_in), allocator_(nullptr), is_output_float16_(false) {
|
||||
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
|
||||
|
||||
auto& subgraph_inputs = subgraph.GetInputs();
|
||||
auto& subgraph_outputs = subgraph.GetOutputs();
|
||||
|
||||
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
|
||||
// outputs: logits, present_0, present_1, ...
|
||||
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
|
||||
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
|
||||
|
||||
// CheckSubgraph will verify inputs and outputs later.
|
||||
subgraph_input_names.reserve(num_subgraph_inputs);
|
||||
for (int i = 0; i < num_subgraph_inputs; ++i) {
|
||||
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
|
||||
}
|
||||
|
||||
subgraph_output_names.reserve(num_subgraph_outputs);
|
||||
for (int i = 0; i < num_subgraph_outputs; ++i) {
|
||||
subgraph_output_names.push_back(subgraph_outputs[i]->Name());
|
||||
}
|
||||
}
|
||||
|
||||
Status GptSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) {
|
||||
ORT_RETURN_IF(num_subgraph_outputs <= 1,
|
||||
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in inputs and outputs).");
|
||||
|
||||
ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2,
|
||||
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2");
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids", "subgraph input 0 shall be named as input_ids, got: ",
|
||||
subgraph_inputs[0]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids", "subgraph input 1 shall be named as position_ids, got: ",
|
||||
subgraph_inputs[1]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask", "subgraph input 2 shall be named as attention_mask, got: ",
|
||||
subgraph_inputs[2]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0", "subgraph input 3 shall be named as past_0, got: ",
|
||||
subgraph_inputs[3]->Name());
|
||||
|
||||
// Past state shape is like (2, batch_size, 12, past_seq_len, 64). Here 12 and 64 are constants of num_heads and hidden_size/num_heads.
|
||||
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape();
|
||||
ORT_RETURN_IF(past_shape->dim_size() != 5, "subgraph past state is expected to have 5 dimension, got ",
|
||||
past_shape->dim_size());
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2,
|
||||
"subgraph past state dimension 0 shall have length of 2");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for number of heads");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0,
|
||||
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
|
||||
|
||||
// check subgraph outputs
|
||||
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits", "subgraph output 0 shall be named as logits, got: ",
|
||||
subgraph_outputs[0]->Name());
|
||||
|
||||
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0", "subgraph input 1 shall be named as present_0, got: ",
|
||||
subgraph_outputs[1]->Name());
|
||||
|
||||
// Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size.
|
||||
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
|
||||
ORT_RETURN_IF(logits_shape->dim_size() != 3, "subgraph logits output is expected to have 3 dimension, got ",
|
||||
logits_shape->dim_size());
|
||||
|
||||
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
|
||||
|
||||
// Save parameters related to the subgraph.
|
||||
num_heads = static_cast<int>(past_shape->dim(2).dim_value());
|
||||
head_size = static_cast<int>(past_shape->dim(4).dim_value());
|
||||
vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
|
||||
num_layers = static_cast<int>(subgraph_outputs.size()) - 1;
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
|
||||
"subgraph input 0 (input_ids) shall have int32 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
|
||||
"subgraph input 1 (position_ids) shall have int32 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32,
|
||||
"subgraph input 2 (attention_mask) shall have int32 type");
|
||||
|
||||
auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
ORT_RETURN_IF(output_type != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT && output_type != ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16,
|
||||
"subgraph output 0 (logits) shall be float or float16 data type");
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[3]->TypeAsProto()->tensor_type().elem_type() != output_type,
|
||||
"subgraph input 3 (past_0) shall shall have same data type of logits output");
|
||||
ORT_RETURN_IF(subgraph_outputs[1]->TypeAsProto()->tensor_type().elem_type() != output_type,
|
||||
"subgraph output 1 (present_0) shall shall have same data type of logits output");
|
||||
|
||||
is_output_float16_ = (output_type == ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GptSubgraph::Setup(const SessionState& session_state,
|
||||
const SessionState& subgraph_session_state) {
|
||||
session_state_ = &session_state;
|
||||
subgraph_session_state_ = &subgraph_session_state;
|
||||
|
||||
std::vector<std::string> feed_names;
|
||||
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
|
||||
|
||||
// Currently, input_ids is in CPU even for CUDA operator, so we have to use logits location as default.
|
||||
const OrtMemoryInfo& default_location = utils::FindMemoryInfoForValue(subgraph_session_state, "logits");
|
||||
|
||||
// position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
|
||||
// as we skip them when we call FindDevicesForValues, and default them to be in the same device as input_ids
|
||||
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
|
||||
|
||||
for (auto& entry : node.ImplicitInputDefs()) {
|
||||
feed_names.push_back(entry->Name());
|
||||
}
|
||||
|
||||
std::vector<OrtDevice> feed_locations;
|
||||
feed_locations.resize(feed_names.size());
|
||||
|
||||
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
|
||||
if (i >= subgraph_input_names.size()) { // implicit inputs
|
||||
const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]);
|
||||
feed_locations[i] = location.device;
|
||||
} else {
|
||||
feed_locations[i] = default_location.device;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<FeedsFetchesManager> ffm;
|
||||
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names,
|
||||
subgraph_session_state.GetOrtValueNameIdxMap(), ffm));
|
||||
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
|
||||
|
||||
// setup the locations where we want the subgraph output to end up on
|
||||
std::vector<const OrtMemoryInfo*> fetch_locations;
|
||||
fetch_locations.reserve(num_subgraph_outputs);
|
||||
|
||||
// past state need to be where we can feed them in to the next iteration, so set the fetch location to match the feed location.
|
||||
for (int i = 0; i < num_subgraph_outputs; ++i) {
|
||||
fetch_locations.push_back(&default_location);
|
||||
}
|
||||
|
||||
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
|
||||
|
||||
feeds_fetches_manager_ = std::move(ffm);
|
||||
|
||||
// Check subgraph only need once so put in Setup function.
|
||||
auto& inputs = subgraph.GetInputs();
|
||||
auto& outputs = subgraph.GetOutputs();
|
||||
ORT_RETURN_IF_ERROR(Validate(inputs, outputs));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const IExecutionProvider* GptSubgraph::GetProvider() const {
|
||||
const ExecutionProviders& providers = session_state_->GetExecutionProviders();
|
||||
const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider);
|
||||
const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider);
|
||||
const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider;
|
||||
return provider;
|
||||
}
|
||||
|
||||
Status GptSubgraph::CreateInitialFeeds(
|
||||
const Tensor& input_ids,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
IAllocatorUniquePtr<char>& buffer) {
|
||||
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
|
||||
|
||||
const IExecutionProvider* provider = GetProvider();
|
||||
|
||||
const TensorShape& input_ids_shape = input_ids.Shape();
|
||||
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
|
||||
const int64_t& batch_size = input_ids_shape[0];
|
||||
|
||||
// Subgraph inputs:
|
||||
// input_ids: shape (B, S) wher B is batch size, and S is sequence length
|
||||
// position_ids: shape (B, S)
|
||||
// attention_mask: shape (B, P+S), where past_sequence_length (P) is 0
|
||||
// After expansion, their shapes will become (B, M*S), where M is num_beams.
|
||||
|
||||
// Allocate subgraph inputs to be same device as input_ids
|
||||
AllocatorPtr cpu_alloactor = session_state_->GetAllocator(input_ids.Location());
|
||||
|
||||
// Store allocator, which will be used in remaining feeds
|
||||
auto default_allocator = provider->GetAllocator(0, OrtMemTypeDefault);
|
||||
allocator_ = default_allocator;
|
||||
|
||||
// Initialize empty past state
|
||||
auto past_type = IsOutputFloat16() ? DataTypeImpl::GetType<MLFloat16>() : DataTypeImpl::GetType<float>();
|
||||
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size};
|
||||
TensorShape past_shape(&past_state_dims[0], 5);
|
||||
OrtValue empty_past;
|
||||
Tensor::InitOrtValue(past_type, past_shape, default_allocator, empty_past);
|
||||
|
||||
// The ordering is the same as used in Setup
|
||||
feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
|
||||
|
||||
OrtValue expanded_position_ids;
|
||||
OrtValue expanded_attention_mask;
|
||||
ORT_RETURN_IF_ERROR(create_inputs_func(&input_ids, num_beams, pad_token_id, sequence_lengths, cpu_alloactor, expanded_input_ids, expanded_position_ids, expanded_attention_mask));
|
||||
|
||||
ORT_RETURN_IF_ERROR(add_to_feeds_func(provider, expanded_input_ids, expanded_position_ids, expanded_attention_mask, feeds, buffer));
|
||||
|
||||
// The remaing inputs are past state.
|
||||
for (int i = 3; i < num_subgraph_inputs; ++i) {
|
||||
feeds.push_back(empty_past);
|
||||
}
|
||||
|
||||
// pass in implicit inputs
|
||||
for (const auto* entry : implicit_inputs) {
|
||||
feeds.push_back(*entry);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,7 +1,11 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <memory>
|
||||
#include <assert.h>
|
||||
#include "logits_processor.h"
|
||||
#include "dump_tensor.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "contrib_ops/cpu/transformers/logits_processor.h"
|
||||
#include "contrib_ops/cpu/transformers/dump_tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -102,7 +106,8 @@ void NoRepeatNGramLogitsProcessor<T>::Process(const ISequences* sequences,
|
|||
std::unordered_set<int32_t> blocked_word_ids;
|
||||
for (int j = 0; j <= static_cast<int>(sequence.length()) - ngram_size_; j++) {
|
||||
// Here we use naive algorithm for matching. The complexity is O(batch_beam_size * ngram_size * sequence_length)
|
||||
// TODO: build N-Gram index (hash table with prefix of length NGram - 1 as key, and list of last word of NGram as value) for fast matching.
|
||||
// TODO(tianleiwu): build N-Gram index (hash table with prefix of length NGram - 1 as key,
|
||||
// and list of last word of NGram as value) for fast matching.
|
||||
if (ngram_size_ == 1 || prefix == sequence.subspan(j, prefix_length)) {
|
||||
blocked_word_ids.insert(sequence[static_cast<gsl::index>(j) + prefix_length]);
|
||||
}
|
||||
|
|
@ -119,7 +124,8 @@ void NoRepeatNGramLogitsProcessor<T>::Process(const ISequences* sequences,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
VocabMaskLogitsProcessor<T>::VocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask) : vocab_mask_(vocab_mask) {
|
||||
VocabMaskLogitsProcessor<T>::VocabMaskLogitsProcessor(const gsl::span<const int32_t>& vocab_mask)
|
||||
: vocab_mask_(vocab_mask) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -145,8 +151,10 @@ void VocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
PrefixVocabMaskLogitsProcessor<T>::PrefixVocabMaskLogitsProcessor(const gsl::span<const int32_t>& prefix_vocab_mask, int batch_size)
|
||||
: prefix_vocab_mask_(prefix_vocab_mask), batch_size_(batch_size) {
|
||||
PrefixVocabMaskLogitsProcessor<T>::PrefixVocabMaskLogitsProcessor(const gsl::span<const int32_t>& prefix_vocab_mask,
|
||||
int batch_size)
|
||||
: prefix_vocab_mask_(prefix_vocab_mask),
|
||||
batch_size_(batch_size) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -159,7 +167,7 @@ void PrefixVocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
|
|||
assert(num_beams * batch_size_ == next_token_scores.batch_beam_size);
|
||||
|
||||
// Process prefix vocabulary mask and set tokens with mask value 0 to -inf.
|
||||
// prefix_vocab_mask shape (batch_szie, vocab_size).
|
||||
// prefix_vocab_mask shape (batch_size, vocab_size).
|
||||
T* p = next_token_scores.scores.data();
|
||||
for (int i = 0; i < batch_size_; i++) {
|
||||
size_t prefix_vocab_mask_offset = SafeInt<size_t>(i) * next_token_scores.vocab_size;
|
||||
|
|
@ -181,7 +189,8 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
|
|||
processor_list_.clear();
|
||||
|
||||
if (parameters.repetition_penalty != 1.0f) { // 1.0 means no penalty
|
||||
repetition_penalty_processor_ = std::make_unique<RepetitionPenaltyLogitsProcessor<float>>(parameters.repetition_penalty);
|
||||
repetition_penalty_processor_ = std::make_unique<RepetitionPenaltyLogitsProcessor<float>>(
|
||||
parameters.repetition_penalty);
|
||||
processor_list_.push_back(repetition_penalty_processor_.get());
|
||||
}
|
||||
|
||||
|
|
@ -196,12 +205,14 @@ void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
|
|||
}
|
||||
|
||||
if (!parameters.prefix_vocab_mask.empty()) {
|
||||
prefix_vocab_mask_processor_ = std::make_unique<PrefixVocabMaskLogitsProcessor<float>>(parameters.prefix_vocab_mask, parameters.batch_size);
|
||||
prefix_vocab_mask_processor_ = std::make_unique<PrefixVocabMaskLogitsProcessor<float>>(parameters.prefix_vocab_mask,
|
||||
parameters.batch_size);
|
||||
processor_list_.push_back(prefix_vocab_mask_processor_.get());
|
||||
}
|
||||
|
||||
if (parameters.min_length > 0) {
|
||||
min_length_processor_ = std::make_unique<MinLengthLogitsProcessor<float>>(parameters.min_length, parameters.eos_token_id);
|
||||
min_length_processor_ = std::make_unique<MinLengthLogitsProcessor<float>>(parameters.min_length,
|
||||
parameters.eos_token_id);
|
||||
processor_list_.push_back(min_length_processor_.get());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,12 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "sequences.h"
|
||||
#include "beam_search_parameters.h"
|
||||
|
||||
#include "core/common/inlined_containers.h"
|
||||
#include "beam_search_shared.h"
|
||||
#include "contrib_ops/cpu/transformers/sequences.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
#include "sequences.h"
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/common/safeint.h"
|
||||
#include "contrib_ops/cpu/transformers/sequences.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -20,8 +23,10 @@ void Sequences::Init(gsl::span<int32_t> buffer, int batch_beam_size, int sequenc
|
|||
}
|
||||
|
||||
gsl::span<const int32_t> Sequences::GetSequence(int beam_index) const {
|
||||
gsl::span<const int32_t> buffer(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size());
|
||||
gsl::span<const int32_t> sequence = buffer.subspan(SafeInt<size_t>(beam_index) * max_length_, static_cast<gsl::index>(current_length_));
|
||||
gsl::span<const int32_t> buffer(sequences[current_sequences_buffer].data(),
|
||||
sequences[current_sequences_buffer].size());
|
||||
gsl::span<const int32_t> sequence = buffer.subspan(SafeInt<size_t>(beam_index) * max_length_,
|
||||
static_cast<gsl::index>(current_length_));
|
||||
return sequence;
|
||||
}
|
||||
|
||||
|
|
@ -42,13 +47,16 @@ void Sequences::PrintSequences(const IConsoleDumper* dumper) const {
|
|||
void Sequences::AppendNextTokenToSequences(
|
||||
gsl::span<int32_t>& beam_indices,
|
||||
gsl::span<int32_t>& beam_next_tokens) {
|
||||
gsl::span<const int32_t> input(sequences[current_sequences_buffer].data(), sequences[current_sequences_buffer].size());
|
||||
gsl::span<const int32_t> input(sequences[current_sequences_buffer].data(),
|
||||
sequences[current_sequences_buffer].size());
|
||||
gsl::span<int32_t> output = sequences[1 - current_sequences_buffer];
|
||||
|
||||
for (int i = 0; i < batch_beam_size_; i++) {
|
||||
int beam_index = beam_indices[i];
|
||||
gsl::span<const int32_t> source = input.subspan(SafeInt<size_t>(beam_index) * max_length_, static_cast<gsl::index>(current_length_));
|
||||
gsl::span<int32_t> target = output.subspan(SafeInt<size_t>(i) * max_length_, static_cast<gsl::index>(current_length_));
|
||||
gsl::span<const int32_t> source = input.subspan(SafeInt<size_t>(beam_index) * max_length_,
|
||||
static_cast<gsl::index>(current_length_));
|
||||
gsl::span<int32_t> target = output.subspan(SafeInt<size_t>(i) * max_length_,
|
||||
static_cast<gsl::index>(current_length_));
|
||||
gsl::copy(source, target);
|
||||
}
|
||||
|
||||
|
|
@ -65,4 +73,4 @@ void Sequences::AppendNextTokenToSequences(
|
|||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "gsl/gsl"
|
||||
#include "beam_search_shared.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_shared.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -47,4 +50,4 @@ class Sequences : public ISequences {
|
|||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
163
onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
Normal file
163
onnxruntime/contrib_ops/cpu/transformers/subgraph_base.cc
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <utility>
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "gsl/gsl"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_base.h"
|
||||
#include "contrib_ops/cpu/transformers/dump_tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
Subgraph::Subgraph(
|
||||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in)
|
||||
: node(node_in),
|
||||
attribute(attribute_name),
|
||||
subgraph(subgraph_in),
|
||||
num_heads(0),
|
||||
head_size(0),
|
||||
vocab_size(0),
|
||||
num_layers(0),
|
||||
allocator_(nullptr),
|
||||
is_output_float16_(false) {
|
||||
num_implicit_inputs = static_cast<int>(node.ImplicitInputDefs().size());
|
||||
|
||||
auto& subgraph_inputs = subgraph.GetInputs();
|
||||
auto& subgraph_outputs = subgraph.GetOutputs();
|
||||
|
||||
// inputs: input_ids, position_ids, attention_mask, past_0, past_1, ...
|
||||
// outputs: logits, present_0, present_1, ...
|
||||
num_subgraph_inputs = static_cast<int>(subgraph_inputs.size());
|
||||
num_subgraph_outputs = static_cast<int>(subgraph_outputs.size());
|
||||
|
||||
// CheckSubgraph will verify inputs and outputs later.
|
||||
subgraph_input_names.reserve(num_subgraph_inputs);
|
||||
for (int i = 0; i < num_subgraph_inputs; ++i) {
|
||||
subgraph_input_names.push_back(subgraph_inputs[i]->Name());
|
||||
}
|
||||
|
||||
subgraph_output_names.reserve(num_subgraph_outputs);
|
||||
for (int i = 0; i < num_subgraph_outputs; ++i) {
|
||||
subgraph_output_names.push_back(subgraph_outputs[i]->Name());
|
||||
}
|
||||
}
|
||||
|
||||
Status Subgraph::Setup(const SessionState& session_state,
|
||||
const SessionState& subgraph_session_state) {
|
||||
session_state_ = &session_state;
|
||||
subgraph_session_state_ = &subgraph_session_state;
|
||||
|
||||
std::vector<std::string> feed_names;
|
||||
feed_names.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
|
||||
|
||||
// Use the first output (logits) to find device location.
|
||||
const OrtMemoryInfo& default_location = utils::FindMemoryInfoForValue(subgraph_session_state,
|
||||
subgraph_output_names[0]);
|
||||
|
||||
// The position_ids, attention_mask, past_0, ... are created by this operator so the name doesn't matter.
|
||||
feed_names.insert(feed_names.end(), subgraph_input_names.begin(), subgraph_input_names.end());
|
||||
|
||||
for (auto& entry : node.ImplicitInputDefs()) {
|
||||
feed_names.push_back(entry->Name());
|
||||
}
|
||||
|
||||
std::vector<OrtDevice> feed_locations;
|
||||
feed_locations.resize(feed_names.size());
|
||||
|
||||
for (size_t i = 0, end = feed_names.size(); i < end; ++i) {
|
||||
if (i >= subgraph_input_names.size()) { // Implicit inputs
|
||||
const auto& location = utils::FindMemoryInfoForValue(session_state, feed_names[i]);
|
||||
feed_locations[i] = location.device;
|
||||
} else {
|
||||
feed_locations[i] = default_location.device;
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<FeedsFetchesManager> ffm;
|
||||
ORT_RETURN_IF_ERROR(FeedsFetchesManager::Create(feed_names, subgraph_output_names,
|
||||
subgraph_session_state.GetOrtValueNameIdxMap(), ffm));
|
||||
ORT_RETURN_IF_ERROR(utils::InitializeFeedFetchCopyInfo(subgraph_session_state, *ffm));
|
||||
|
||||
// Setup the locations where we want the subgraph output to end up on
|
||||
std::vector<const OrtMemoryInfo*> fetch_locations;
|
||||
fetch_locations.reserve(num_subgraph_outputs);
|
||||
|
||||
// Past state need to be where we can feed them in to the next iteration, so set the location to match the feed.
|
||||
for (int i = 0; i < num_subgraph_outputs; ++i) {
|
||||
fetch_locations.push_back(&default_location);
|
||||
}
|
||||
|
||||
utils::FinalizeFeedFetchCopyInfo(*ffm, feed_locations, fetch_locations);
|
||||
|
||||
feeds_fetches_manager_ = std::move(ffm);
|
||||
|
||||
// Check subgraph only need once so put in Setup function.
|
||||
auto& inputs = subgraph.GetInputs();
|
||||
auto& outputs = subgraph.GetOutputs();
|
||||
ORT_RETURN_IF_ERROR(Validate(inputs, outputs));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
const IExecutionProvider* Subgraph::GetProvider() const {
|
||||
const ExecutionProviders& providers = session_state_->GetExecutionProviders();
|
||||
const IExecutionProvider* cpu_provider = providers.Get(onnxruntime::kCpuExecutionProvider);
|
||||
const IExecutionProvider* cuda_provider = providers.Get(onnxruntime::kCudaExecutionProvider);
|
||||
const IExecutionProvider* provider = cuda_provider ? cuda_provider : cpu_provider;
|
||||
return provider;
|
||||
}
|
||||
|
||||
Status Subgraph::GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shape,
|
||||
const ONNX_NAMESPACE::TensorShapeProto* logits_shape,
|
||||
bool merged_past) {
|
||||
if (merged_past) {
|
||||
// Merged past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads)
|
||||
ORT_RETURN_IF(past_shape->dim_size() != 5,
|
||||
"subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size());
|
||||
ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2,
|
||||
"subgraph past state dimension 0 shall have length of 2");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for number of heads");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0,
|
||||
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
|
||||
this->num_heads = static_cast<int>(past_shape->dim(2).dim_value());
|
||||
this->head_size = static_cast<int>(past_shape->dim(4).dim_value());
|
||||
} else {
|
||||
// Past state shape is like (batch_size, num_heads, past_seq_len, hidden_size/num_heads).
|
||||
ORT_RETURN_IF(past_shape->dim_size() != 4,
|
||||
"subgraph output present_key_self_0 is expected to have 4 dimension, got ", past_shape->dim_size());
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(1).has_dim_value() || past_shape->dim(1).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for number of heads");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(3).has_dim_value() || past_shape->dim(3).dim_value() <= 0,
|
||||
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
|
||||
this->num_heads = static_cast<int>(past_shape->dim(1).dim_value());
|
||||
this->head_size = static_cast<int>(past_shape->dim(3).dim_value());
|
||||
}
|
||||
|
||||
// Logits shape is like (batch_size, seq_len, vocabulary_size)
|
||||
ORT_RETURN_IF(logits_shape->dim_size() != 3,
|
||||
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
|
||||
|
||||
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
|
||||
|
||||
this->vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -2,6 +2,9 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include "gsl/gsl"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/feeds_fetches_manager.h"
|
||||
|
|
@ -15,21 +18,22 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
// A class for GPT-2 subgraph inputs and outputs preparation.
|
||||
struct GptSubgraph {
|
||||
GptSubgraph(
|
||||
class Subgraph {
|
||||
public:
|
||||
Subgraph(
|
||||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in);
|
||||
virtual ~Subgraph() {}
|
||||
|
||||
const onnxruntime::Node& node; // node that contains the subgraph
|
||||
const std::string& attribute; // attribute of th node that contains the subgraph. Not used yet.
|
||||
const GraphViewer& subgraph; // the subgraph
|
||||
const onnxruntime::Node& node; // Node that contains the subgraph
|
||||
const std::string& attribute; // Attribute of th node that contains the subgraph. Not used yet.
|
||||
const GraphViewer& subgraph; // The subgraph
|
||||
|
||||
int num_implicit_inputs;
|
||||
|
||||
int num_subgraph_inputs; // same as subgraph_input_names.size(), keep it for convenience.
|
||||
int num_subgraph_outputs; // same as subgraph_output_names.size()
|
||||
int num_subgraph_inputs; // Same as subgraph_input_names.size(), keep it for convenience.
|
||||
int num_subgraph_outputs; // Same as subgraph_output_names.size()
|
||||
|
||||
std::vector<std::string> subgraph_input_names;
|
||||
std::vector<std::string> subgraph_output_names;
|
||||
|
|
@ -40,32 +44,23 @@ struct GptSubgraph {
|
|||
int vocab_size;
|
||||
int num_layers;
|
||||
|
||||
// Setup exectuion
|
||||
// Setup execution
|
||||
Status Setup(const SessionState& session_state,
|
||||
const SessionState& subgraph_session_state);
|
||||
|
||||
// Create inputs for first inference of subgraph.
|
||||
Status CreateInitialFeeds(
|
||||
const Tensor& input_ids,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
const BeamSearchDeviceHelper::CreateInputsFunc& create_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
IAllocatorUniquePtr<char>& buffer);
|
||||
|
||||
FeedsFetchesManager* GetFeedsFetchesManager() const { return feeds_fetches_manager_.get(); }
|
||||
|
||||
const IExecutionProvider* GetProvider() const;
|
||||
|
||||
bool IsOutputFloat16() const { return is_output_float16_; }
|
||||
|
||||
virtual Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) = 0;
|
||||
|
||||
protected:
|
||||
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs);
|
||||
Status GetParameters(const ONNX_NAMESPACE::TensorShapeProto* past_shape,
|
||||
const ONNX_NAMESPACE::TensorShapeProto* logits_shape,
|
||||
bool merged_past);
|
||||
|
||||
AllocatorPtr allocator_;
|
||||
const SessionState* session_state_;
|
||||
167
onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc
Normal file
167
onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.cc
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "gsl/gsl"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
|
||||
#include "contrib_ops/cpu/transformers/dump_tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
Status GptSubgraph::CreateInitialFeeds(
|
||||
const Tensor& input_ids,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
const BeamSearchDeviceHelper::CreateGptInputsFunc& create_gpt_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
IAllocatorUniquePtr<char>& buffer) {
|
||||
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
|
||||
|
||||
const IExecutionProvider* provider = GetProvider();
|
||||
|
||||
const TensorShape& input_ids_shape = input_ids.Shape();
|
||||
ORT_ENFORCE(input_ids_shape.NumDimensions() == 2);
|
||||
const int64_t& batch_size = input_ids_shape[0];
|
||||
|
||||
// Subgraph inputs:
|
||||
// input_ids: shape (B, S) where B is batch size, and S is sequence length
|
||||
// position_ids: shape (B, S)
|
||||
// attention_mask: shape (B, P+S), where past_sequence_length (P) is 0
|
||||
// After expansion, their shapes will become (B, M*S), where M is num_beams.
|
||||
|
||||
// Allocate subgraph inputs to be same device as input_ids
|
||||
AllocatorPtr cpu_allocator = session_state_->GetAllocator(input_ids.Location());
|
||||
|
||||
// Store allocator, which will be used in remaining feeds
|
||||
auto default_allocator = provider->GetAllocator(0, OrtMemTypeDefault);
|
||||
allocator_ = default_allocator;
|
||||
|
||||
// Initialize empty past state
|
||||
auto past_type = IsOutputFloat16() ? DataTypeImpl::GetType<MLFloat16>() : DataTypeImpl::GetType<float>();
|
||||
int64_t past_state_dims[] = {2, batch_size * num_beams, num_heads, 0, head_size};
|
||||
TensorShape past_shape(&past_state_dims[0], 5);
|
||||
OrtValue empty_past;
|
||||
Tensor::InitOrtValue(past_type, past_shape, default_allocator, empty_past);
|
||||
|
||||
// The ordering is the same as used in Setup
|
||||
feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
|
||||
|
||||
OrtValue expanded_position_ids;
|
||||
OrtValue expanded_attention_mask;
|
||||
ORT_RETURN_IF_ERROR(create_gpt_inputs_func(&input_ids,
|
||||
num_beams,
|
||||
pad_token_id,
|
||||
sequence_lengths,
|
||||
cpu_allocator,
|
||||
expanded_input_ids,
|
||||
expanded_position_ids,
|
||||
expanded_attention_mask));
|
||||
|
||||
ORT_RETURN_IF_ERROR(add_to_feeds_func(provider,
|
||||
{expanded_input_ids, expanded_position_ids, expanded_attention_mask},
|
||||
feeds,
|
||||
buffer));
|
||||
|
||||
// The remaining inputs are past state.
|
||||
for (int i = kFirstPastInputIndex; i < num_subgraph_inputs; ++i) {
|
||||
feeds.push_back(empty_past);
|
||||
}
|
||||
|
||||
// Pass in implicit inputs
|
||||
for (const auto* entry : implicit_inputs) {
|
||||
feeds.push_back(*entry);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GptSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) {
|
||||
ORT_RETURN_IF(num_subgraph_outputs <= kFirstPresentOutputIndex,
|
||||
"Invalid GPT-2 subgraph: number of outputs shall be larger than 1 (Need past state in outputs).");
|
||||
|
||||
ORT_RETURN_IF(num_subgraph_inputs != num_subgraph_outputs + 2,
|
||||
"Invalid GPT-2 subgraph: number of inputs shall be number of outputs plus 2");
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
|
||||
"subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "position_ids",
|
||||
"subgraph input 1 shall be named as position_ids, got: ", subgraph_inputs[1]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "attention_mask",
|
||||
"subgraph input 2 shall be named as attention_mask, got: ", subgraph_inputs[2]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[3]->Name() != "past_0",
|
||||
"subgraph input 3 shall be named as past_0, got: ", subgraph_inputs[3]->Name());
|
||||
|
||||
// Past state shape is like (2, batch_size, num_heads, past_seq_len, hidden_size/num_heads).
|
||||
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_inputs[3]->Shape();
|
||||
ORT_RETURN_IF(past_shape->dim_size() != 5,
|
||||
"subgraph past state is expected to have 5 dimension, got ", past_shape->dim_size());
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(0).has_dim_value() || past_shape->dim(0).dim_value() != 2,
|
||||
"subgraph past state dimension 0 shall have length of 2");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(2).has_dim_value() || past_shape->dim(2).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for number of heads");
|
||||
|
||||
ORT_RETURN_IF(!past_shape->dim(4).has_dim_value() || past_shape->dim(4).dim_value() <= 0,
|
||||
"subgraph past state dimension 4 shall have a positive value for hidden size per head");
|
||||
|
||||
// check subgraph outputs
|
||||
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
|
||||
"subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
|
||||
|
||||
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "present_0",
|
||||
"subgraph input 1 shall be named as present_0, got: ", subgraph_outputs[1]->Name());
|
||||
|
||||
// Logits shape is like (batch_size, seq_len, 50257). Here 50257 is the vocabulary size.
|
||||
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
|
||||
ORT_RETURN_IF(logits_shape->dim_size() != 3,
|
||||
"subgraph logits output is expected to have 3 dimension, got ", logits_shape->dim_size());
|
||||
|
||||
ORT_RETURN_IF(!logits_shape->dim(2).has_dim_value() || logits_shape->dim(2).dim_value() <= 0,
|
||||
"subgraph past state dimension 2 shall have a positive value for vocabulary size");
|
||||
|
||||
// Save parameters related to the subgraph.
|
||||
num_heads = static_cast<int>(past_shape->dim(2).dim_value());
|
||||
head_size = static_cast<int>(past_shape->dim(4).dim_value());
|
||||
vocab_size = static_cast<int>(logits_shape->dim(2).dim_value());
|
||||
num_layers = static_cast<int>(subgraph_outputs.size()) - 1;
|
||||
|
||||
constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
|
||||
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
|
||||
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"subgraph input 0 (input_ids) shall have int32 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"subgraph input 1 (position_ids) shall have int32 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"subgraph input 2 (attention_mask) shall have int32 type");
|
||||
|
||||
auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
ORT_RETURN_IF(output_type != float32_type && output_type != float16_type,
|
||||
"subgraph output 0 (logits) shall be float or float16 data type");
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[kFirstPastInputIndex]->TypeAsProto()->tensor_type().elem_type() != output_type,
|
||||
"subgraph input 3 (past_0) shall shall have same data type of logits output");
|
||||
ORT_RETURN_IF(subgraph_outputs[kFirstPresentOutputIndex]->TypeAsProto()->tensor_type().elem_type() != output_type,
|
||||
"subgraph output 1 (present_0) shall shall have same data type of logits output");
|
||||
|
||||
is_output_float16_ = (output_type == float16_type);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
42
onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h
Normal file
42
onnxruntime/contrib_ops/cpu/transformers/subgraph_gpt.h
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/transformers/subgraph_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
// A class for GPT-2 subgraph inputs and outputs preparation.
|
||||
class GptSubgraph : public Subgraph {
|
||||
public:
|
||||
GptSubgraph(
|
||||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {}
|
||||
|
||||
// Create inputs for first inference of subgraph.
|
||||
Status CreateInitialFeeds(
|
||||
const Tensor& input_ids,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
const BeamSearchDeviceHelper::CreateGptInputsFunc& create_gpt_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
IAllocatorUniquePtr<char>& buffer);
|
||||
|
||||
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) override;
|
||||
|
||||
constexpr static int kFirstPastInputIndex = 3;
|
||||
constexpr static int kFirstPresentOutputIndex = 1;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
154
onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Normal file
154
onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_decoder.cc
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "gsl/gsl"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
|
||||
#include "contrib_ops/cpu/transformers/dump_tensor.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_device_helper.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
/* T5 Decoder Subgraph.
|
||||
|
||||
Inputs:
|
||||
input_ids: int32 (B, 1)
|
||||
encoder_attention_mask: int32 (B, encode_sequence_length)
|
||||
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
|
||||
|
||||
past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
|
||||
past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
|
||||
... (for each self attention layer)
|
||||
|
||||
past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
|
||||
past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
|
||||
... (for each cross attention layer)
|
||||
|
||||
Outputs:
|
||||
logits: (B, 1, vocab_size)
|
||||
|
||||
present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
|
||||
present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
|
||||
... (for each self attention layer)
|
||||
|
||||
Note:
|
||||
B = batch_size * num_beams
|
||||
Data type of input or output is float or float16 if not specified.
|
||||
*/
|
||||
|
||||
Status T5DecoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) {
|
||||
ORT_RETURN_IF(num_subgraph_inputs < 7 || (num_subgraph_inputs - kFirstPastInputIndex) % 4 != 0,
|
||||
"number of outputs expected to be 3 + 4 * layers, got:", num_subgraph_inputs);
|
||||
ORT_RETURN_IF(num_subgraph_outputs < 3 || (num_subgraph_outputs - kFirstPresentOutputIndex) % 2 != 0,
|
||||
"number of outputs expected to be 1 + 2 * layers, got:", num_subgraph_outputs);
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "input_ids",
|
||||
"decoder subgraph input 0 shall be named as input_ids, got: ", subgraph_inputs[0]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
|
||||
"decoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "encoder_hidden_states",
|
||||
"decoder subgraph input 2 shall be named as encoder_hidden_states, got: ", subgraph_inputs[2]->Name());
|
||||
|
||||
// check subgraph outputs
|
||||
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
|
||||
"decoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
|
||||
|
||||
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
|
||||
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[kFirstPresentOutputIndex]->Shape();
|
||||
|
||||
// Save parameters related to the subgraph.
|
||||
ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false));
|
||||
num_layers = (static_cast<int>(subgraph_outputs.size()) - kFirstPresentOutputIndex) / 2;
|
||||
|
||||
constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
|
||||
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
|
||||
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"decoder subgraph input 0 (input_ids) shall have int32 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"decoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
|
||||
|
||||
auto float_type = subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type();
|
||||
ORT_RETURN_IF(float_type != float32_type && float_type != float16_type,
|
||||
"decoder subgraph input 2 (encoder_hidden_states) shall have float or float16 type");
|
||||
|
||||
for (int i = kFirstPastInputIndex; i < num_subgraph_inputs; i++) {
|
||||
ORT_RETURN_IF(subgraph_inputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
|
||||
"decoder subgraph past inputs shall have same data type as that of encoder_hidden_states");
|
||||
}
|
||||
|
||||
for (int i = 0; i < num_subgraph_outputs; i++) {
|
||||
ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != float_type,
|
||||
"decoder subgraph output shall have same data type as that of encoder_hidden_states");
|
||||
}
|
||||
|
||||
is_output_float16_ = (subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type() == float16_type);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create inputs for decoder from the following data sources:
|
||||
// encoder feeds: encoder_input_ids, encoder_attention_mask, decoder_input_ids (with start tokens)
|
||||
// encoder fetches: logits,
|
||||
// encoder_hidden_states,
|
||||
// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
|
||||
// decoder_feeds: input_ids,
|
||||
// encoder_attention_mask,
|
||||
// encoder_hidden_states,
|
||||
// present_key_self_0, present_value_self_0, ..., present_key_cross_0, present_value_cross_0, ...
|
||||
Status T5DecoderSubgraph::CreateInitialFeeds(
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
const std::vector<OrtValue>& encoder_feeds,
|
||||
const std::vector<OrtValue>& encoder_fetches,
|
||||
std::vector<OrtValue>& decoder_feeds,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
|
||||
void* stream) {
|
||||
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
|
||||
|
||||
// Allocate subgraph inputs from same device as inputs of encoder subgraph.
|
||||
AllocatorPtr allocator = session_state_->GetAllocator(encoder_feeds[0].Get<Tensor>().Location());
|
||||
|
||||
// Copy beam next tokens in CPU to input_ids in provider device (CPU for CPU EP, or GPU for CUDA EP).
|
||||
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
|
||||
int64_t dims[] = {batch_beam_size, 1};
|
||||
TensorShape input_ids_shape(&dims[0], 2);
|
||||
OrtValue input_ids;
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), input_ids_shape, allocator, input_ids);
|
||||
ORT_RETURN_IF_ERROR(device_copy_int32_func(
|
||||
input_ids.GetMutable<Tensor>()->MutableDataAsSpan<int32_t>(),
|
||||
beam_next_tokens,
|
||||
stream,
|
||||
DeviceCopyDirection::hostToDevice));
|
||||
|
||||
// The ordering is the same as used in Setup.
|
||||
decoder_feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
|
||||
decoder_feeds.push_back(input_ids);
|
||||
|
||||
// The encoder_attention_mask is copied from the second input of encoder.
|
||||
decoder_feeds.push_back(encoder_feeds[1]);
|
||||
|
||||
// The encoder_hidden_states and past states are copied from the second output of encoder.
|
||||
for (size_t j = 1; j < encoder_fetches.size(); j++) {
|
||||
decoder_feeds.push_back(encoder_fetches[j]);
|
||||
}
|
||||
|
||||
// Pass through implicit inputs.
|
||||
for (const auto* entry : implicit_inputs) {
|
||||
decoder_feeds.push_back(*entry);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/transformers/subgraph_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
// A class for T5 decoder subgraph inputs and outputs preparation.
|
||||
class T5DecoderSubgraph : public Subgraph {
|
||||
public:
|
||||
T5DecoderSubgraph(
|
||||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {}
|
||||
|
||||
// Create inputs for first inference of decoder subgraph.
|
||||
Status CreateInitialFeeds(
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
const std::vector<OrtValue>& encoder_feeds,
|
||||
const std::vector<OrtValue>& encoder_fetches,
|
||||
std::vector<OrtValue>& decoder_feeds,
|
||||
const BeamSearchDeviceHelper::DeviceCopyFunc<int32_t>& device_copy_int32_func,
|
||||
void* stream);
|
||||
|
||||
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) override;
|
||||
|
||||
constexpr static int kFirstPastInputIndex = 3;
|
||||
constexpr static int kFirstPresentOutputIndex = 1;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
145
onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc
Normal file
145
onnxruntime/contrib_ops/cpu/transformers/subgraph_t5_encoder.cc
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/framework/framework_common.h"
|
||||
#include "core/framework/session_state.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "gsl/gsl"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_encoder.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
/* T5 Encoder Subgraph (It also contains decoder initialization where decoder_input_ids are filled with start token ID).
|
||||
|
||||
Inputs:
|
||||
encoder_input_ids: int32 (B, encode_sequence_length)
|
||||
encoder_attention_mask: int32 (B, encode_sequence_length)
|
||||
decoder_input_ids: int32 (B, 1)
|
||||
|
||||
Outputs:
|
||||
logits: (B, 1, vocab_size)
|
||||
encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
|
||||
|
||||
present_key_self_0: (B, num_heads, 1, head_size)
|
||||
present_value_self_0: (B, num_heads, 1, head_size)
|
||||
... (for each self attention layer)
|
||||
|
||||
present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
|
||||
present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
|
||||
... (for each cross attention layer)
|
||||
|
||||
Note:
|
||||
Here, B = batch_size * num_beams since we expand the inputs.
|
||||
Ideally, we could use B=batch_size and expand the outputs with a factor of num_beams.
|
||||
Data type of input or output is float or float16 if not specified.
|
||||
*/
|
||||
|
||||
Status T5EncoderSubgraph::Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) {
|
||||
ORT_RETURN_IF(num_subgraph_inputs != 3, "expect 3 inputs, got:", num_subgraph_inputs);
|
||||
|
||||
ORT_RETURN_IF(num_subgraph_outputs < 6, "expect >=6 outputs, got:", num_subgraph_outputs);
|
||||
ORT_RETURN_IF((static_cast<int>(subgraph_outputs.size()) - kFirstPresentOutputIndex) % 4 != 0,
|
||||
"number of outputs expected to be 2 + 4 * layers, got:", num_subgraph_outputs);
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->Name() != "encoder_input_ids",
|
||||
"encoder subgraph input 0 shall be named as encoder_input_ids, got: ", subgraph_inputs[0]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->Name() != "encoder_attention_mask",
|
||||
"encoder subgraph input 1 shall be named as encoder_attention_mask, got: ", subgraph_inputs[1]->Name());
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->Name() != "decoder_input_ids",
|
||||
"encoder subgraph input 2 shall be named as decoder_input_ids, got: ", subgraph_inputs[2]->Name());
|
||||
|
||||
ORT_RETURN_IF(subgraph_outputs[0]->Name() != "logits",
|
||||
"encoder subgraph output 0 shall be named as logits, got: ", subgraph_outputs[0]->Name());
|
||||
ORT_RETURN_IF(subgraph_outputs[1]->Name() != "encoder_hidden_states",
|
||||
"encoder subgraph output 1 shall be named encoder_hidden_states, got: ", subgraph_outputs[1]->Name());
|
||||
ORT_RETURN_IF(subgraph_outputs[2]->Name() != "present_key_self_0",
|
||||
"encoder subgraph output 2 shall be named as present_key_self_0, got: ", subgraph_outputs[2]->Name());
|
||||
ORT_RETURN_IF(subgraph_outputs[3]->Name() != "present_value_self_0",
|
||||
"encoder subgraph output 3 shall be named as present_value_self_0, got: ", subgraph_outputs[3]->Name());
|
||||
|
||||
const ONNX_NAMESPACE::TensorShapeProto* past_shape = subgraph_outputs[2]->Shape();
|
||||
const ONNX_NAMESPACE::TensorShapeProto* logits_shape = subgraph_outputs[0]->Shape();
|
||||
|
||||
// Save parameters related to the subgraph.
|
||||
ORT_RETURN_IF_ERROR(GetParameters(past_shape, logits_shape, false));
|
||||
num_layers = (static_cast<int>(subgraph_outputs.size()) - kFirstPresentOutputIndex) / 4;
|
||||
|
||||
constexpr auto int32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_INT32;
|
||||
constexpr auto float32_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT;
|
||||
constexpr auto float16_type = ONNX_NAMESPACE::TensorProto_DataType::TensorProto_DataType_FLOAT16;
|
||||
|
||||
ORT_RETURN_IF(subgraph_inputs[0]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"encoder subgraph input 0 (encoder_input_ids) shall have int32 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[1]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"encoder subgraph input 1 (encoder_attention_mask) shall have int32 type");
|
||||
ORT_RETURN_IF(subgraph_inputs[2]->TypeAsProto()->tensor_type().elem_type() != int32_type,
|
||||
"encoder subgraph input 2 (decoder_input_ids) shall have int32 type");
|
||||
|
||||
auto output_type = subgraph_outputs[0]->TypeAsProto()->tensor_type().elem_type();
|
||||
ORT_RETURN_IF(output_type != float32_type && output_type != float16_type,
|
||||
"encoder subgraph output 0 (logits) shall be float or float16 data type");
|
||||
|
||||
for (int i = 1; i < num_subgraph_outputs; i++) {
|
||||
ORT_RETURN_IF(subgraph_outputs[i]->TypeAsProto()->tensor_type().elem_type() != output_type,
|
||||
"encoder subgraph outputs 1, 2, ... shall have same data type");
|
||||
}
|
||||
|
||||
is_output_float16_ = (output_type == float16_type);
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Create inputs for first inference of subgraph.
|
||||
Status T5EncoderSubgraph::CreateInitialFeeds(
|
||||
const Tensor& encoder_input_ids,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
int start_token_id,
|
||||
std::vector<OrtValue>& feeds,
|
||||
const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
IAllocatorUniquePtr<char>& buffer,
|
||||
OrtValue& expanded_decoder_input_ids) {
|
||||
ORT_ENFORCE(session_state_ != nullptr, "Setup must be called before CreateInitialFeeds");
|
||||
|
||||
// The ordering is the same as used in Setup.
|
||||
feeds.reserve(static_cast<size_t>(num_subgraph_inputs) + static_cast<size_t>(num_implicit_inputs));
|
||||
|
||||
// Allocate subgraph inputs to be same device as encoder_input_ids.
|
||||
AllocatorPtr cpu_allocator = session_state_->GetAllocator(encoder_input_ids.Location());
|
||||
|
||||
// TODO(tianleiwu): expand the outputs instead of inputs to save computation.
|
||||
OrtValue expanded_encoder_input_ids;
|
||||
OrtValue expanded_encoder_attention_mask;
|
||||
ORT_RETURN_IF_ERROR(create_encoder_inputs_func(&encoder_input_ids,
|
||||
num_beams,
|
||||
pad_token_id,
|
||||
start_token_id,
|
||||
cpu_allocator,
|
||||
expanded_encoder_input_ids,
|
||||
expanded_encoder_attention_mask,
|
||||
expanded_decoder_input_ids));
|
||||
|
||||
const IExecutionProvider* provider = GetProvider();
|
||||
ORT_RETURN_IF_ERROR(add_to_feeds_func(
|
||||
provider,
|
||||
{expanded_encoder_input_ids, expanded_encoder_attention_mask, expanded_decoder_input_ids},
|
||||
feeds,
|
||||
buffer));
|
||||
|
||||
for (const auto* entry : implicit_inputs) {
|
||||
feeds.push_back(*entry);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/transformers/subgraph_base.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
// A class for T5 encoder subgraph inputs and outputs preparation.
|
||||
class T5EncoderSubgraph : public Subgraph {
|
||||
public:
|
||||
T5EncoderSubgraph(
|
||||
const onnxruntime::Node& node_in,
|
||||
const std::string& attribute_name,
|
||||
const GraphViewer& subgraph_in) : Subgraph(node_in, attribute_name, subgraph_in) {}
|
||||
|
||||
// Create inputs for first inference of subgraph.
|
||||
Status CreateInitialFeeds(
|
||||
const Tensor& encoder_input_ids,
|
||||
const std::vector<const OrtValue*>& implicit_inputs,
|
||||
int num_beams,
|
||||
int pad_token_id,
|
||||
int start_token_id,
|
||||
std::vector<OrtValue>& feeds,
|
||||
const BeamSearchDeviceHelper::CreateEncoderInputsFunc& create_encoder_inputs_func,
|
||||
const BeamSearchDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
IAllocatorUniquePtr<char>& buffer,
|
||||
OrtValue& expanded_decoder_input_ids);
|
||||
|
||||
Status Validate(const std::vector<const NodeArg*>& subgraph_inputs,
|
||||
const std::vector<const NodeArg*>& subgraph_outputs) override;
|
||||
|
||||
constexpr static int kFirstPresentOutputIndex = 2;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -4,8 +4,8 @@
|
|||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#include "contrib_ops/cuda/transformers/beam_search.h"
|
||||
#include "beam_search_device_helper.h"
|
||||
#include "dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cuda/transformers/beam_search_device_helper.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -38,16 +38,19 @@ BeamSearch::BeamSearch(const OpKernelInfo& info)
|
|||
SetComputeStream(static_cast<void*>(info.GetExecutionProvider()->GetComputeStream()));
|
||||
|
||||
SetDeviceHelpers(BeamSearchCudaDeviceHelper::AddToFeeds,
|
||||
BeamSearchCudaDeviceHelper::TopK);
|
||||
|
||||
SetDeviceHelpers(BeamSearchCudaDeviceHelper::ProcessLogits<float>,
|
||||
BeamSearchCudaDeviceHelper::InitBeamState<float>,
|
||||
BeamSearchCudaDeviceHelper::TopK,
|
||||
BeamSearchCudaDeviceHelper::DeviceCopy<float>,
|
||||
BeamSearchCudaDeviceHelper::UpdateFeeds<float>);
|
||||
BeamSearchCudaDeviceHelper::DeviceCopy<int32_t>,
|
||||
BeamSearchCudaDeviceHelper::ProcessLogits<float>,
|
||||
BeamSearchCudaDeviceHelper::ProcessLogits<MLFloat16>,
|
||||
BeamSearchCudaDeviceHelper::InitBeamState<float>,
|
||||
BeamSearchCudaDeviceHelper::InitBeamState<MLFloat16>);
|
||||
|
||||
SetDeviceHelpers(BeamSearchCudaDeviceHelper::ProcessLogits<MLFloat16>,
|
||||
BeamSearchCudaDeviceHelper::InitBeamState<MLFloat16>,
|
||||
BeamSearchCudaDeviceHelper::UpdateFeeds<MLFloat16>);
|
||||
SetDeviceHelpers_Gpt(BeamSearchCudaDeviceHelper::UpdateGptFeeds<float>,
|
||||
BeamSearchCudaDeviceHelper::UpdateGptFeeds<MLFloat16>);
|
||||
|
||||
SetDeviceHelpers_EncoderDecoder(BeamSearchCudaDeviceHelper::UpdateDecoderFeeds<float>,
|
||||
BeamSearchCudaDeviceHelper::UpdateDecoderFeeds<MLFloat16>);
|
||||
|
||||
SetConsoleDumper(&g_cuda_dumper);
|
||||
}
|
||||
|
|
@ -71,4 +74,4 @@ Status BeamSearch::Compute(OpKernelContext* context) const {
|
|||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,16 +1,17 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/cuda/math/topk_impl.h"
|
||||
#include "core/providers/cuda/math/softmax.h"
|
||||
#include "core/providers/cuda/shared_inc/accumulation_type.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
|
||||
#include "beam_search_impl.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include "dump_cuda_tensor.h"
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
using namespace onnxruntime::contrib::cuda::transformers;
|
||||
#endif
|
||||
#include "contrib_ops/cuda/transformers/beam_search_impl.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace concurrency {
|
||||
|
|
@ -52,7 +53,7 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
output_indices = Tensor::Create(DataTypeImpl::GetType<int64_t>(), output_shape, allocator);
|
||||
|
||||
if (input->IsDataType<float>()) {
|
||||
return TopKImpl<float>(nullptr, // We limit number of beams in BeamSearchParameters, so that K <= 256 and kernel is not needed
|
||||
return TopKImpl<float>(nullptr, // We limit number of beams in BeamSearchParameters, so K <= 256 and use NULL here
|
||||
reinterpret_cast<cudaStream_t>(stream),
|
||||
input->Data<float>(),
|
||||
static_cast<float*>(output_values->MutableDataRaw()),
|
||||
|
|
@ -87,35 +88,53 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
}
|
||||
|
||||
Status AddToFeeds(const IExecutionProvider* execution_provider,
|
||||
OrtValue& input_ids,
|
||||
OrtValue& position_ids,
|
||||
OrtValue& attention_mask,
|
||||
std::initializer_list<OrtValue> inputs,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer) {
|
||||
// Copy tensors to GPU, then add to feeds
|
||||
const CUDAExecutionProvider* provider = reinterpret_cast<const CUDAExecutionProvider*>(execution_provider);
|
||||
const TensorShape& shape = input_ids.Get<Tensor>().Shape();
|
||||
ORT_ENFORCE(shape.NumDimensions() == 2);
|
||||
const int64_t elements = shape[0] * shape[1];
|
||||
size_t total_bytes = 0;
|
||||
for (auto& input : inputs) {
|
||||
if (input.IsAllocated()) {
|
||||
total_bytes += input.Get<Tensor>().Shape().Size() * input.Type()->Size();
|
||||
}
|
||||
}
|
||||
|
||||
ORT_ENFORCE(total_bytes > 0);
|
||||
|
||||
AllocatorPtr pinned_allocator = provider->GetAllocator(DEFAULT_CPU_ALLOCATOR_DEVICE_ID, OrtMemTypeCPU);
|
||||
cudaStream_t stream = static_cast<cudaStream_t>(provider->GetComputeStream());
|
||||
|
||||
size_t bytes = (sizeof(int32_t) + sizeof(int32_t) + sizeof(int32_t)) * elements;
|
||||
auto pinned_buffer = IAllocator::MakeUniquePtr<void>(pinned_allocator, bytes);
|
||||
auto pinned_buffer = IAllocator::MakeUniquePtr<void>(pinned_allocator, total_bytes);
|
||||
char* pinned_data = static_cast<char*>(pinned_buffer.get());
|
||||
|
||||
// Copy tensors to one pinned memory buffer (so that we only need copy to GPU once)
|
||||
memcpy(pinned_data, input_ids.Get<Tensor>().Data<int32_t>(), sizeof(int32_t) * elements);
|
||||
memcpy(pinned_data + sizeof(int32_t) * elements, position_ids.Get<Tensor>().Data<int32_t>(), sizeof(int32_t) * elements);
|
||||
memcpy(pinned_data + 2 * sizeof(int32_t) * elements, attention_mask.Get<Tensor>().Data<int32_t>(), sizeof(int32_t) * elements);
|
||||
char* destination = pinned_data;
|
||||
for (auto& input : inputs) {
|
||||
if (input.IsAllocated()) {
|
||||
const Tensor& tensor = input.Get<Tensor>();
|
||||
const size_t bytes = input.Type()->Size() * tensor.Shape().Size();
|
||||
MLDataType dataType = tensor.DataType();
|
||||
if (dataType == DataTypeImpl::GetType<int32_t>()) {
|
||||
memcpy(destination, input.Get<Tensor>().Data<int32_t>(), bytes);
|
||||
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
|
||||
memcpy(destination, input.Get<Tensor>().Data<int64_t>(), bytes);
|
||||
} else {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED,
|
||||
"AddToFeeds: An implementation for the input type ",
|
||||
dataType, " is not supported yet");
|
||||
}
|
||||
|
||||
// Do not need alignment because GPT has int32 inputs (past is empty) and T5 encoder has int64 inputs.
|
||||
destination += bytes;
|
||||
}
|
||||
}
|
||||
|
||||
if (!buffer) {
|
||||
buffer = provider->GetScratchBuffer<char>(bytes);
|
||||
buffer = provider->GetScratchBuffer<char>(total_bytes);
|
||||
}
|
||||
|
||||
char* gpu_data = buffer.get();
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, bytes, cudaMemcpyHostToDevice, stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream));
|
||||
|
||||
// Create an event to make sure the async copy is finished before reading the data.
|
||||
onnxruntime::contrib::cuda::AutoDestoryCudaEvent new_event;
|
||||
|
|
@ -124,33 +143,32 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
|
|||
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone));
|
||||
|
||||
// TODO: allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call.
|
||||
OrtValue device_input_ids;
|
||||
OrtValue device_position_ids;
|
||||
OrtValue device_attention_mask;
|
||||
// TODO(tianleiwu): allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call.
|
||||
const OrtMemoryInfo& location = provider->GetAllocator(0, OrtMemTypeDefault)->Info();
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), shape, gpu_data, location, device_input_ids);
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), shape, gpu_data + sizeof(int32_t) * elements, location, device_position_ids);
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int32_t>(), shape, gpu_data + 2 * sizeof(int32_t) * elements, location, device_attention_mask);
|
||||
for (auto& input : inputs) {
|
||||
if (input.IsAllocated()) {
|
||||
const Tensor& tensor = input.Get<Tensor>();
|
||||
const TensorShape& shape = tensor.Shape();
|
||||
const size_t bytes = input.Type()->Size() * shape.Size();
|
||||
MLDataType dataType = tensor.DataType();
|
||||
|
||||
feeds.push_back(device_input_ids);
|
||||
feeds.push_back(device_position_ids);
|
||||
feeds.push_back(device_attention_mask);
|
||||
OrtValue device_input;
|
||||
Tensor::InitOrtValue(dataType, shape, gpu_data, location, device_input);
|
||||
gpu_data += bytes;
|
||||
feeds.push_back(device_input);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
|
||||
transformers::IBeamSearchCpuState* cpu_state,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
gsl::span<const int32_t> input_ids_in_cpu,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
void* stream) {
|
||||
// TODO: we can use another stream to avoid blocking subgraph execution.
|
||||
// TODO(tianleiwu): we can use another stream to avoid blocking subgraph execution.
|
||||
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||
cudaMemsetAsync(beam_state->next_token_logits.data(), 0, beam_state->next_token_logits.size_bytes(), cuda_stream);
|
||||
cudaMemsetAsync(beam_state->next_token_scores.data(), 0, beam_state->next_token_scores.size_bytes(), cuda_stream);
|
||||
|
|
@ -162,18 +180,9 @@ void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
|
|||
|
||||
// copy sequence lengths to GPU
|
||||
// since next_positions is only needed to update feeds after subgraph execution, so it is fine to use Async here.
|
||||
// cudaMemsetAsync(beam_state->next_positions.data(), 0, beam_state->next_positions.size_bytes(), cuda_stream);
|
||||
cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(), cudaMemcpyHostToDevice, cuda_stream);
|
||||
|
||||
memset(cpu_state->sequences_space.data(), 0, cpu_state->sequences_space.size_bytes());
|
||||
|
||||
// Copy input_ids to sequences[0]
|
||||
gsl::span<int32_t> sequences_0 = cpu_state->sequences_space;
|
||||
int batch_beam_size = batch_size * num_beams;
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
for (int j = 0; j < sequence_length; j++) {
|
||||
sequences_0[SafeInt<gsl::index>(i) * max_length + j] = static_cast<int32_t>(input_ids_in_cpu[SafeInt<gsl::index>(i) * sequence_length + j]);
|
||||
}
|
||||
if (!beam_state->next_positions.empty()) { // next_positions is empty for T5
|
||||
cudaMemcpyAsync(beam_state->next_positions.data(), sequence_lengths.data(), sequence_lengths.size_bytes(),
|
||||
cudaMemcpyHostToDevice, cuda_stream);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -220,12 +229,13 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
// When input_length == 1, use logits directly in SoftmaxCPU below so it only need for input_length > 1.
|
||||
gsl::span<T>& next_token_logits = beam_state->next_token_logits;
|
||||
if (input_length > 1) {
|
||||
// TODO: use one kernel to replace a loop of memory copy.
|
||||
// TODO(tianleiwu): use one kernel to replace a loop of memory copy.
|
||||
const CudaT* current_logits = logits_data + (input_length - 1) * vocab_size;
|
||||
for (int i = 0; i < batch_beam_size; i++) {
|
||||
gsl::span<const T> source(reinterpret_cast<const T*>(current_logits), vocab_size);
|
||||
gsl::span<T> target = next_token_logits.subspan(i * vocab_size, vocab_size);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), sizeof(T) * vocab_size,
|
||||
cudaMemcpyDeviceToDevice, cuda_stream));
|
||||
current_logits += input_length * vocab_size;
|
||||
}
|
||||
}
|
||||
|
|
@ -253,12 +263,14 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
// Copy sequences to device only when repetition penalty or no repeat ngram is used in kernel
|
||||
BufferUniquePtr sequences_buffer;
|
||||
int current_sequence_length = sequences->GetSequenceLength();
|
||||
if (parameters->repetition_penalty != 1.0f || (parameters->no_repeat_ngram_size > 0 && current_sequence_length >= parameters->no_repeat_ngram_size)) {
|
||||
bool run_ngram = parameters->no_repeat_ngram_size > 0 && current_sequence_length >= parameters->no_repeat_ngram_size;
|
||||
if (parameters->repetition_penalty != 1.0f || run_ngram) {
|
||||
size_t bytes = SafeInt<size_t>(sizeof(int32_t)) * batch_beam_size * parameters->max_length;
|
||||
void* data = allocator->Alloc(bytes);
|
||||
BufferUniquePtr temp_buffer(data, BufferDeleter(allocator));
|
||||
sequences_buffer = std::move(temp_buffer);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes, cudaMemcpyHostToDevice, cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sequences_buffer.get(), sequences->GetSequence(0).data(), bytes,
|
||||
cudaMemcpyHostToDevice, cuda_stream));
|
||||
}
|
||||
|
||||
cuda::LaunchLogitsProcessKernel<float>(
|
||||
|
|
@ -277,20 +289,25 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
cuda_stream);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
dumper->Print("next_token_scores after logits process", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
#endif
|
||||
|
||||
// Add beam score to next token scores. Corresponding python code is like:
|
||||
// next_token_scores = next_token_scores + beam_scores[:, None].expand_as(next_token_scores)
|
||||
cuda::LaunchAddProbsKernel(next_token_scores.data(), beam_state->beam_scores.data(), batch_size, num_beams, vocab_size, cuda_stream);
|
||||
cuda::LaunchAddProbsKernel(next_token_scores.data(), beam_state->beam_scores.data(),
|
||||
batch_size, num_beams, vocab_size, cuda_stream);
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("next_token_scores after adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
dumper->Print("next_token_scores adding beam_scores", next_token_scores.data(), batch_size, num_beams, vocab_size);
|
||||
#endif
|
||||
|
||||
if (output_scores) {
|
||||
// Append next token scores to the scores output.
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state->remaining_scores.data(), next_token_scores.data(), next_token_scores.size_bytes(), cudaMemcpyDeviceToDevice, cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(beam_state->remaining_scores.data(),
|
||||
next_token_scores.data(),
|
||||
next_token_scores.size_bytes(),
|
||||
cudaMemcpyDeviceToDevice,
|
||||
cuda_stream));
|
||||
beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size());
|
||||
}
|
||||
|
||||
|
|
@ -301,7 +318,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
|
||||
auto element_type = DataTypeImpl::GetType<float>();
|
||||
OrtValue next_token_scores_value;
|
||||
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(), next_token_scores_value);
|
||||
Tensor::InitOrtValue(element_type, next_token_scores_shape, next_token_scores.data(), allocator->Info(),
|
||||
next_token_scores_value);
|
||||
const Tensor& input = next_token_scores_value.Get<Tensor>();
|
||||
|
||||
constexpr int axis = 1;
|
||||
|
|
@ -311,7 +329,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
|
||||
std::unique_ptr<Tensor> topk_scores;
|
||||
std::unique_ptr<Tensor> topk_indices;
|
||||
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool, topk_scores, topk_indices));
|
||||
ORT_RETURN_IF_ERROR(TopK(&input, axis, top_k, largest, sorted, allocator, stream, thread_pool,
|
||||
topk_scores, topk_indices));
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
|
|
@ -322,7 +341,8 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
// next_indices = (next_tokens / vocab_size).long()
|
||||
// next_tokens = next_tokens % vocab_size
|
||||
const int64_t* next_token_indices = topk_indices->Data<int64_t>();
|
||||
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(), batch_size, top_k, vocab_size, cuda_stream);
|
||||
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
|
||||
batch_size, top_k, vocab_size, cuda_stream);
|
||||
|
||||
const float* data = topk_scores->Data<float>();
|
||||
|
||||
|
|
@ -332,12 +352,26 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
|
||||
#endif
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(), data, topk_scores->Shape().Size() * sizeof(float), cudaMemcpyDeviceToHost, cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(), beam_state->next_tokens.data(), beam_state->next_tokens.size_bytes(), cudaMemcpyDeviceToHost, cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_indices.data(), beam_state->next_indices.data(), beam_state->next_indices.size_bytes(), cudaMemcpyDeviceToHost, cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_scores.data(),
|
||||
data,
|
||||
topk_scores->Shape().Size() * sizeof(float),
|
||||
cudaMemcpyDeviceToHost,
|
||||
cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_tokens.data(),
|
||||
beam_state->next_tokens.data(),
|
||||
beam_state->next_tokens.size_bytes(),
|
||||
cudaMemcpyDeviceToHost,
|
||||
cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(cpu_state->topk_indices.data(),
|
||||
beam_state->next_indices.data(),
|
||||
beam_state->next_indices.size_bytes(),
|
||||
cudaMemcpyDeviceToHost,
|
||||
cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
|
||||
|
||||
gsl::span<const float> next_scores = gsl::make_span(cpu_state->topk_scores.data(), static_cast<typename gsl::span<float>::index_type>(topk_scores->Shape().Size()));
|
||||
gsl::span<const float> next_scores = gsl::make_span(
|
||||
cpu_state->topk_scores.data(),
|
||||
static_cast<typename gsl::span<float>::index_type>(topk_scores->Shape().Size()));
|
||||
gsl::span<const int32_t> next_tokens(cpu_state->topk_tokens.data(), beam_state->next_tokens.size());
|
||||
gsl::span<const int32_t> next_indices(cpu_state->topk_indices.data(), beam_state->next_indices.size());
|
||||
|
||||
|
|
@ -354,55 +388,98 @@ template <typename T>
|
|||
Status DeviceCopy(gsl::span<T> target, gsl::span<const T> source, void* stream, int copyDirection) {
|
||||
assert(copyDirection >= 0 && copyDirection <= 3);
|
||||
if (stream == nullptr) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(target.data(), source.data(), source.size_bytes(), static_cast<cudaMemcpyKind>(copyDirection)));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpy(target.data(), source.data(), source.size_bytes(),
|
||||
static_cast<cudaMemcpyKind>(copyDirection)));
|
||||
} else {
|
||||
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(), static_cast<cudaMemcpyKind>(copyDirection), cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(target.data(), source.data(), source.size_bytes(),
|
||||
static_cast<cudaMemcpyKind>(copyDirection), cuda_stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status PickPastState(const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
gsl::span<const int32_t>& beam_indices,
|
||||
AllocatorPtr allocator,
|
||||
void* stream) {
|
||||
for (size_t i = 1; i < last_outputs.size(); ++i) {
|
||||
const OrtValue& present = last_outputs[i]; // shape is like (2, batch_beam_size, 12, past_seq_len, 64)
|
||||
Status PickGptPastState(const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
gsl::span<const int32_t>& beam_indices,
|
||||
AllocatorPtr allocator,
|
||||
void* stream) {
|
||||
int num_present_tensors = static_cast<int>(last_outputs.size()) - transformers::GptSubgraph::kFirstPresentOutputIndex;
|
||||
for (int i = 0; i < num_present_tensors; ++i) {
|
||||
const OrtValue& present = last_outputs[transformers::GptSubgraph::kFirstPresentOutputIndex + i];
|
||||
|
||||
// shape is like (2, batch_beam_size, 12, past_seq_len, 64)
|
||||
const TensorShape& past_shape = present.Get<Tensor>().Shape();
|
||||
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
|
||||
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
|
||||
|
||||
// Create a tensor with same shape.
|
||||
// TODO: allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data.
|
||||
// TODO(tianleiwu): allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data.
|
||||
OrtValue past;
|
||||
auto past_type = DataTypeImpl::GetType<T>();
|
||||
Tensor::InitOrtValue(past_type, past_shape, allocator, past);
|
||||
|
||||
auto block_size_per_beam = past_shape[2] * past_shape[3] * past_shape[4];
|
||||
auto past_key_size = past_shape[1] * past_shape[2] * past_shape[3] * past_shape[4];
|
||||
|
||||
gsl::span<T> past_span = gsl::make_span<T>(past.GetMutable<Tensor>()->MutableData<T>(), past_shape.Size());
|
||||
gsl::span<const T> present_span = gsl::make_span<const T>(present.Get<Tensor>().Data<T>(), past_shape.Size());
|
||||
for (gsl::index j = 0; j < beam_indices.length(); j++) {
|
||||
int32_t beam_index = beam_indices[j];
|
||||
gsl::span<const T> present_key = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<const T> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<const T> present_value = present_span.subspan(past_key_size + beam_index * block_size_per_beam,
|
||||
block_size_per_beam);
|
||||
|
||||
gsl::span<T> past_key = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<T> past_value = past_span.subspan(past_key_size + j * block_size_per_beam, block_size_per_beam);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_key.data(), present_key.data(), present_key.size_bytes(), cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_value.data(), present_value.data(), present_value.size_bytes(), cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_key.data(), present_key.data(), present_key.size_bytes(),
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_value.data(), present_value.data(), present_value.size_bytes(),
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
|
||||
}
|
||||
|
||||
next_inputs[i + 2] = past;
|
||||
next_inputs[transformers::GptSubgraph::kFirstPastInputIndex + i] = past;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Copy present state to past state for T5 model
|
||||
template <typename T>
|
||||
Status PickT5PastState(const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t>& beam_indices,
|
||||
AllocatorPtr allocator,
|
||||
void* stream) {
|
||||
for (int i = 0; i < num_present_tensors; ++i) {
|
||||
const OrtValue& present = last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i];
|
||||
|
||||
// shape is like (batch_beam_size, 12, past_seq_len, 64)
|
||||
const TensorShape& past_shape = present.Get<Tensor>().Shape();
|
||||
auto block_size_per_beam = past_shape[1] * past_shape[2] * past_shape[3];
|
||||
|
||||
// Create a tensor with same shape.
|
||||
// TODO(tianleiwu): allocate one buffer for all layers, and use a CUDA kernel to copy key/value cache data.
|
||||
OrtValue past;
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<T>(), past_shape, allocator, past);
|
||||
|
||||
gsl::span<T> past_span = gsl::make_span<T>(past.GetMutable<Tensor>()->MutableData<T>(), past_shape.Size());
|
||||
gsl::span<const T> present_span = gsl::make_span<const T>(present.Get<Tensor>().Data<T>(), past_shape.Size());
|
||||
for (gsl::index j = 0; j < beam_indices.length(); j++) {
|
||||
int32_t beam_index = beam_indices[j];
|
||||
gsl::span<const T> present_beam = present_span.subspan(beam_index * block_size_per_beam, block_size_per_beam);
|
||||
gsl::span<T> past_beam = past_span.subspan(j * block_size_per_beam, block_size_per_beam);
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(past_beam.data(), present_beam.data(), present_beam.size_bytes(),
|
||||
cudaMemcpyDeviceToDevice, reinterpret_cast<cudaStream_t>(stream)));
|
||||
}
|
||||
|
||||
next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] = past;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status UpdateFeeds(
|
||||
Status UpdateGptFeeds(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
|
|
@ -421,7 +498,8 @@ Status UpdateFeeds(
|
|||
OrtValue input_ids;
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids);
|
||||
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(), cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream)));
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream)));
|
||||
next_inputs[0] = input_ids;
|
||||
|
||||
// Update position IDs
|
||||
|
|
@ -439,7 +517,8 @@ Status UpdateFeeds(
|
|||
int32_t* mask_data = attention_mask.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
|
||||
// Launch kernel to update position_ids and attention_mask for next iteration
|
||||
cuda::LaunchUpdateKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length, reinterpret_cast<cudaStream_t>(stream));
|
||||
cuda::LaunchUpdateGptKernel(old_mask_data, mask_data, position_data, batch_beam_size, current_length,
|
||||
reinterpret_cast<cudaStream_t>(stream));
|
||||
|
||||
next_inputs[2] = attention_mask;
|
||||
|
||||
|
|
@ -453,12 +532,13 @@ Status UpdateFeeds(
|
|||
|
||||
// Update past state
|
||||
if (num_beams == 1) {
|
||||
const int k = transformers::GptSubgraph::kFirstPastInputIndex - transformers::GptSubgraph::kFirstPresentOutputIndex;
|
||||
// feed present_* output to past_* inputs one by one
|
||||
for (size_t i = 1; i < last_outputs.size(); ++i) {
|
||||
next_inputs[i + 2] = last_outputs[i];
|
||||
for (size_t i = transformers::GptSubgraph::kFirstPresentOutputIndex; i < last_outputs.size(); ++i) {
|
||||
next_inputs[i + k] = last_outputs[i];
|
||||
}
|
||||
} else {
|
||||
ORT_RETURN_IF_ERROR(PickPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream));
|
||||
ORT_RETURN_IF_ERROR(PickGptPastState<T>(last_outputs, next_inputs, beam_indices, allocator, stream));
|
||||
}
|
||||
|
||||
// Make sure data is ready before next subgraph execution.
|
||||
|
|
@ -466,15 +546,63 @@ Status UpdateFeeds(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// Update decoder inputs given decoder outputs of last iteration.
|
||||
template <typename T>
|
||||
Status UpdateDecoderFeeds(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper) {
|
||||
// last_outputs: logits, present_key_self_0, present_value_self_0, ...
|
||||
// next_inputs: input_ids,
|
||||
// encoder_attention_mask, encoder_hidden_states,
|
||||
// past_key_self_0, past_value_self_0, ...
|
||||
// past_key_cross_0, past_value_cross_0, ...
|
||||
// Only need copy beam next tokens to input_ids, and copy present_*_self_* to past_*_self_*,
|
||||
|
||||
// Update input_ids with next tokens.
|
||||
int batch_beam_size = static_cast<int>(beam_next_tokens.length());
|
||||
int64_t dims[] = {batch_beam_size, 1};
|
||||
TensorShape input_ids_shape(&dims[0], 2);
|
||||
auto element_type = DataTypeImpl::GetType<int32_t>();
|
||||
OrtValue input_ids;
|
||||
Tensor::InitOrtValue(element_type, input_ids_shape, allocator, input_ids);
|
||||
int32_t* input_ids_data = input_ids.GetMutable<Tensor>()->MutableData<int32_t>();
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(input_ids_data, beam_next_tokens.data(), beam_next_tokens.size_bytes(),
|
||||
cudaMemcpyHostToDevice, reinterpret_cast<cudaStream_t>(stream)));
|
||||
next_inputs[0] = input_ids;
|
||||
|
||||
#ifdef DEBUG_BEAM_SEARCH
|
||||
dumper->Print("input_ids", input_ids);
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(dumper);
|
||||
#endif
|
||||
|
||||
// Update past state
|
||||
ORT_ENFORCE(last_outputs.size() >= static_cast<size_t>(1 + num_present_tensors));
|
||||
// TODO(tianleiwu): remove num_beams==1 once GreedySearch operator is available.
|
||||
if (num_beams == 1) {
|
||||
// feed present_* output to past_* inputs one by one
|
||||
for (int i = 0; i < num_present_tensors; ++i) {
|
||||
next_inputs[transformers::T5DecoderSubgraph::kFirstPastInputIndex + i] =
|
||||
last_outputs[transformers::T5DecoderSubgraph::kFirstPresentOutputIndex + i];
|
||||
return Status::OK();
|
||||
}
|
||||
}
|
||||
|
||||
return PickT5PastState<T>(last_outputs, next_inputs, num_present_tensors, beam_indices, allocator, stream);
|
||||
}
|
||||
|
||||
// Explicit template instantiations of functions
|
||||
template void InitBeamState<float>(transformers::IBeamSearchState<float>* beam_state,
|
||||
transformers::IBeamSearchCpuState* cpu_state,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
gsl::span<const int32_t> input_ids_in_cpu,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
void* stream);
|
||||
|
||||
template Status ProcessLogits<float>(const OrtValue& logits,
|
||||
|
|
@ -494,9 +622,15 @@ template Status DeviceCopy<float>(
|
|||
gsl::span<float> target,
|
||||
gsl::span<const float> source,
|
||||
void* stream,
|
||||
int copyDirectionn);
|
||||
int copyDirection);
|
||||
|
||||
template Status UpdateFeeds<float>(
|
||||
template Status DeviceCopy<int32_t>(
|
||||
gsl::span<int32_t> target,
|
||||
gsl::span<const int32_t> source,
|
||||
void* stream,
|
||||
int copyDirection);
|
||||
|
||||
template Status UpdateGptFeeds<float>(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
|
|
@ -510,13 +644,9 @@ template Status UpdateFeeds<float>(
|
|||
|
||||
// Float16
|
||||
template void InitBeamState<MLFloat16>(transformers::IBeamSearchState<MLFloat16>* beam_state,
|
||||
transformers::IBeamSearchCpuState* cpu_state,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
gsl::span<const int32_t> input_ids_in_cpu,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
void* stream);
|
||||
|
||||
template Status ProcessLogits<MLFloat16>(const OrtValue& logits,
|
||||
|
|
@ -532,7 +662,7 @@ template Status ProcessLogits<MLFloat16>(const OrtValue& logits,
|
|||
void* stream,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
template Status UpdateFeeds<MLFloat16>(
|
||||
template Status UpdateGptFeeds<MLFloat16>(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
|
|
@ -544,6 +674,28 @@ template Status UpdateFeeds<MLFloat16>(
|
|||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
template Status UpdateDecoderFeeds<float>(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
template Status UpdateDecoderFeeds<MLFloat16>(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
} // namespace BeamSearchCudaDeviceHelper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,3 +1,6 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
|
|
@ -26,21 +29,15 @@ Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest,
|
|||
std::unique_ptr<Tensor>& output_indices);
|
||||
|
||||
Status AddToFeeds(const IExecutionProvider* execution_provider,
|
||||
OrtValue& input_ids,
|
||||
OrtValue& position_ids,
|
||||
OrtValue& attention_mask,
|
||||
std::initializer_list<OrtValue> inputs,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer);
|
||||
|
||||
template <typename T>
|
||||
void InitBeamState(transformers::IBeamSearchState<T>* beam_state,
|
||||
transformers::IBeamSearchCpuState* cpu_state,
|
||||
gsl::span<int32_t>& sequence_lengths,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
gsl::span<const int32_t> input_ids_in_cpu,
|
||||
int sequence_length,
|
||||
int max_length,
|
||||
void* stream);
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -64,7 +61,7 @@ Status DeviceCopy(gsl::span<T> target,
|
|||
int copyDirection);
|
||||
|
||||
template <typename T>
|
||||
Status UpdateFeeds(
|
||||
Status UpdateGptFeeds(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
|
|
@ -76,6 +73,23 @@ Status UpdateFeeds(
|
|||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
// ---------------------------------------------------------------
|
||||
// Functions for encoder-decoder model like T5
|
||||
// ---------------------------------------------------------------
|
||||
|
||||
// Update decoder inputs given decoder outputs of last iteration.
|
||||
template <typename T>
|
||||
Status UpdateDecoderFeeds(
|
||||
AllocatorPtr allocator,
|
||||
void* stream,
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int num_present_tensors,
|
||||
gsl::span<const int32_t> beam_next_tokens,
|
||||
gsl::span<const int32_t> beam_indices,
|
||||
int num_beams,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
||||
} // namespace BeamSearchCudaDeviceHelper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "beam_search_impl.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "cub/util_type.cuh"
|
||||
#include "contrib_ops/cuda/transformers/beam_search_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -52,7 +52,11 @@ void LaunchNextTokenKernel(const int64_t* next_token_indices,
|
|||
int total_elements = batch_size * top_k;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
NextTokenKernel<<<gridSize, blockSize, 0, stream>>>(next_token_indices, next_indices, next_tokens, vocab_size, total_elements);
|
||||
NextTokenKernel<<<gridSize, blockSize, 0, stream>>>(next_token_indices,
|
||||
next_indices,
|
||||
next_tokens,
|
||||
vocab_size,
|
||||
total_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -153,8 +157,19 @@ void LaunchLogitsProcessKernel(
|
|||
int total_elements = batch_size * num_beams * vocab_size;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(next_token_scores, vocab_mask, prefix_vocab_mask, num_beams, vocab_size, total_elements, demote_token_id,
|
||||
sequences, max_sequence_length, current_sequence_length, repetition_penalty, no_repeat_ngram_size);
|
||||
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(
|
||||
next_token_scores,
|
||||
vocab_mask,
|
||||
prefix_vocab_mask,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
total_elements,
|
||||
demote_token_id,
|
||||
sequences,
|
||||
max_sequence_length,
|
||||
current_sequence_length,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size);
|
||||
}
|
||||
|
||||
// Instantiation
|
||||
|
|
@ -221,11 +236,11 @@ template void LaunchAddProbsKernel(
|
|||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void UpdateInputsKernel(const T* old_mask_data,
|
||||
T* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length) {
|
||||
__global__ void UpdateGptInputsKernel(const T* old_mask_data,
|
||||
T* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < batch_beam_size * current_length) {
|
||||
// Update attention mask.
|
||||
|
|
@ -240,19 +255,20 @@ __global__ void UpdateInputsKernel(const T* old_mask_data,
|
|||
}
|
||||
}
|
||||
|
||||
void LaunchUpdateKernel(const int32_t* old_mask_data,
|
||||
int32_t* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length,
|
||||
cudaStream_t stream) {
|
||||
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
|
||||
int32_t* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length,
|
||||
cudaStream_t stream) {
|
||||
assert(current_length > 0);
|
||||
int total_elements = batch_beam_size * current_length;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
UpdateInputsKernel<int32_t><<<gridSize, blockSize, 0, stream>>>(old_mask_data, mask_data, next_positions, batch_beam_size, current_length);
|
||||
UpdateGptInputsKernel<int32_t><<<gridSize, blockSize, 0, stream>>>(
|
||||
old_mask_data, mask_data, next_positions, batch_beam_size, current_length);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,8 +2,10 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
|
@ -46,12 +48,12 @@ void LaunchNextTokenKernel(const int64_t* next_token_indices,
|
|||
int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
void LaunchUpdateKernel(const int32_t* old_mask_data,
|
||||
int32_t* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length,
|
||||
cudaStream_t stream);
|
||||
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
|
||||
int32_t* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cuda_runtime_api.h>
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "dump_cuda_tensor.h"
|
||||
#include "core/framework/print_tensor_utils.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -38,11 +39,13 @@ class PinnedHostBuffer {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1) {
|
||||
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool is_gpu_tensor) {
|
||||
// Occasionally, user will need dump CPU tensor in CUDA EP.
|
||||
// In that case, we copy tensor data as well. It is not needed, but it keeps code simple.
|
||||
int num_items = dim0 * dim1;
|
||||
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
|
||||
cudaDeviceSynchronize();
|
||||
cudaMemcpy(*data, tensor, num_items * sizeof(T), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
|
||||
|
||||
if (nullptr != name) {
|
||||
std::cout << std::string(name) << std::endl;
|
||||
|
|
@ -56,11 +59,11 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1) {
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2) {
|
||||
void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int dim2, bool is_gpu_tensor) {
|
||||
int num_items = dim0 * dim1 * dim2;
|
||||
auto data = std::make_shared<PinnedHostBuffer<T>>(num_items);
|
||||
cudaDeviceSynchronize();
|
||||
cudaMemcpy(*data, tensor, num_items * sizeof(T), cudaMemcpyDeviceToHost);
|
||||
cudaMemcpy(*data, tensor, num_items * sizeof(T), is_gpu_tensor ? cudaMemcpyDeviceToHost : cudaMemcpyHostToHost);
|
||||
|
||||
if (nullptr != name) {
|
||||
std::cout << std::string(name) << std::endl;
|
||||
|
|
@ -75,14 +78,15 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di
|
|||
|
||||
void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) {
|
||||
MLDataType dataType = tensor.DataType();
|
||||
bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU);
|
||||
if (dataType == DataTypeImpl::GetType<float>()) {
|
||||
DumpGpuTensor<float>(name, tensor.Data<float>(), dim0, dim1, dim2);
|
||||
DumpGpuTensor<float>(name, tensor.Data<float>(), dim0, dim1, dim2, is_gpu_tensor);
|
||||
} else if (dataType == DataTypeImpl::GetType<MLFloat16>()) {
|
||||
DumpGpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1, dim2);
|
||||
DumpGpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1, dim2, is_gpu_tensor);
|
||||
} else if (dataType == DataTypeImpl::GetType<int32_t>()) {
|
||||
DumpGpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1, dim2);
|
||||
DumpGpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1, dim2, is_gpu_tensor);
|
||||
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
|
||||
DumpGpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1, dim2);
|
||||
DumpGpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1, dim2, is_gpu_tensor);
|
||||
} else {
|
||||
assert(0);
|
||||
}
|
||||
|
|
@ -90,14 +94,15 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, i
|
|||
|
||||
void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) {
|
||||
MLDataType dataType = tensor.DataType();
|
||||
bool is_gpu_tensor = (tensor.Location().device.Type() == OrtDevice::GPU);
|
||||
if (dataType == DataTypeImpl::GetType<float>()) {
|
||||
DumpGpuTensor<float>(name, tensor.Data<float>(), dim0, dim1);
|
||||
DumpGpuTensor<float>(name, tensor.Data<float>(), dim0, dim1, is_gpu_tensor);
|
||||
} else if (dataType == DataTypeImpl::GetType<MLFloat16>()) {
|
||||
DumpGpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1);
|
||||
DumpGpuTensor<MLFloat16>(name, tensor.Data<MLFloat16>(), dim0, dim1, is_gpu_tensor);
|
||||
} else if (dataType == DataTypeImpl::GetType<int32_t>()) {
|
||||
DumpGpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1);
|
||||
DumpGpuTensor<int32_t>(name, tensor.Data<int32_t>(), dim0, dim1, is_gpu_tensor);
|
||||
} else if (dataType == DataTypeImpl::GetType<int64_t>()) {
|
||||
DumpGpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1);
|
||||
DumpGpuTensor<int64_t>(name, tensor.Data<int64_t>(), dim0, dim1, is_gpu_tensor);
|
||||
} else {
|
||||
assert(0);
|
||||
}
|
||||
|
|
@ -106,12 +111,18 @@ void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1) {
|
|||
void DumpGpuTensor(const char* name, const Tensor& tensor) {
|
||||
const auto& shape = tensor.Shape();
|
||||
|
||||
if (nullptr != name) {
|
||||
std::cout << std::string(name) << std::endl;
|
||||
}
|
||||
std::cout << "Shape:" << shape << std::endl;
|
||||
std::cout << tensor.Location().ToString() << std::endl;
|
||||
|
||||
size_t num_dims = shape.NumDimensions();
|
||||
if (num_dims >= 3) {
|
||||
int dim0 = static_cast<int>(shape.SizeToDimension(num_dims - 2));
|
||||
int dim1 = static_cast<int>(shape[num_dims - 2]);
|
||||
int dim2 = static_cast<int>(shape[num_dims - 1]);
|
||||
DumpGpuTensor(name, tensor, dim0, dim1, dim2);
|
||||
DumpGpuTensor(nullptr, tensor, dim0, dim1, dim2);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -121,47 +132,47 @@ void DumpGpuTensor(const char* name, const Tensor& tensor) {
|
|||
num_rows = static_cast<size_t>(shape[0]);
|
||||
}
|
||||
size_t row_size = num_items / num_rows;
|
||||
DumpGpuTensor(name, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
|
||||
DumpGpuTensor(nullptr, tensor, static_cast<int>(num_rows), static_cast<int>(row_size));
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1) const {
|
||||
if (is_enabled_)
|
||||
DumpGpuTensor<float>(name, tensor, dim0, dim1);
|
||||
DumpGpuTensor<float>(name, tensor, dim0, dim1, true);
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1) const {
|
||||
if (is_enabled_)
|
||||
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1);
|
||||
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1, true);
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1) const {
|
||||
if (is_enabled_)
|
||||
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1);
|
||||
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1, true);
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1) const {
|
||||
if (is_enabled_)
|
||||
DumpGpuTensor<int32_t>(name, tensor, dim0, dim1);
|
||||
DumpGpuTensor<int32_t>(name, tensor, dim0, dim1, true);
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const float* tensor, int dim0, int dim1, int dim2) const {
|
||||
if (is_enabled_)
|
||||
DumpGpuTensor<float>(name, tensor, dim0, dim1, dim2);
|
||||
DumpGpuTensor<float>(name, tensor, dim0, dim1, dim2, true);
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const MLFloat16* tensor, int dim0, int dim1, int dim2) const {
|
||||
if (is_enabled_)
|
||||
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1, dim2);
|
||||
DumpGpuTensor<MLFloat16>(name, tensor, dim0, dim1, dim2, true);
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const int64_t* tensor, int dim0, int dim1, int dim2) const {
|
||||
if (is_enabled_)
|
||||
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1, dim2);
|
||||
DumpGpuTensor<int64_t>(name, tensor, dim0, dim1, dim2, true);
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const int32_t* tensor, int dim0, int dim1, int dim2) const {
|
||||
if (is_enabled_)
|
||||
DumpGpuTensor<int32_t>(name, tensor, dim0, dim1, dim2);
|
||||
DumpGpuTensor<int32_t>(name, tensor, dim0, dim1, dim2, true);
|
||||
}
|
||||
|
||||
void CudaTensorConsoleDumper::Print(const char* name, const Tensor& tensor) const {
|
||||
|
|
@ -235,4 +246,4 @@ void CudaTensorConsoleDumper::Print(const char*, const std::string&, bool) const
|
|||
} // namespace transformers
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
|
|
@ -33,4 +34,4 @@ class CudaTensorConsoleDumper : public onnxruntime::contrib::transformers::ICons
|
|||
} // namespace transformers
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -957,10 +957,11 @@ ONNX_MS_OPERATOR_SET_SCHEMA(BeamSearch, 1,
|
|||
.SetDoc("Beam Search for text generation. Supports GPT-2 decoder.")
|
||||
.Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT)
|
||||
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
|
||||
.Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast<int64_t>(-1))
|
||||
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("early_stopping", "early stop or not", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("model_type", "model type: 0 for GPT-2; 1 for encoder decoder like T5", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("encoder_decoder_init", "subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
|
||||
.Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
|
||||
.Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH)
|
||||
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
|
||||
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
|
||||
|
|
|
|||
|
|
@ -12,3 +12,5 @@ import convert_to_onnx
|
|||
|
||||
# added for backward compatible
|
||||
import gpt2_helper
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "models", "t5"))
|
||||
|
|
|
|||
|
|
@ -9,9 +9,8 @@
|
|||
import argparse
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple, Union
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
from onnx import ModelProto, TensorProto, numpy_helper
|
||||
|
|
@ -27,7 +26,7 @@ def fake_input_ids_data(
|
|||
input_ids (TensorProto): graph input of the input_ids input tensor
|
||||
batch_size (int): batch size
|
||||
sequence_length (int): sequence length
|
||||
dictionary_size (int): vacaburary size of dictionary
|
||||
dictionary_size (int): vocabulary size of dictionary
|
||||
|
||||
Returns:
|
||||
np.ndarray: the input tensor created
|
||||
|
|
@ -115,28 +114,28 @@ def fake_input_mask_data(
|
|||
return data
|
||||
|
||||
|
||||
def output_test_data(dir: str, inputs: np.ndarray):
|
||||
def output_test_data(directory: str, inputs: Dict[str, np.ndarray]):
|
||||
"""Output input tensors of test data to a directory
|
||||
|
||||
Args:
|
||||
dir (str): path of a directory
|
||||
inputs (numpy.ndarray): numpy array
|
||||
directory (str): path of a directory
|
||||
inputs (Dict[str, np.ndarray]): map from input name to value
|
||||
"""
|
||||
if not os.path.exists(dir):
|
||||
if not os.path.exists(directory):
|
||||
try:
|
||||
os.mkdir(dir)
|
||||
os.mkdir(directory)
|
||||
except OSError:
|
||||
print("Creation of the directory %s failed" % dir)
|
||||
print("Creation of the directory %s failed" % directory)
|
||||
else:
|
||||
print("Successfully created the directory %s " % dir)
|
||||
print("Successfully created the directory %s " % directory)
|
||||
else:
|
||||
print("Warning: directory %s existed. Files will be overwritten." % dir)
|
||||
print("Warning: directory %s existed. Files will be overwritten." % directory)
|
||||
|
||||
index = 0
|
||||
for name, data in inputs.items():
|
||||
tensor = numpy_helper.from_array(data, name)
|
||||
with open(os.path.join(dir, "input_{}.pb".format(index)), "wb") as f:
|
||||
f.write(tensor.SerializeToString())
|
||||
with open(os.path.join(directory, "input_{}.pb".format(index)), "wb") as file:
|
||||
file.write(tensor.SerializeToString())
|
||||
index += 1
|
||||
|
||||
|
||||
|
|
@ -158,7 +157,7 @@ def fake_test_data(
|
|||
batch_size (int): batch size
|
||||
sequence_length (int): sequence length
|
||||
test_cases (int): number of test cases
|
||||
dictionary_size (int): vocaburary size of dictionary for input_ids
|
||||
dictionary_size (int): vocabulary size of dictionary for input_ids
|
||||
verbose (bool): print more information or not
|
||||
random_seed (int): random seed
|
||||
input_ids (TensorProto): graph input of input IDs
|
||||
|
|
@ -167,7 +166,8 @@ def fake_test_data(
|
|||
random_mask_length (bool): whether mask random number of words at the end
|
||||
|
||||
Returns:
|
||||
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictonary with input name as key and a tensor as value
|
||||
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
|
||||
with input name as key and a tensor as value
|
||||
"""
|
||||
assert input_ids is not None
|
||||
|
||||
|
|
@ -202,7 +202,7 @@ def generate_test_data(
|
|||
input_mask: TensorProto,
|
||||
random_mask_length: bool,
|
||||
):
|
||||
"""Create given number of minput data for testing
|
||||
"""Create given number of input data for testing
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size
|
||||
|
|
@ -216,7 +216,8 @@ def generate_test_data(
|
|||
random_mask_length (bool): whether mask random number of words at the end
|
||||
|
||||
Returns:
|
||||
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictonary with input name as key and a tensor as value
|
||||
List[Dict[str,numpy.ndarray]]: list of test cases, where each test case is a dictionary
|
||||
with input name as key and a tensor as value
|
||||
"""
|
||||
dictionary_size = 10000
|
||||
all_inputs = fake_test_data(
|
||||
|
|
@ -251,12 +252,13 @@ def get_graph_input_from_embed_node(onnx_model, embed_node, input_index):
|
|||
|
||||
def find_bert_inputs(
|
||||
onnx_model: OnnxModel,
|
||||
input_ids_name: str = None,
|
||||
segment_ids_name: str = None,
|
||||
input_mask_name: str = None,
|
||||
) -> Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]:
|
||||
input_ids_name: Optional[str] = None,
|
||||
segment_ids_name: Optional[str] = None,
|
||||
input_mask_name: Optional[str] = None,
|
||||
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""Find graph inputs for BERT model.
|
||||
First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming.
|
||||
First, we will deduce inputs from EmbedLayerNormalization node.
|
||||
If not found, we will guess the meaning of graph inputs based on naming.
|
||||
|
||||
Args:
|
||||
onnx_model (OnnxModel): onnx model object
|
||||
|
|
@ -266,10 +268,12 @@ def find_bert_inputs(
|
|||
|
||||
Raises:
|
||||
ValueError: Graph does not have input named of input_ids_name or segment_ids_name or input_mask_name
|
||||
ValueError: Exptected graph input number does not match with specifeid input_ids_name, segment_ids_name and input_mask_name
|
||||
ValueError: Expected graph input number does not match with specified input_ids_name, segment_ids_name
|
||||
and input_mask_name
|
||||
|
||||
Returns:
|
||||
Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: input tensors of input_ids, segment_ids and input_mask
|
||||
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
|
||||
segment_ids and input_mask
|
||||
"""
|
||||
|
||||
graph_inputs = onnx_model.get_graph_inputs_excluding_initializers()
|
||||
|
|
@ -340,12 +344,13 @@ def find_bert_inputs(
|
|||
|
||||
def get_bert_inputs(
|
||||
onnx_file: str,
|
||||
input_ids_name: str = None,
|
||||
segment_ids_name: str = None,
|
||||
input_mask_name: str = None,
|
||||
):
|
||||
input_ids_name: Optional[str] = None,
|
||||
segment_ids_name: Optional[str] = None,
|
||||
input_mask_name: Optional[str] = None,
|
||||
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""Find graph inputs for BERT model.
|
||||
First, we will deduce inputs from EmbedLayerNormalization node. If not found, we will guess the meaning of graph inputs based on naming.
|
||||
First, we will deduce inputs from EmbedLayerNormalization node.
|
||||
If not found, we will guess the meaning of graph inputs based on naming.
|
||||
|
||||
Args:
|
||||
onnx_file (str): onnx model path
|
||||
|
|
@ -354,11 +359,12 @@ def get_bert_inputs(
|
|||
input_mask_name (str, optional): Name of graph input for attention mask. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Tuple[Union[None, np.ndarray], Union[None, np.ndarray], Union[None, np.ndarray]]: input tensors of input_ids, segment_ids and input_mask
|
||||
Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]: input tensors of input_ids,
|
||||
segment_ids and input_mask
|
||||
"""
|
||||
model = ModelProto()
|
||||
with open(onnx_file, "rb") as f:
|
||||
model.ParseFromString(f.read())
|
||||
with open(onnx_file, "rb") as file:
|
||||
model.ParseFromString(file.read())
|
||||
|
||||
onnx_model = OnnxModel(model)
|
||||
return find_bert_inputs(onnx_model, input_ids_name, segment_ids_name, input_mask_name)
|
||||
|
|
@ -447,9 +453,9 @@ def create_and_save_test_data(
|
|||
test_cases: int,
|
||||
seed: int,
|
||||
verbose: bool,
|
||||
input_ids_name: str,
|
||||
segment_ids_name: str,
|
||||
input_mask_name: str,
|
||||
input_ids_name: Optional[str],
|
||||
segment_ids_name: Optional[str],
|
||||
input_mask_name: Optional[str],
|
||||
only_input_tensors: bool,
|
||||
):
|
||||
"""Create test data for a model, and save test data to a directory.
|
||||
|
|
@ -482,24 +488,24 @@ def create_and_save_test_data(
|
|||
)
|
||||
|
||||
for i, inputs in enumerate(all_inputs):
|
||||
dir = os.path.join(output_dir, "test_data_set_" + str(i))
|
||||
output_test_data(dir, inputs)
|
||||
directory = os.path.join(output_dir, "test_data_set_" + str(i))
|
||||
output_test_data(directory, inputs)
|
||||
|
||||
if only_input_tensors:
|
||||
return
|
||||
|
||||
import onnxruntime
|
||||
|
||||
sess = onnxruntime.InferenceSession(model)
|
||||
output_names = [output.name for output in sess.get_outputs()]
|
||||
session = onnxruntime.InferenceSession(model)
|
||||
output_names = [output.name for output in session.get_outputs()]
|
||||
|
||||
for i, inputs in enumerate(all_inputs):
|
||||
dir = os.path.join(output_dir, "test_data_set_" + str(i))
|
||||
result = sess.run(output_names, inputs)
|
||||
directory = os.path.join(output_dir, "test_data_set_" + str(i))
|
||||
result = session.run(output_names, inputs)
|
||||
for i, output_name in enumerate(output_names):
|
||||
tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_names[i])
|
||||
with open(os.path.join(dir, "output_{}.pb".format(i)), "wb") as f:
|
||||
f.write(tensor_result.SerializeToString())
|
||||
tensor_result = numpy_helper.from_array(np.asarray(result[i]), output_name)
|
||||
with open(os.path.join(directory, "output_{}.pb".format(i)), "wb") as file:
|
||||
file.write(tensor_result.SerializeToString())
|
||||
|
||||
|
||||
def main():
|
||||
|
|
|
|||
|
|
@ -6,11 +6,20 @@
|
|||
This converts GPT2 or T5 model to onnx with beam search operator.
|
||||
|
||||
Example 1: convert gpt2 model with beam search:
|
||||
python convert_beam_search.py -m gpt2 --decoder_onnx .\onnx_models\gpt2_past_fp32.onnx --output .\onnx_models\gpt2_beam_search.onnx --output_sequences_scores
|
||||
|
||||
Example 2: convert T5 model with beam search:
|
||||
python ./models/t5/convert_to_onnx.py -m t5-small -s
|
||||
python convert_beam_search.py -m t5-small --model_type t5 --decoder_onnx ./onnx_models/t5-small_decoder.onnx --encoder_decoder_init_onnx ./onnx_models/t5-small_encoder_decoder_init.onnx --output ./onnx_models/t5_small_beam_search.onnx
|
||||
python convert_beam_search.py -m gpt2 --decoder_onnx ./onnx_models/gpt2_past_fp32.onnx \
|
||||
--output ./onnx_models/gpt2_beam_search.onnx --output_sequences_scores
|
||||
|
||||
Example 2: convert T5 model with beam search in two steps:
|
||||
cd ./models/t5
|
||||
python convert_to_onnx.py -m t5-small
|
||||
cd ../..
|
||||
python convert_beam_search.py -m t5-small --model_type t5 \
|
||||
--decoder_onnx ./models/t5/onnx_models/t5-small_decoder.onnx \
|
||||
--encoder_decoder_init_onnx ./models/t5/onnx_models/t5-small_encoder_decoder_init.onnx \
|
||||
--output ./models/t5/onnx_models/t5_small_beam_search.onnx
|
||||
|
||||
Example 3: convert T5 model with beam search. All in one step:
|
||||
python convert_beam_search.py -m t5-small --model_type t5 --output ./models/t5/onnx_models/t5_small_beam_search.onnx
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -19,27 +28,36 @@ import os
|
|||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import List, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import torch
|
||||
from benchmark_helper import Precision
|
||||
from onnx import helper
|
||||
from onnx import onnx_pb as onnx_proto
|
||||
from packaging import version
|
||||
from transformers import GPT2Config, T5Config
|
||||
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer, T5Config, T5ForConditionalGeneration, T5Tokenizer
|
||||
|
||||
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions, get_available_providers
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "models", "gpt2"))
|
||||
from convert_to_onnx import main as convert_gpt2_to_onnx
|
||||
from gpt2_helper import PRETRAINED_GPT2_MODELS
|
||||
from gpt2_helper import PRETRAINED_GPT2_MODELS # noqa: E402
|
||||
from models.gpt2.convert_to_onnx import main as convert_gpt2_to_onnx # noqa: E402
|
||||
|
||||
config: Union[GPT2Config, T5Config] = None
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "models", "t5"))
|
||||
from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models # noqa: E402
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
||||
def parse_arguments(argv=None):
|
||||
def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
||||
"""Parse arguments
|
||||
|
||||
Args:
|
||||
argv (Optional[List[str]], optional): _description_. Defaults to None.
|
||||
|
||||
Returns:
|
||||
argparse.Namespace: Parsed arguments.
|
||||
"""
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -69,9 +87,10 @@ def parse_arguments(argv=None):
|
|||
|
||||
parser.add_argument(
|
||||
"--decoder_onnx",
|
||||
required=True,
|
||||
required=False,
|
||||
type=str,
|
||||
help="Output directory for decoder onnx model, or model path ends with .onnx",
|
||||
default="",
|
||||
help="Path of onnx model for decoder. Required for gpt2 model type.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -79,14 +98,14 @@ def parse_arguments(argv=None):
|
|||
required=False,
|
||||
type=str,
|
||||
default="",
|
||||
help="path of ONNX model for encoder and decoder initialization. Required for t5 model type.",
|
||||
help="Path of ONNX model for encoder and decoder initialization. For t5 model type.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--output",
|
||||
required=False,
|
||||
type=str,
|
||||
help="Output directory for beam search model, or model path ends with .onnx",
|
||||
help="Output path for onnx model with beam search.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
|
|
@ -113,6 +132,14 @@ def parse_arguments(argv=None):
|
|||
)
|
||||
parser.set_defaults(disable_parity=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Print more information",
|
||||
)
|
||||
parser.set_defaults(verbose=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--torch_performance",
|
||||
required=False,
|
||||
|
|
@ -185,7 +212,7 @@ def parse_arguments(argv=None):
|
|||
type=float,
|
||||
required=False,
|
||||
default=1,
|
||||
help="Positive. >1 to penalize and <1 to encorage short sentence.",
|
||||
help="Positive. >1 to penalize and <1 to encourage short sentence.",
|
||||
)
|
||||
|
||||
beam_search_group.add_argument(
|
||||
|
|
@ -193,7 +220,7 @@ def parse_arguments(argv=None):
|
|||
type=float,
|
||||
required=False,
|
||||
default=1,
|
||||
help="Positive. >1 to penalize and <1 to encorage.",
|
||||
help="Positive. >1 to penalize and <1 to encourage.",
|
||||
)
|
||||
|
||||
beam_search_group.add_argument(
|
||||
|
|
@ -217,10 +244,14 @@ def parse_arguments(argv=None):
|
|||
return args
|
||||
|
||||
|
||||
def gpt2_to_onnx(args):
|
||||
def gpt2_to_onnx(args: argparse.Namespace):
|
||||
"""Convert GPT-2 model to onnx
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments parsed from command line
|
||||
"""
|
||||
model_name = args.model_name_or_path
|
||||
|
||||
print(f"use convert_to_onnx.py to convert model {model_name} to onnx {args.decoder_onnx} ...")
|
||||
arguments = [
|
||||
"--model_name_or_path",
|
||||
model_name,
|
||||
|
|
@ -233,7 +264,7 @@ def gpt2_to_onnx(args):
|
|||
"1",
|
||||
"--test_cases",
|
||||
"10",
|
||||
"--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, postion_ids and attention_mask
|
||||
"--use_int32_inputs", # BeamSearch requires to use int32 for input_ids, position_ids and attention_mask
|
||||
]
|
||||
if args.use_gpu:
|
||||
arguments.append("--use_gpu")
|
||||
|
|
@ -242,39 +273,78 @@ def gpt2_to_onnx(args):
|
|||
|
||||
if args.precision == Precision.FLOAT16:
|
||||
assert args.use_gpu, "fp16 or mixed precision model cannot run in CPU. Please add --use_gpu"
|
||||
# TODO: Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
|
||||
# TODO(tianleiwu): Use auto mixed precision for fp16 conversion: arguments.append('--auto_mixed_precision')
|
||||
# Need change cuda kernel to support a combination of fp32 logits and fp16 past state.
|
||||
# Currently logits and past state shall be same data type.
|
||||
arguments.extend(["--op_block_list", "Add", "LayerNormalization", "FastGelu"])
|
||||
convert_gpt2_to_onnx(arguments)
|
||||
|
||||
if args.verbose:
|
||||
print(f"arguments for convert_to_onnx:{arguments}")
|
||||
|
||||
convert_gpt2_to_onnx(argv=arguments)
|
||||
|
||||
|
||||
def shape_inference(decoder_onnx_path):
|
||||
if version.parse(onnx.__version__) >= version.parse("1.11.0"):
|
||||
logger.warn("SymbolicShapeInference might fail using onnx version 1.11. Please install 1.10.0 for now.")
|
||||
def t5_to_onnx(args: argparse.Namespace):
|
||||
"""Convert T5 model to onnx
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments parsed from command line
|
||||
"""
|
||||
paths = export_t5_onnx_models(
|
||||
args.model_name_or_path,
|
||||
args.cache_dir,
|
||||
Path(args.output).parent,
|
||||
use_gpu=args.use_gpu,
|
||||
use_external_data_format=args.use_external_data_format,
|
||||
optimize_onnx=False,
|
||||
precision=args.precision,
|
||||
verbose=False,
|
||||
use_decoder_start_token=False,
|
||||
merge_encoder_and_decoder_init=True,
|
||||
overwrite=True,
|
||||
disable_auto_mixed_precision=False,
|
||||
use_int32_inputs=True,
|
||||
)
|
||||
args.encoder_decoder_init_onnx = paths[0]
|
||||
args.decoder_onnx = paths[1]
|
||||
|
||||
|
||||
def shape_inference(onnx_path: str):
|
||||
"""Shape inference on an onnx file, which will be overwritten.
|
||||
|
||||
Args:
|
||||
onnx_path (str): Path of onnx model
|
||||
"""
|
||||
# Run symbolic shape inference to walk around ORT shape inference issue for subgraph.
|
||||
from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference
|
||||
|
||||
out = SymbolicShapeInference.infer_shapes(onnx.load(decoder_onnx_path), auto_merge=True, guess_output_rank=False)
|
||||
out = SymbolicShapeInference.infer_shapes(onnx.load(onnx_path), auto_merge=True, guess_output_rank=False)
|
||||
if out:
|
||||
# TODO: Use external format if input has extra data.
|
||||
onnx.save(out, decoder_onnx_path)
|
||||
# TODO(tianleiwu): Use external format if input has extra data.
|
||||
onnx.save(out, onnx_path)
|
||||
else:
|
||||
print("Failed to run symbolic shape inference on the model.")
|
||||
|
||||
|
||||
def create_ort_session(model_path, use_gpu):
|
||||
from onnxruntime import GraphOptimizationLevel, InferenceSession, SessionOptions
|
||||
from onnxruntime import __version__ as ort_version
|
||||
from onnxruntime import get_available_providers
|
||||
def create_ort_session(model_path: str, use_gpu: bool) -> InferenceSession:
|
||||
"""Create OnnxRuntime session.
|
||||
|
||||
Args:
|
||||
model_path (str): onnx model path
|
||||
use_gpu (bool): use GPU or not
|
||||
|
||||
Raises:
|
||||
RuntimeError: CUDAExecutionProvider is not available when --use_gpu is specified.
|
||||
|
||||
Returns:
|
||||
onnxruntime.InferenceSession: The created session.
|
||||
"""
|
||||
sess_options = SessionOptions()
|
||||
sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
|
||||
execution_providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"]
|
||||
if use_gpu:
|
||||
if "CUDAExecutionProvider" not in get_available_providers():
|
||||
raise RuntimeError("CUDAExecutionProvider is not avaiable for --use_gpu!")
|
||||
raise RuntimeError("CUDAExecutionProvider is not available for --use_gpu!")
|
||||
else:
|
||||
print("use CUDAExecutionProvider")
|
||||
|
||||
|
|
@ -282,11 +352,26 @@ def create_ort_session(model_path, use_gpu):
|
|||
return ort_session
|
||||
|
||||
|
||||
def verify_gpt2_subgraph(graph, precision):
|
||||
def verify_gpt2_subgraph(graph: onnx.GraphProto, precision: Precision):
|
||||
"""Verify GPT-2 subgraph
|
||||
|
||||
Args:
|
||||
graph (onnx.GraphProto): onnx graph of GPT-2
|
||||
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
|
||||
|
||||
Raises:
|
||||
ValueError: Number of inputs not expected.
|
||||
ValueError: Input name is not expected.
|
||||
ValueError: Input data type is not expected.
|
||||
ValueError: Number of outputs not expected.
|
||||
ValueError: Output name is not expected.
|
||||
ValueError: Output data type is not expected.
|
||||
"""
|
||||
is_float16 = Precision.FLOAT16 == precision
|
||||
|
||||
input_count = len(graph.input)
|
||||
layer_count = input_count - 3
|
||||
assert layer_count >= 1
|
||||
|
||||
expected_inputs = ["input_ids", "position_ids", "attention_mask"] + [f"past_{i}" for i in range(layer_count)]
|
||||
if len(graph.input) != len(expected_inputs):
|
||||
|
|
@ -300,10 +385,9 @@ def verify_gpt2_subgraph(graph, precision):
|
|||
if i >= 3:
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
|
||||
|
||||
if graph.input[i].type.tensor_type.elem_type != expected_type:
|
||||
raise ValueError(
|
||||
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.input[i].type.tensor_type.elem_type}"
|
||||
)
|
||||
input_type = graph.input[i].type.tensor_type.elem_type
|
||||
if input_type != expected_type:
|
||||
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
|
||||
print("Verifying GPT-2 graph inputs: name and data type are good.")
|
||||
|
||||
expected_outputs = ["logits"] + [f"present_{i}" for i in range(layer_count)]
|
||||
|
|
@ -315,49 +399,197 @@ def verify_gpt2_subgraph(graph, precision):
|
|||
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
|
||||
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
|
||||
if graph.output[i].type.tensor_type.elem_type != expected_type:
|
||||
raise ValueError(
|
||||
f"Input {i} is expected to have onnx data type {expected_type}. Got {graph.output[i].type.tensor_type.elem_type}"
|
||||
)
|
||||
output_type = graph.output[i].type.tensor_type.elem_type
|
||||
if output_type != expected_type:
|
||||
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {output_type}")
|
||||
print("Verifying GPT-2 graph outputs: name and data type are good.")
|
||||
|
||||
# TODO: verify shapes of inputs and outputs.
|
||||
# TODO(tianleiwu): verify shapes of inputs and outputs.
|
||||
return
|
||||
|
||||
|
||||
def verify_t5_decoder_subgraph(graph, precision):
|
||||
# TODO: implement it
|
||||
pass
|
||||
def verify_t5_decoder_subgraph(graph: onnx.GraphProto, precision: Precision):
|
||||
"""Verify T5 decoder subgraph
|
||||
|
||||
Args:
|
||||
graph (onnx.GraphProto): onnx graph of T5 decoder
|
||||
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
|
||||
|
||||
Raises:
|
||||
ValueError: Number of inputs not expected.
|
||||
ValueError: Input name is not expected.
|
||||
ValueError: Input data type is not expected.
|
||||
ValueError: Number of outputs not expected.
|
||||
ValueError: Output name is not expected.
|
||||
ValueError: Output data type is not expected.
|
||||
"""
|
||||
is_float16 = Precision.FLOAT16 == precision
|
||||
float_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
|
||||
|
||||
input_count = len(graph.input)
|
||||
layer_count = (input_count - 3) // 4
|
||||
assert layer_count >= 1
|
||||
|
||||
# Expect inputs:
|
||||
# input_ids: int32 (B, 1)
|
||||
# encoder_attention_mask: int32 (B, encode_sequence_length)
|
||||
# encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
|
||||
|
||||
# past_key_self_0: (B, num_heads, past_decode_sequence_length, head_size)
|
||||
# past_value_self_0: (B, num_heads, past_decode_sequence_length, head_size)
|
||||
# ... (for each self attention layer)
|
||||
|
||||
# past_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
|
||||
# past_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
|
||||
# ... (for each cross attention layer)
|
||||
expected_inputs = ["input_ids", "encoder_attention_mask", "encoder_hidden_states"]
|
||||
for i in range(layer_count):
|
||||
expected_inputs.append(f"past_key_self_{i}")
|
||||
expected_inputs.append(f"past_value_self_{i}")
|
||||
for i in range(layer_count):
|
||||
expected_inputs.append(f"past_key_cross_{i}")
|
||||
expected_inputs.append(f"past_value_cross_{i}")
|
||||
|
||||
if len(graph.input) != len(expected_inputs):
|
||||
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
|
||||
|
||||
for i, expected_input in enumerate(expected_inputs):
|
||||
if graph.input[i].name != expected_input:
|
||||
raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
|
||||
|
||||
expected_type = onnx_proto.TensorProto.INT32 if i < 2 else float_type
|
||||
input_type = graph.input[i].type.tensor_type.elem_type
|
||||
if input_type != expected_type:
|
||||
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
|
||||
|
||||
# Expect outputs:
|
||||
# logits: (B, 1, vocab_size)
|
||||
# present_key_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
|
||||
# present_value_self_0: (B, num_heads, past_decode_sequence_length + 1, head_size)
|
||||
# ... (for each self attention layer)
|
||||
expected_outputs = ["logits"]
|
||||
for i in range(layer_count):
|
||||
expected_outputs.append(f"present_key_self_{i}")
|
||||
expected_outputs.append(f"present_value_self_{i}")
|
||||
|
||||
if len(graph.output) != len(expected_outputs):
|
||||
raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
|
||||
|
||||
for i, expected_output in enumerate(expected_outputs):
|
||||
if graph.output[i].name != expected_output:
|
||||
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
|
||||
output_type = graph.output[i].type.tensor_type.elem_type
|
||||
if output_type != float_type:
|
||||
raise ValueError(f"Output {i} is expected to have onnx data type {float_type}. Got {output_type}")
|
||||
|
||||
|
||||
def verify_t5_encoder_decoder_init_subgraph(graph, precision):
|
||||
# TODO: implement it
|
||||
pass
|
||||
def verify_t5_encoder_decoder_init_subgraph(graph: onnx.GraphProto, precision: Precision):
|
||||
"""Verify T5 decoder subgraph
|
||||
|
||||
Args:
|
||||
graph (onnx.GraphProto): onnx graph of T5 decoder
|
||||
precision (Precision): Precision (FLOAT16 or FLOAT32) of the model.
|
||||
|
||||
Raises:
|
||||
ValueError: Number of inputs not expected.
|
||||
ValueError: Input name is not expected.
|
||||
ValueError: Input data type is not expected.
|
||||
ValueError: Number of outputs not expected.
|
||||
ValueError: Output name is not expected.
|
||||
ValueError: Output data type is not expected.
|
||||
"""
|
||||
is_float16 = Precision.FLOAT16 == precision
|
||||
layer_count = (len(graph.output) - 2) // 4
|
||||
assert layer_count >= 1
|
||||
|
||||
# Expect 3 inputs:
|
||||
# encoder_input_ids: int32 (B, encode_sequence_length)
|
||||
# encoder_attention_mask: int32 (B, encode_sequence_length)
|
||||
# decoder_input_ids: int32 (B, 1)
|
||||
expected_inputs = ["encoder_input_ids", "encoder_attention_mask", "decoder_input_ids"]
|
||||
if len(graph.input) != len(expected_inputs):
|
||||
raise ValueError(f"Number of inputs expected to be {len(expected_inputs)}. Got {len(graph.input)}")
|
||||
|
||||
for i, expected_input in enumerate(expected_inputs):
|
||||
if graph.input[i].name != expected_input:
|
||||
raise ValueError(f"Input {i} is expected to be {expected_input}. Got {graph.input[i].name}")
|
||||
|
||||
expected_type = onnx_proto.TensorProto.INT32
|
||||
input_type = graph.input[i].type.tensor_type.elem_type
|
||||
if input_type != expected_type:
|
||||
raise ValueError(f"Input {i} is expected to have onnx data type {expected_type}. Got {input_type}")
|
||||
|
||||
# Expected outputs:
|
||||
# logits: (B, 1, vocab_size)
|
||||
# encoder_hidden_states: (B, encode_sequence_length, encoder_hidden_size)
|
||||
# present_key_self_0: (B, num_heads, 1, head_size)
|
||||
# present_value_self_0: (B, num_heads, 1, head_size)
|
||||
# ... (for each self attention layer)
|
||||
# present_key_cross_0: (B, num_heads, encode_sequence_length, head_size)
|
||||
# present_value_cross_0: (B, num_heads, encode_sequence_length, head_size)
|
||||
# ... (for each cross attention layer)
|
||||
expected_outputs = ["logits", "encoder_hidden_states"]
|
||||
for i in range(layer_count):
|
||||
expected_outputs.append(f"present_key_self_{i}")
|
||||
expected_outputs.append(f"present_value_self_{i}")
|
||||
for i in range(layer_count):
|
||||
expected_outputs.append(f"present_key_cross_{i}")
|
||||
expected_outputs.append(f"present_value_cross_{i}")
|
||||
|
||||
if len(graph.output) != len(expected_outputs):
|
||||
raise ValueError(f"Number of outputs expected to be {len(expected_outputs)}. Got {len(graph.output)}")
|
||||
|
||||
for i, expected_output in enumerate(expected_outputs):
|
||||
if graph.output[i].name != expected_output:
|
||||
raise ValueError(f"Output {i} is expected to be {expected_output}. Got {graph.output[i].name}")
|
||||
|
||||
expected_type = onnx_proto.TensorProto.FLOAT16 if is_float16 else onnx_proto.TensorProto.FLOAT
|
||||
output_type = graph.output[i].type.tensor_type.elem_type
|
||||
if output_type != expected_type:
|
||||
raise ValueError(f"Output {i} is expected to have onnx data type {expected_type}. Got {output_type}")
|
||||
|
||||
print("T5 encoder graph verified: name and data type of inputs and outputs are good.")
|
||||
|
||||
|
||||
def convert_model(args):
|
||||
if os.path.exists(args.decoder_onnx):
|
||||
print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
|
||||
else:
|
||||
assert args.model_type == "gpt2", "please have onnx model ready for model type that is not gpt2"
|
||||
gpt2_to_onnx(args)
|
||||
def convert_model(args: argparse.Namespace):
|
||||
"""Convert model according to command line arguments.
|
||||
|
||||
# TODO: fix shape inference for T5. Currently symbolic shape inference on T5 is broken.
|
||||
Args:
|
||||
args (argparse.Namespace): arguments parsed from command line
|
||||
"""
|
||||
is_gpt2: bool = args.model_type == "gpt2"
|
||||
if is_gpt2:
|
||||
if os.path.exists(args.decoder_onnx):
|
||||
print(f"skip convert_to_onnx since path existed: {args.decoder_onnx}")
|
||||
else:
|
||||
print(f"Convert GPT model {args.model_name_or_path} to onnx {args.decoder_onnx} ...")
|
||||
gpt2_to_onnx(args)
|
||||
else: # t5
|
||||
if args.decoder_onnx and args.encoder_decoder_init_onnx:
|
||||
print(
|
||||
f"skip convert_to_onnx since paths specified: {args.decoder_onnx} and {args.encoder_decoder_init_onnx}"
|
||||
)
|
||||
else:
|
||||
print(f"Convert T5 model {args.model_name_or_path} to onnx ...")
|
||||
t5_to_onnx(args)
|
||||
|
||||
# TODO(tianleiwu): fix shape inference for T5. Currently symbolic shape inference on T5 is broken.
|
||||
enable_shape_inference = args.model_type == "gpt2"
|
||||
|
||||
if enable_shape_inference:
|
||||
print(f"Run symbolic shape inference on {args.decoder_onnx}. The file will be overwritten.")
|
||||
shape_inference(args.decoder_onnx)
|
||||
|
||||
global config
|
||||
if args.model_type == "gpt2":
|
||||
if is_gpt2:
|
||||
config = GPT2Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
else:
|
||||
config = T5Config.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
print(config)
|
||||
|
||||
if args.verbose:
|
||||
print(config)
|
||||
|
||||
eos_token_id = config.eos_token_id
|
||||
pad_token_id = config.eos_token_id
|
||||
pad_token_id = config.eos_token_id if is_gpt2 else config.pad_token_id
|
||||
vocab_size = config.vocab_size
|
||||
|
||||
# if vocab_size is given in parameters use that.
|
||||
|
|
@ -394,7 +626,7 @@ def convert_model(args):
|
|||
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
|
||||
outputs.append("scores")
|
||||
|
||||
node = helper.make_node(
|
||||
node = onnx.helper.make_node(
|
||||
"BeamSearch",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
|
|
@ -403,12 +635,12 @@ def convert_model(args):
|
|||
node.domain = "com.microsoft"
|
||||
node.attribute.extend(
|
||||
[
|
||||
helper.make_attribute("eos_token_id", eos_token_id),
|
||||
helper.make_attribute("pad_token_id", pad_token_id),
|
||||
helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
|
||||
helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
|
||||
helper.make_attribute("decoder", model.graph),
|
||||
onnx.helper.make_attribute("eos_token_id", eos_token_id),
|
||||
onnx.helper.make_attribute("pad_token_id", pad_token_id),
|
||||
onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
|
||||
onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
|
||||
onnx.helper.make_attribute("decoder", model.graph),
|
||||
]
|
||||
)
|
||||
|
||||
|
|
@ -421,22 +653,25 @@ def convert_model(args):
|
|||
verify_t5_encoder_decoder_init_subgraph(init_model.graph, args.precision)
|
||||
node.attribute.extend(
|
||||
[
|
||||
helper.make_attribute("encoder_decoder_init", init_model.graph),
|
||||
onnx.helper.make_attribute("encoder", init_model.graph),
|
||||
onnx.helper.make_attribute(
|
||||
"decoder_start_token_id", config.decoder_start_token_id if len(init_model.graph.input) == 3 else -1
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
from onnx import TensorProto
|
||||
|
||||
# graph inputs
|
||||
input_ids = helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])
|
||||
max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
|
||||
min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
|
||||
num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
|
||||
num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
|
||||
temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
|
||||
length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
|
||||
repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
|
||||
vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
|
||||
input_ids = onnx.helper.make_tensor_value_info("input_ids", TensorProto.INT32, ["batch_size", "sequence_length"])
|
||||
max_length = onnx.helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
|
||||
min_length = onnx.helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
|
||||
num_beams = onnx.helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
|
||||
num_return_sequences = onnx.helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
|
||||
temperature = onnx.helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
|
||||
length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
|
||||
repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
|
||||
vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
|
||||
|
||||
graph_inputs = [
|
||||
input_ids,
|
||||
|
|
@ -451,23 +686,23 @@ def convert_model(args):
|
|||
]
|
||||
|
||||
if args.prefix_vocab_mask:
|
||||
prefix_vocab_mask = helper.make_tensor_value_info(
|
||||
prefix_vocab_mask = onnx.helper.make_tensor_value_info(
|
||||
"prefix_vocab_mask", TensorProto.INT32, ["batch_size", vocab_size]
|
||||
)
|
||||
graph_inputs.append(prefix_vocab_mask)
|
||||
|
||||
# graph outputs
|
||||
sequences = helper.make_tensor_value_info(
|
||||
sequences = onnx.helper.make_tensor_value_info(
|
||||
"sequences",
|
||||
TensorProto.INT32,
|
||||
["batch_size", "num_return_sequences", "max_length"],
|
||||
)
|
||||
|
||||
sequences_scores = helper.make_tensor_value_info(
|
||||
sequences_scores = onnx.helper.make_tensor_value_info(
|
||||
"sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"]
|
||||
)
|
||||
|
||||
scores = helper.make_tensor_value_info(
|
||||
scores = onnx.helper.make_tensor_value_info(
|
||||
"scores",
|
||||
TensorProto.FLOAT,
|
||||
["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
|
||||
|
|
@ -483,7 +718,7 @@ def convert_model(args):
|
|||
if args.output_token_scores:
|
||||
graph_outputs.append(scores)
|
||||
|
||||
new_graph = helper.make_graph(
|
||||
new_graph = onnx.helper.make_graph(
|
||||
[node],
|
||||
f"{args.model_type}-beam-search",
|
||||
graph_inputs,
|
||||
|
|
@ -492,18 +727,44 @@ def convert_model(args):
|
|||
)
|
||||
|
||||
# Create the model
|
||||
new_model = helper.make_model(
|
||||
new_model = onnx.helper.make_model(
|
||||
new_graph,
|
||||
producer_name="onnxruntime.transformers",
|
||||
opset_imports=model.opset_import,
|
||||
)
|
||||
|
||||
# TODO(tianleiwu): move shared initializers from T5 encoder and decoder subgraphs to parent graph to save memory.
|
||||
onnx.save(new_model, args.output)
|
||||
|
||||
|
||||
def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id, pad_token_id, bad_words_ids):
|
||||
def test_torch_performance(
|
||||
args: argparse.Namespace,
|
||||
model: Union[GPT2LMHeadModel, T5ForConditionalGeneration],
|
||||
input_ids: torch.Tensor,
|
||||
attention_mask: torch.Tensor,
|
||||
eos_token_id: int,
|
||||
pad_token_id: int,
|
||||
bad_words_ids: List[List[int]],
|
||||
) -> Dict[str, Any]:
|
||||
"""Test PyTorch performance of text generation.
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments parsed from command line
|
||||
model (Union[GPT2LMHeadModel, T5ForConditionalGeneration]): PyTorch model
|
||||
input_ids (torch.Tensor): input_ids
|
||||
attention_mask (torch.Tensor): Attention mask
|
||||
eos_token_id (int): EOS token ID
|
||||
pad_token_id (int): Padding token ID
|
||||
bad_words_ids (List[List[int]]): Words shall not be generated.
|
||||
|
||||
Raises:
|
||||
RuntimeError: PyTorch with CUDA is not available for --use_gpu
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A dictionary with string with metric name, and value can be integer or string.
|
||||
"""
|
||||
if args.use_gpu and not torch.cuda.is_available():
|
||||
logger.error("Please install PyTorch with Cuda, and use a machine with GPU for testing gpu performance.")
|
||||
return None
|
||||
raise RuntimeError("Please install PyTorch with Cuda for testing gpu performance.")
|
||||
|
||||
if args.precision == Precision.FLOAT16:
|
||||
model.half()
|
||||
|
|
@ -543,23 +804,27 @@ def test_torch_performance(args, model, input_ids, attention_mask, eos_token_id,
|
|||
return get_latency_result(torch_latency, batch_size)
|
||||
|
||||
|
||||
def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
|
||||
if args.model_type != "gpt2":
|
||||
print(
|
||||
f"Skipping parity test since the support for model type {args.model_type} is not implemented in OnnxRuntime"
|
||||
)
|
||||
return True
|
||||
def test_gpt_model(args: argparse.Namespace, use_vocab_mask: bool = False, sentences: Optional[List[str]] = None):
|
||||
"""Test GPT-2 model
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments parsed from command line
|
||||
use_vocab_mask (bool, optional): use vocabulary mask. Defaults to False.
|
||||
sentences (Optional[List[str]], optional): input text. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
|
||||
"""
|
||||
assert args.model_type == "gpt2"
|
||||
|
||||
if args.temperature != 1.0:
|
||||
# TODO: implement temperature in BeamSearch operator.
|
||||
# TODO(tianleiwu): implement temperature in BeamSearch operator.
|
||||
print("Skipping parity test as temperature is not implemented in BeamSearch operator")
|
||||
return True
|
||||
return None
|
||||
|
||||
if args.prefix_vocab_mask:
|
||||
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
|
||||
return True
|
||||
|
||||
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
||||
return None
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
tokenizer.padding_side = "left"
|
||||
|
|
@ -589,14 +854,200 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
|
|||
if use_vocab_mask:
|
||||
print("bad_words_ids", bad_words_ids)
|
||||
else:
|
||||
bad_words_ids = None
|
||||
bad_words_ids = []
|
||||
|
||||
global config
|
||||
config = model.config
|
||||
eos_token_id = config.eos_token_id
|
||||
pad_token_id = config.eos_token_id
|
||||
vocab_size = config.vocab_size
|
||||
|
||||
torch_decoded_sequences = []
|
||||
beam_outputs = None
|
||||
if not args.disable_parity:
|
||||
print("-" * 50)
|
||||
print("Test PyTorch model and beam search with huggingface transformers...")
|
||||
beam_outputs = model.generate(
|
||||
input_ids=input_ids,
|
||||
attention_mask=attention_mask,
|
||||
max_length=args.max_length,
|
||||
min_length=args.min_length,
|
||||
num_beams=args.num_beams,
|
||||
early_stopping=args.early_stopping,
|
||||
no_repeat_ngram_size=args.no_repeat_ngram_size,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
num_return_sequences=args.num_return_sequences,
|
||||
temperature=args.temperature,
|
||||
length_penalty=args.length_penalty,
|
||||
repetition_penalty=args.repetition_penalty,
|
||||
bad_words_ids=bad_words_ids,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=args.output_sequences_scores or args.output_token_scores,
|
||||
)
|
||||
print("input_ids", input_ids)
|
||||
print("huggingface transformers outputs:")
|
||||
print("sequences", beam_outputs.sequences)
|
||||
if args.output_sequences_scores:
|
||||
print("sequences_scores", beam_outputs.sequences_scores)
|
||||
if args.output_token_scores:
|
||||
print("scores", beam_outputs.scores)
|
||||
for i, sequence in enumerate(beam_outputs.sequences):
|
||||
decoded_sequence = tokenizer.decode(sequence, skip_special_tokens=True)
|
||||
torch_decoded_sequences.append(decoded_sequence)
|
||||
print(f"{i}: {decoded_sequence}")
|
||||
|
||||
print("-" * 50)
|
||||
print("Test ONNX model and bream search with onnxruntime...")
|
||||
|
||||
ort_session = create_ort_session(args.output, args.use_gpu)
|
||||
|
||||
vocab_mask = np.ones((vocab_size), dtype=np.int32)
|
||||
if use_vocab_mask:
|
||||
for bad_word_id in bad_words_ids:
|
||||
vocab_mask[bad_word_id] = 0
|
||||
|
||||
inputs = {
|
||||
"input_ids": input_ids.cpu().numpy().astype(np.int32),
|
||||
"max_length": np.array([args.max_length], dtype=np.int32),
|
||||
"min_length": np.array([args.min_length], dtype=np.int32),
|
||||
"num_beams": np.array([args.num_beams], dtype=np.int32),
|
||||
"num_return_sequences": np.array([args.num_return_sequences], dtype=np.int32),
|
||||
"temperature": np.array([args.temperature], dtype=np.float32),
|
||||
"length_penalty": np.array([args.length_penalty], dtype=np.float32),
|
||||
"repetition_penalty": np.array([args.repetition_penalty], dtype=np.float32),
|
||||
"vocab_mask": vocab_mask,
|
||||
}
|
||||
|
||||
print("inputs", inputs)
|
||||
result = ort_session.run(None, inputs)
|
||||
|
||||
test_data_dir = Path(args.output).parent.as_posix()
|
||||
print("test_data_dir", test_data_dir)
|
||||
from bert_test_data import output_test_data
|
||||
|
||||
all_inputs = [inputs]
|
||||
for i, inputs in enumerate(all_inputs):
|
||||
dir = os.path.join(test_data_dir, "test_data_set_" + str(i))
|
||||
output_test_data(dir, inputs)
|
||||
|
||||
# Test performance
|
||||
latency = []
|
||||
for _ in range(args.total_runs):
|
||||
start = time.time()
|
||||
_ = ort_session.run(None, inputs)
|
||||
latency.append(time.time() - start)
|
||||
batch_size = input_ids.shape[0]
|
||||
from benchmark_helper import get_latency_result
|
||||
|
||||
output = get_latency_result(latency, batch_size)
|
||||
|
||||
print("ORT outputs:")
|
||||
sequences = result[0]
|
||||
print("sequences", sequences)
|
||||
if args.output_sequences_scores:
|
||||
print("sequences_scores", result[1])
|
||||
if args.output_token_scores:
|
||||
print("scores", result[2])
|
||||
|
||||
(batch_size, num_sequences, max_length) = sequences.shape
|
||||
ort_decoded_sequences = []
|
||||
for i in range(batch_size):
|
||||
for j in range(num_sequences):
|
||||
decoded_sequence = tokenizer.decode(sequences[i][j], skip_special_tokens=True)
|
||||
ort_decoded_sequences.append(decoded_sequence)
|
||||
print(f"batch {i} sequence {j}: {decoded_sequence}")
|
||||
|
||||
if beam_outputs:
|
||||
torch_sequences = beam_outputs.sequences.reshape(batch_size, args.num_return_sequences, -1)
|
||||
ort_sequences = torch.LongTensor(sequences)
|
||||
print("-" * 50)
|
||||
print("Torch Sequences:")
|
||||
print(torch_sequences)
|
||||
print(torch_decoded_sequences)
|
||||
print("-" * 50)
|
||||
print("ORT Sequences:")
|
||||
print(ort_sequences)
|
||||
print(ort_decoded_sequences)
|
||||
print("-" * 50)
|
||||
# Compare the generated text instead of word IDs since ORT pads to max sequence length but Torch not.
|
||||
is_same = torch_decoded_sequences == ort_decoded_sequences
|
||||
print("Torch and ORT result is ", "same" if is_same else "different")
|
||||
output["parity"] = is_same
|
||||
|
||||
if args.torch_performance:
|
||||
torch_latency_output = test_torch_performance(
|
||||
args,
|
||||
model,
|
||||
input_ids,
|
||||
attention_mask,
|
||||
eos_token_id,
|
||||
pad_token_id,
|
||||
bad_words_ids,
|
||||
)
|
||||
print("Torch Latency", torch_latency_output)
|
||||
|
||||
print("ORT", output)
|
||||
return output
|
||||
|
||||
|
||||
def test_t5_model(args: argparse.Namespace, use_vocab_mask: bool = False, sentences: Optional[List[str]] = None):
|
||||
"""Test T5 model
|
||||
|
||||
Args:
|
||||
args (argparse.Namespace): arguments parsed from command line
|
||||
use_vocab_mask (bool, optional): use vocabulary mask. Defaults to False.
|
||||
sentences (Optional[List[str]], optional): input text. Defaults to None.
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
|
||||
"""
|
||||
assert args.model_type == "t5"
|
||||
|
||||
if args.temperature != 1.0:
|
||||
# TODO(tianleiwu): implement temperature in BeamSearch operator.
|
||||
print("Skipping parity test as temperature is not implemented in BeamSearch operator")
|
||||
return None
|
||||
|
||||
if args.prefix_vocab_mask:
|
||||
print("Skipping parity test as prefix vocab mask is not implemented by Hugging Face")
|
||||
return None
|
||||
|
||||
tokenizer = T5Tokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
|
||||
tokenizer.padding_side = "left"
|
||||
|
||||
model = T5ForConditionalGeneration.from_pretrained(
|
||||
args.model_name_or_path,
|
||||
cache_dir=args.cache_dir,
|
||||
)
|
||||
|
||||
# Use different length sentences to test batching
|
||||
if sentences is None:
|
||||
sentences = [
|
||||
"translate English to French: The product is released",
|
||||
"summarize: research continues to show that pets bring real health benefits to their owners."
|
||||
+ "Having a dog around can lead to lower levels of stress for both adults and kids.",
|
||||
# "summarize: I enjoy walking in the park. It makes my mind feel calm and refreshed. "
|
||||
# + "I enjoy looking at the trees, flowers, and wildlife around me, and listening to sound from natural.",
|
||||
]
|
||||
|
||||
inputs = tokenizer(sentences, return_tensors="pt", padding=True)
|
||||
input_ids = inputs["input_ids"]
|
||||
attention_mask = inputs["attention_mask"]
|
||||
|
||||
bad_words = "walk in park"
|
||||
bad_words_ids = tokenizer.encode(bad_words)[:-1] # exclude the last token (EOS)
|
||||
bad_words_ids = [[word_id] for word_id in bad_words_ids] # Convert to list of list
|
||||
if use_vocab_mask:
|
||||
print("bad_words_ids", bad_words_ids)
|
||||
else:
|
||||
bad_words_ids = []
|
||||
|
||||
config = model.config
|
||||
eos_token_id = config.eos_token_id
|
||||
pad_token_id = config.pad_token_id
|
||||
vocab_size = config.vocab_size
|
||||
print(f"eos_token_id:{eos_token_id}, pad_token_id:{pad_token_id}, vocab_size:{vocab_size}")
|
||||
|
||||
torch_decoded_sequences = []
|
||||
if not args.disable_parity:
|
||||
print("-" * 50)
|
||||
|
|
@ -619,6 +1070,7 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
|
|||
return_dict_in_generate=True,
|
||||
output_scores=args.output_sequences_scores or args.output_token_scores,
|
||||
)
|
||||
|
||||
print("input_ids", input_ids)
|
||||
print("huggingface transformers outputs:")
|
||||
print("sequences", beam_outputs.sequences)
|
||||
|
|
@ -724,17 +1176,44 @@ def test_model(args, use_vocab_mask: bool = False, sentences: List[str] = None):
|
|||
return output
|
||||
|
||||
|
||||
def main(argv=None, sentences=None):
|
||||
def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None):
|
||||
"""Main entry function
|
||||
|
||||
Args:
|
||||
argv (Optional[List[str]], optional): _description_. Defaults to None.
|
||||
sentences (Optional[List[str]], optional): input text. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: --decoder_onnx is not specified for GPT2 model
|
||||
ValueError: Path does not exist: --encoder_decoder_init_onnx
|
||||
ValueError: Path does not exist: --decoder_onnx
|
||||
ValueError: --decoder_onnx and --encoder_decoder_init_onnx are not used together for T5
|
||||
|
||||
Returns:
|
||||
Union[Dict[str, Any], None]: A dictionary with string with metric name, and value can be integer or string.
|
||||
"""
|
||||
|
||||
args = parse_arguments(argv)
|
||||
|
||||
if args.model_type == "gpt2":
|
||||
if not args.decoder_onnx:
|
||||
raise ValueError("--decoder_onnx shall be specified for gpt2 model")
|
||||
elif args.model_type == "t5":
|
||||
if args.encoder_decoder_init_onnx and not os.path.exists(args.encoder_decoder_init_onnx):
|
||||
raise ValueError(f"Path does not exist: --encoder_decoder_init_onnx {args.encoder_decoder_init_onnx}")
|
||||
if args.decoder_onnx and not os.path.exists(args.decoder_onnx):
|
||||
raise ValueError(f"Path does not exist: --decoder_onnx {args.decoder_onnx}")
|
||||
if (args.encoder_decoder_init_onnx and not args.decoder_onnx) or (
|
||||
args.decoder_onnx and not args.encoder_decoder_init_onnx
|
||||
):
|
||||
raise ValueError("--decoder_onnx shall use together with --encoder_decoder_init_onnx")
|
||||
|
||||
convert_model(args)
|
||||
|
||||
if args.model_type == "t5":
|
||||
assert args.encoder_decoder_init_onnx, "please export t5 to onnx models before using this tool"
|
||||
|
||||
if os.path.exists(args.output):
|
||||
print(f"skip conversion since path existed: {args.output}")
|
||||
return test_t5_model(args, use_vocab_mask=True, sentences=sentences)
|
||||
else:
|
||||
convert_model(args)
|
||||
|
||||
return test_model(args, use_vocab_mask=True, sentences=sentences)
|
||||
return test_gpt_model(args, use_vocab_mask=True, sentences=sentences)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
# -------------------------------------------------------------------------
|
||||
# -------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See License.txt in the project root for
|
||||
# license information.
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import torch
|
|||
from t5_helper import PRETRAINED_T5_MODELS, T5Helper
|
||||
|
||||
sys.path.append(os.path.join(os.path.dirname(__file__), "..", ".."))
|
||||
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
|
||||
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger # noqa: E402
|
||||
|
||||
logger = logging.getLogger("")
|
||||
|
||||
|
|
@ -80,7 +80,7 @@ def parse_arguments():
|
|||
"--use_decoder_start_token",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use config.decoder_start_token_id in decoding. Otherwise, add an extra graph input for decoder_input_ids.",
|
||||
help="Use config.decoder_start_token_id. Otherwise, add an extra graph input for decoder_input_ids.",
|
||||
)
|
||||
parser.set_defaults(use_decoder_start_token=False)
|
||||
|
||||
|
|
@ -101,6 +101,22 @@ def parse_arguments():
|
|||
)
|
||||
parser.set_defaults(disable_auto_mixed_precision=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--separate_encoder_and_decoder_init",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.",
|
||||
)
|
||||
parser.set_defaults(separate_encoder_and_decoder_init=False)
|
||||
|
||||
parser.add_argument(
|
||||
"--use_int64_inputs",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.",
|
||||
)
|
||||
parser.set_defaults(use_int64_inputs=False)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
return args
|
||||
|
|
@ -115,10 +131,11 @@ def export_onnx_models(
|
|||
optimize_onnx,
|
||||
precision,
|
||||
verbose,
|
||||
use_decoder_start_token: bool = True,
|
||||
use_decoder_start_token: bool = False,
|
||||
merge_encoder_and_decoder_init: bool = True,
|
||||
overwrite: bool = False,
|
||||
disable_auto_mixed_precision: bool = False,
|
||||
use_int32_inputs: bool = True,
|
||||
):
|
||||
device = torch.device("cuda:0" if use_gpu else "cpu")
|
||||
|
||||
|
|
@ -126,7 +143,7 @@ def export_onnx_models(
|
|||
config = models["decoder"].config
|
||||
|
||||
if (not use_external_data_format) and (config.num_layers > 24):
|
||||
logger.info(f"Try use_external_data_format when model size > 2GB")
|
||||
logger.info("Try use_external_data_format when model size > 2GB")
|
||||
|
||||
output_paths = []
|
||||
for name, model in models.items():
|
||||
|
|
@ -151,6 +168,7 @@ def export_onnx_models(
|
|||
verbose,
|
||||
use_external_data_format,
|
||||
use_decoder_input_ids=not use_decoder_start_token,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
else:
|
||||
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
|
||||
|
|
@ -185,10 +203,12 @@ def export_onnx_models(
|
|||
use_gpu=use_gpu,
|
||||
provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"],
|
||||
)
|
||||
max_diff = T5Helper.verify_onnx(model, ort_session, device)
|
||||
|
||||
with torch.no_grad():
|
||||
max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
||||
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
|
||||
if max_diff > 1e-4:
|
||||
logger.warn(f"PyTorch and OnnxRuntime results are NOT close")
|
||||
logger.warning("PyTorch and OnnxRuntime results are NOT close")
|
||||
|
||||
output_paths.append(output_path)
|
||||
|
||||
|
|
@ -212,24 +232,23 @@ def main():
|
|||
assert args.use_gpu, "fp16 requires --use_gpu"
|
||||
|
||||
if args.optimize_onnx:
|
||||
logger.warn(f"Graph optimization for T5 is not implemented yet.")
|
||||
logger.warning("Graph optimization for T5 is not implemented yet.")
|
||||
|
||||
with torch.no_grad():
|
||||
merge_encoder_and_decoder_init = True # Merge encoder and decoder initialization into one model is recommended.
|
||||
output_paths = export_onnx_models(
|
||||
args.model_name_or_path,
|
||||
cache_dir,
|
||||
output_dir,
|
||||
args.use_gpu,
|
||||
args.use_external_data_format,
|
||||
args.optimize_onnx,
|
||||
args.precision,
|
||||
args.verbose,
|
||||
args.use_decoder_start_token,
|
||||
merge_encoder_and_decoder_init,
|
||||
args.overwrite,
|
||||
args.disable_auto_mixed_precision,
|
||||
)
|
||||
output_paths = export_onnx_models(
|
||||
args.model_name_or_path,
|
||||
cache_dir,
|
||||
output_dir,
|
||||
args.use_gpu,
|
||||
args.use_external_data_format,
|
||||
args.optimize_onnx,
|
||||
args.precision,
|
||||
args.verbose,
|
||||
args.use_decoder_start_token,
|
||||
not args.separate_encoder_and_decoder_init,
|
||||
args.overwrite,
|
||||
args.disable_auto_mixed_precision,
|
||||
not args.use_int64_inputs,
|
||||
)
|
||||
|
||||
logger.info(f"Done! Outputs: {output_paths}")
|
||||
|
||||
|
|
|
|||
|
|
@ -135,6 +135,7 @@ class T5DecoderInputs:
|
|||
past_decode_sequence_length: int,
|
||||
device: torch.device,
|
||||
float16: bool = False,
|
||||
use_int32_inputs: bool = False,
|
||||
): # -> T5DecoderInputs:
|
||||
"""Create dummy inputs for T5Decoder.
|
||||
|
||||
|
|
@ -145,6 +146,7 @@ class T5DecoderInputs:
|
|||
past_decode_sequence_length (int): past sequence length of input_ids for decoder
|
||||
device (torch.device): device of output tensors
|
||||
float16 (bool): whether the model uses float32 or float16 in input
|
||||
use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
|
||||
|
||||
Returns:
|
||||
T5DecoderInputs: dummy inputs for decoder
|
||||
|
|
@ -159,11 +161,17 @@ class T5DecoderInputs:
|
|||
low=0,
|
||||
high=vocab_size - 1,
|
||||
size=(batch_size, sequence_length),
|
||||
dtype=torch.int64,
|
||||
dtype=(torch.int32 if use_int32_inputs else torch.int64),
|
||||
device=device,
|
||||
)
|
||||
|
||||
encoder_inputs = T5EncoderInputs.create_dummy(batch_size, encode_sequence_length, vocab_size, device)
|
||||
encoder_inputs = T5EncoderInputs.create_dummy(
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
vocab_size,
|
||||
device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
|
||||
float_type = torch.float16 if float16 else torch.float32
|
||||
encoder_hidden_state = torch.rand(
|
||||
|
|
@ -211,7 +219,7 @@ class T5DecoderInputs:
|
|||
|
||||
def to_fp32(self):
|
||||
encoder_hidden_state = self.encoder_hidden_states.to(dtype=torch.float32)
|
||||
past = [p.to(dtype=torch.float32) for p in self.past_key_values]
|
||||
past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
|
||||
return T5DecoderInputs(
|
||||
self.decoder_input_ids.clone(),
|
||||
self.encoder_attention_mask.clone(),
|
||||
|
|
@ -228,6 +236,7 @@ class T5DecoderHelper:
|
|||
onnx_model_path: str,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_int32_inputs: bool = False,
|
||||
):
|
||||
"""Export decoder to ONNX
|
||||
|
||||
|
|
@ -237,6 +246,7 @@ class T5DecoderHelper:
|
|||
onnx_model_path (str): onnx path
|
||||
verbose (bool, optional): print verbose information. Defaults to True.
|
||||
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
||||
use_int32_inputs (bool, optional): use int32 inputs
|
||||
"""
|
||||
assert isinstance(decoder, (T5Decoder, T5DecoderInit))
|
||||
|
||||
|
|
@ -246,10 +256,9 @@ class T5DecoderHelper:
|
|||
encode_sequence_length=3,
|
||||
past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
|
||||
device=device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
input_list = inputs.to_list()
|
||||
with torch.no_grad():
|
||||
outputs = decoder(*input_list)
|
||||
|
||||
past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False)
|
||||
present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True)
|
||||
|
|
@ -351,7 +360,8 @@ class T5DecoderHelper:
|
|||
model: Union[T5Decoder, T5DecoderInit],
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
max_cases=4,
|
||||
use_int32_inputs: bool,
|
||||
max_cases: int = 4,
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
float16: bool = TypeHelper.get_input_type(ort_session, "encoder_hidden_states") == "tensor(float16)"
|
||||
|
|
@ -373,6 +383,7 @@ class T5DecoderHelper:
|
|||
past_decode_sequence_length,
|
||||
device=device,
|
||||
float16=float16,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
|
||||
# We use fp32 PyTroch model as baseline even when ONNX model is fp16
|
||||
|
|
|
|||
|
|
@ -42,28 +42,30 @@ class T5EncoderInputs:
|
|||
|
||||
@staticmethod
|
||||
def create_dummy(
|
||||
batch_size: int, sequence_length: int, vocab_size: int, device: torch.device
|
||||
batch_size: int, sequence_length: int, vocab_size: int, device: torch.device, use_int32_inputs: bool = False
|
||||
): # -> T5EncoderInputs
|
||||
"""Create dummy inputs for T5 encoder.
|
||||
|
||||
Args:
|
||||
batch_size (int): batch size
|
||||
sequence_length (int): sequence length
|
||||
vocab_size (int): vocaburary size
|
||||
vocab_size (int): vocabulary size
|
||||
device (torch.device): device of output tensors
|
||||
|
||||
Returns:
|
||||
T5EncoderInputs: dummy inputs for encoder
|
||||
"""
|
||||
dtype = torch.int32 if use_int32_inputs else torch.int64
|
||||
|
||||
input_ids = torch.randint(
|
||||
low=0,
|
||||
high=vocab_size - 1,
|
||||
size=(batch_size, sequence_length),
|
||||
dtype=torch.int64,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
|
||||
attention_mask = torch.ones([batch_size, sequence_length], dtype=torch.int64, device=device)
|
||||
attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
|
||||
if sequence_length >= 2:
|
||||
for i in range(batch_size):
|
||||
padding_position = random.randint(0, sequence_length - 1)
|
||||
|
|
@ -83,6 +85,7 @@ class T5EncoderHelper:
|
|||
onnx_model_path: str,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_int32_inputs: bool = False,
|
||||
):
|
||||
"""Export encoder to ONNX
|
||||
|
||||
|
|
@ -95,12 +98,13 @@ class T5EncoderHelper:
|
|||
"""
|
||||
config = encoder.config
|
||||
encoder_inputs = T5EncoderInputs.create_dummy(
|
||||
batch_size=2, sequence_length=4, vocab_size=config.vocab_size, device=device
|
||||
batch_size=2,
|
||||
sequence_length=4,
|
||||
vocab_size=config.vocab_size,
|
||||
device=device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
outputs = encoder(encoder_inputs.input_ids, encoder_inputs.attention_mask)
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch_onnx_export(
|
||||
encoder,
|
||||
|
|
@ -131,13 +135,16 @@ class T5EncoderHelper:
|
|||
return ort_session.run(None, ort_inputs)
|
||||
|
||||
@staticmethod
|
||||
def verify_onnx(model: T5Encoder, ort_session: InferenceSession, device: torch.device):
|
||||
def verify_onnx(
|
||||
model: T5Encoder, ort_session: InferenceSession, device: torch.device, use_int32_inputs: bool = False
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
inputs = T5EncoderInputs.create_dummy(
|
||||
batch_size=4,
|
||||
sequence_length=11,
|
||||
vocab_size=model.config.vocab_size,
|
||||
device=device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
input_list = inputs.to_list()
|
||||
torch_outputs = model(*input_list)
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@
|
|||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy
|
||||
import onnx
|
||||
import torch
|
||||
from past_helper import PastKeyValuesHelper
|
||||
from t5_decoder import T5DecoderInit
|
||||
|
|
@ -34,7 +36,7 @@ class T5EncoderDecoderInit(torch.nn.Module):
|
|||
decoder: torch.nn.Module,
|
||||
lm_head: torch.nn.Module,
|
||||
config: T5Config,
|
||||
decoder_start_token_id: int = None,
|
||||
decoder_start_token_id: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
|
@ -67,15 +69,19 @@ class T5EncoderDecoderInitInputs:
|
|||
encode_sequence_length: int,
|
||||
use_decoder_input_ids: int,
|
||||
device: torch.device,
|
||||
use_int32_inputs: bool = False,
|
||||
): # -> T5EncoderDecoderInitInputs:
|
||||
encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
|
||||
batch_size, encode_sequence_length, config.vocab_size, device
|
||||
batch_size,
|
||||
encode_sequence_length,
|
||||
config.vocab_size,
|
||||
device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
decoder_input_ids = None
|
||||
if use_decoder_input_ids:
|
||||
decoder_input_ids = (
|
||||
torch.ones((batch_size, 1), dtype=torch.long, device=device) * config.decoder_start_token_id
|
||||
)
|
||||
dtype = torch.int32 if use_int32_inputs else torch.int64
|
||||
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
|
||||
|
||||
return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
|
||||
|
||||
|
|
@ -95,6 +101,7 @@ class T5EncoderDecoderInitHelper:
|
|||
use_decoder_input_ids: bool = True,
|
||||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_int32_inputs: bool = False,
|
||||
):
|
||||
"""Export decoder to ONNX
|
||||
|
||||
|
|
@ -113,9 +120,9 @@ class T5EncoderDecoderInitHelper:
|
|||
encode_sequence_length=3,
|
||||
use_decoder_input_ids=use_decoder_input_ids,
|
||||
device=device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
input_list = inputs.to_list()
|
||||
outputs = model(*input_list)
|
||||
|
||||
present_names = PastKeyValuesHelper.get_past_names(model.config.num_layers, present=True)
|
||||
|
||||
|
|
@ -135,7 +142,8 @@ class T5EncoderDecoderInitHelper:
|
|||
|
||||
input_names = ["encoder_input_ids", "encoder_attention_mask"]
|
||||
|
||||
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2'. Use more friendly string here.
|
||||
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference.
|
||||
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
|
||||
sequence_length = "1"
|
||||
num_heads = str(model.config.num_heads)
|
||||
hidden_size = str(model.config.d_model)
|
||||
|
|
@ -149,12 +157,18 @@ class T5EncoderDecoderInitHelper:
|
|||
1: "encode_sequence_length",
|
||||
2: hidden_size,
|
||||
},
|
||||
"logits": {0: "batch_size", 1: sequence_length},
|
||||
"logits": {
|
||||
0: "batch_size",
|
||||
1: sequence_length,
|
||||
},
|
||||
}
|
||||
|
||||
if use_decoder_input_ids:
|
||||
input_names.append("decoder_input_ids")
|
||||
dynamic_axes["decoder_input_ids"] = {0: "batch_size", 1: sequence_length}
|
||||
dynamic_axes["decoder_input_ids"] = {
|
||||
0: "batch_size",
|
||||
1: sequence_length,
|
||||
}
|
||||
|
||||
for name in present_names:
|
||||
if "cross" in name:
|
||||
|
|
@ -173,20 +187,47 @@ class T5EncoderDecoderInitHelper:
|
|||
3: head_size,
|
||||
}
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch_onnx_export(
|
||||
model,
|
||||
args=tuple(input_list),
|
||||
f=onnx_model_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=12,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_data_format,
|
||||
verbose=verbose,
|
||||
)
|
||||
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
||||
temp_onnx_model_path = os.path.join(tmp_dir_name, "model.onnx")
|
||||
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
torch_onnx_export(
|
||||
model,
|
||||
args=tuple(input_list),
|
||||
f=temp_onnx_model_path,
|
||||
export_params=True,
|
||||
input_names=input_names,
|
||||
output_names=output_names,
|
||||
dynamic_axes=dynamic_axes,
|
||||
opset_version=12,
|
||||
do_constant_folding=True,
|
||||
use_external_data_format=use_external_data_format,
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# Workaround as mentioned earlier: change numeric dim_param to dim_value
|
||||
model = onnx.load(temp_onnx_model_path)
|
||||
for tensor in model.graph.output:
|
||||
for dim_proto in tensor.type.tensor_type.shape.dim:
|
||||
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
|
||||
sequence_length,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
head_size,
|
||||
]:
|
||||
dim_value = int(dim_proto.dim_param)
|
||||
dim_proto.Clear()
|
||||
dim_proto.dim_value = dim_value
|
||||
|
||||
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
onnx.save_model(
|
||||
model,
|
||||
onnx_model_path,
|
||||
save_as_external_data=use_external_data_format,
|
||||
all_tensors_to_one_file=True,
|
||||
location=onnx_model_path + ".data",
|
||||
size_threshold=4096,
|
||||
convert_attribute=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
|
||||
|
|
@ -208,7 +249,8 @@ class T5EncoderDecoderInitHelper:
|
|||
model: T5EncoderDecoderInit,
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
max_cases=4,
|
||||
use_int32_inputs: bool,
|
||||
max_cases: int = 4,
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
ort_inputs = ort_session.get_inputs()
|
||||
|
|
@ -223,6 +265,7 @@ class T5EncoderDecoderInitHelper:
|
|||
encode_sequence_length,
|
||||
use_decoder_input_ids=use_decoder_input_ids,
|
||||
device=device,
|
||||
use_int32_inputs=use_int32_inputs,
|
||||
)
|
||||
|
||||
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from optimizer import optimize_model
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3B", "t5-11B"]
|
||||
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
|
||||
|
||||
|
||||
class T5Helper:
|
||||
|
|
@ -110,9 +110,17 @@ class T5Helper:
|
|||
verbose: bool = True,
|
||||
use_external_data_format: bool = False,
|
||||
use_decoder_input_ids: bool = True,
|
||||
use_int32_inputs: bool = False,
|
||||
):
|
||||
if isinstance(model, T5Encoder):
|
||||
T5EncoderHelper.export_onnx(model, device, onnx_model_path, verbose, use_external_data_format)
|
||||
T5EncoderHelper.export_onnx(
|
||||
model,
|
||||
device,
|
||||
onnx_model_path,
|
||||
verbose,
|
||||
use_external_data_format,
|
||||
use_int32_inputs,
|
||||
)
|
||||
elif isinstance(model, T5EncoderDecoderInit):
|
||||
T5EncoderDecoderInitHelper.export_onnx(
|
||||
model,
|
||||
|
|
@ -121,9 +129,17 @@ class T5Helper:
|
|||
use_decoder_input_ids,
|
||||
verbose,
|
||||
use_external_data_format,
|
||||
use_int32_inputs,
|
||||
)
|
||||
else:
|
||||
T5DecoderHelper.export_onnx(model, device, onnx_model_path, verbose, use_external_data_format)
|
||||
T5DecoderHelper.export_onnx(
|
||||
model,
|
||||
device,
|
||||
onnx_model_path,
|
||||
verbose,
|
||||
use_external_data_format,
|
||||
use_int32_inputs,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def auto_mixed_precision(
|
||||
|
|
@ -234,11 +250,13 @@ class T5Helper:
|
|||
model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit],
|
||||
ort_session: InferenceSession,
|
||||
device: torch.device,
|
||||
use_int32_inputs: bool,
|
||||
):
|
||||
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
||||
if isinstance(model, T5Encoder):
|
||||
return T5EncoderHelper.verify_onnx(model, ort_session, device)
|
||||
elif isinstance(model, T5EncoderDecoderInit):
|
||||
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device)
|
||||
else:
|
||||
return T5DecoderHelper.verify_onnx(model, ort_session, device)
|
||||
return T5EncoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
||||
|
||||
if isinstance(model, T5EncoderDecoderInit):
|
||||
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
||||
|
||||
return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ py3nvml
|
|||
packaging
|
||||
transformers >= 4.0
|
||||
scipy
|
||||
sentencepiece
|
||||
|
||||
# please follow https://pytorch.org/ to install PyTorch for your OS
|
||||
torch >= 1.8
|
||||
|
|
@ -10,63 +10,189 @@ import os
|
|||
import unittest
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from parity_utilities import find_transformers_source
|
||||
|
||||
if find_transformers_source():
|
||||
from onnxruntime import get_available_providers
|
||||
|
||||
if find_transformers_source() and find_transformers_source(["models", "t5"]):
|
||||
from benchmark_helper import Precision
|
||||
from convert_beam_search import main as run
|
||||
from models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
|
||||
else:
|
||||
from onnxruntime.transformers.benchmark_helper import Precision
|
||||
from onnxruntime.transformers.convert_beam_search import main as run
|
||||
from onnxruntime.transformers.models.t5.convert_to_onnx import export_onnx_models as export_t5_onnx_models
|
||||
|
||||
|
||||
class TestBeamSearch(unittest.TestCase):
|
||||
class TestBeamSearchGpt(unittest.TestCase):
|
||||
"""Test BeamSearch for GPT-2 model"""
|
||||
|
||||
def setUp(self):
|
||||
# TODO: use a smaller model and enable tests in CI pipeline
|
||||
self.model_name = "gpt2"
|
||||
self.gpt2_onnx_path = os.path.join(".", "onnx_models", "gpt2_past_fp32_shape.onnx")
|
||||
self.beam_search_onnx_path = os.path.join(".", "onnx_models", "gpt2_beam_search.onnx")
|
||||
self.cpu_params = f"-m {self.model_name} --decoder_onnx {self.gpt2_onnx_path} --output {self.beam_search_onnx_path} --output_sequences_score --repetition_penalty 2.0"
|
||||
self.default_arguments = [
|
||||
f"-m {self.model_name}",
|
||||
f"--decoder_onnx {self.gpt2_onnx_path}",
|
||||
f"--output {self.beam_search_onnx_path}",
|
||||
"--output_sequences_score",
|
||||
"--repetition_penalty 2.0",
|
||||
]
|
||||
self.sentences = [
|
||||
"The product is released",
|
||||
"I enjoy walking in the park",
|
||||
"Test best way to invest",
|
||||
]
|
||||
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
|
||||
self.remove_onnx_files()
|
||||
|
||||
def run_beam_search(self, arguments: str, sentences=None):
|
||||
return run(arguments.split(), sentences=sentences)
|
||||
def tearDown(self):
|
||||
self.remove_onnx_files()
|
||||
|
||||
def remove_onnx_files(self):
|
||||
if os.path.exists(self.gpt2_onnx_path):
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
|
||||
if os.path.exists(self.beam_search_onnx_path):
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
|
||||
def run_beam_search(self, extra_arguments: str, sentences=None):
|
||||
arguments = " ".join(self.default_arguments + [extra_arguments]).split()
|
||||
|
||||
# Test CPU
|
||||
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
|
||||
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
|
||||
|
||||
# Test GPU
|
||||
if self.enable_cuda:
|
||||
if "--use_gpu" not in arguments:
|
||||
arguments.append("--use_gpu")
|
||||
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
|
||||
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}")
|
||||
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_cpu(self):
|
||||
result = self.run_beam_search(
|
||||
self.cpu_params + " --num_return_sequences 2",
|
||||
sentences=["The product is released"],
|
||||
)
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
def test_return_sequences(self):
|
||||
for return_sequences in [1, 2]:
|
||||
self.run_beam_search(f"--num_return_sequences {return_sequences}")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_early_stopping(self):
|
||||
result = self.run_beam_search(self.cpu_params + " --early_stopping")
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
self.run_beam_search("--early_stopping")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_temperature(self):
|
||||
result = self.run_beam_search(self.cpu_params + " --temperature 0.5")
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
# @pytest.mark.slow
|
||||
# def test_temperature(self):
|
||||
# self.run_beam_search("--temperature 0.5")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_length_penalty(self):
|
||||
result = self.run_beam_search(self.cpu_params + " --length_penalty 0.5")
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
for length_penalty in [0.5, 2.0]:
|
||||
self.run_beam_search(f"--length_penalty {length_penalty}")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_no_repeat_ngram(self):
|
||||
for ngram_size in [1, 2]:
|
||||
result = self.run_beam_search(self.cpu_params + f" --no_repeat_ngram_size {ngram_size}")
|
||||
os.remove(self.gpt2_onnx_path)
|
||||
self.run_beam_search(f"--no_repeat_ngram_size {ngram_size}")
|
||||
|
||||
|
||||
class TestBeamSearchT5(unittest.TestCase):
|
||||
"""Test BeamSearch for T5 model"""
|
||||
|
||||
def setUp(self):
|
||||
self.model_name = "t5-small"
|
||||
self.decoder_onnx_path = os.path.join(".", "onnx_models", "t5-small_decoder.onnx")
|
||||
self.encoder_onnx_path = os.path.join(".", "onnx_models", "t5-small_encoder_decoder_init.onnx")
|
||||
self.beam_search_onnx_path = os.path.join(".", "onnx_models", "t5_small_beam_search.onnx")
|
||||
self.default_arguments = [
|
||||
f"-m {self.model_name}",
|
||||
"--model_type t5",
|
||||
f"--decoder_onnx {self.decoder_onnx_path}",
|
||||
f"--encoder_decoder_init_onnx {self.encoder_onnx_path}",
|
||||
f"--output {self.beam_search_onnx_path}",
|
||||
"--output_sequences_score",
|
||||
"--repetition_penalty 2.0",
|
||||
]
|
||||
|
||||
self.enable_cuda = torch.cuda.is_available() and "CUDAExecutionProvider" in get_available_providers()
|
||||
|
||||
export_t5_onnx_models(
|
||||
self.model_name,
|
||||
os.path.join(".", "cache_models"),
|
||||
os.path.join(".", "onnx_models"),
|
||||
use_gpu=False,
|
||||
use_external_data_format=False,
|
||||
optimize_onnx=False,
|
||||
precision=Precision.FLOAT32,
|
||||
verbose=False,
|
||||
use_decoder_start_token=False,
|
||||
merge_encoder_and_decoder_init=True,
|
||||
overwrite=True,
|
||||
disable_auto_mixed_precision=False,
|
||||
use_int32_inputs=True,
|
||||
)
|
||||
|
||||
self.sentences = [
|
||||
"translate English to French: The product is released",
|
||||
"summarize: research continues to show that pets bring real health benefits to their owners."
|
||||
+ "Having a dog around can lead to lower levels of stress for both adults and kids.",
|
||||
]
|
||||
|
||||
if os.path.exists(self.beam_search_onnx_path):
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
self.assertTrue(result["parity"], "ORT and PyTorch result is different")
|
||||
|
||||
def tearDown(self):
|
||||
self.remove_onnx_files()
|
||||
|
||||
def remove_onnx_files(self):
|
||||
if os.path.exists(self.beam_search_onnx_path):
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
|
||||
if os.path.exists(self.decoder_onnx_path):
|
||||
os.remove(self.decoder_onnx_path)
|
||||
|
||||
if os.path.exists(self.encoder_onnx_path):
|
||||
os.remove(self.encoder_onnx_path)
|
||||
|
||||
def run_beam_search(self, extra_arguments: str, sentences=None):
|
||||
arguments = " ".join(self.default_arguments + [extra_arguments]).split()
|
||||
|
||||
# Test CPU
|
||||
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
|
||||
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on CPU for arguments {arguments}")
|
||||
|
||||
# Test GPU
|
||||
if self.enable_cuda:
|
||||
if "--use_gpu" not in arguments:
|
||||
arguments.append("--use_gpu")
|
||||
result = run(arguments, sentences=self.sentences if sentences is None else sentences)
|
||||
self.assertTrue(result["parity"], f"ORT and PyTorch result is different on GPU for arguments {arguments}")
|
||||
|
||||
os.remove(self.beam_search_onnx_path)
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_return_sequences(self):
|
||||
for return_sequences in [1, 2]:
|
||||
self.run_beam_search(f"--num_return_sequences {return_sequences}")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_early_stopping(self):
|
||||
self.run_beam_search("--early_stopping")
|
||||
|
||||
# @pytest.mark.slow
|
||||
# def test_temperature(self):
|
||||
# self.run_beam_search("--temperature 0.5")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_length_penalty(self):
|
||||
for length_penalty in [0.5, 2.0]:
|
||||
self.run_beam_search(f"--length_penalty {length_penalty}")
|
||||
|
||||
@pytest.mark.slow
|
||||
def test_no_repeat_ngram(self):
|
||||
for ngram_size in [1, 2]:
|
||||
self.run_beam_search(f"--no_repeat_ngram_size {ngram_size}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
|
|||
Loading…
Reference in a new issue