mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-06 00:03:22 +00:00
Sampling op (#13426)
### Description
<!-- Describe your changes. -->
Sampling op for cpu and cuda
support huggingface case and custom case
### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
parent
4d2dc8bbbd
commit
68518a1b72
41 changed files with 2404 additions and 509 deletions
|
|
@ -75,10 +75,13 @@ set(contrib_ops_excluded_files
|
|||
"transformers/beam_search.h"
|
||||
"transformers/generation_device_helper.cc"
|
||||
"transformers/generation_device_helper.h"
|
||||
"transformers/beam_search_impl.cu"
|
||||
"transformers/beam_search_impl.h"
|
||||
"transformers/generation_cuda_impl.cu"
|
||||
"transformers/generation_cuda_impl.h"
|
||||
"transformers/greedy_search.cc"
|
||||
"transformers/greedy_search.h"
|
||||
"transformers/sampling.cc"
|
||||
"transformers/sampling.h"
|
||||
"transformers/sampling_cuda_helper.h"
|
||||
"transformers/dump_cuda_tensor.cc"
|
||||
"transformers/dump_cuda_tensor.h"
|
||||
"conv_transpose_with_dynamic_pads.cc"
|
||||
|
|
|
|||
|
|
@ -74,6 +74,7 @@ Do not modify directly.*
|
|||
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
|
||||
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
|
||||
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
|
||||
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
|
||||
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
|
||||
* <a href="#com.microsoft.Snpe">com.microsoft.Snpe</a>
|
||||
* <a href="#com.microsoft.SparseToDenseMatMul">com.microsoft.SparseToDenseMatMul</a>
|
||||
|
|
@ -3810,6 +3811,89 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.Sampling"></a><a name="com.microsoft.sampling">**com.microsoft.Sampling**</a>
|
||||
|
||||
Greedy Sampling for text generation.
|
||||
|
||||
#### Version
|
||||
|
||||
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
|
||||
|
||||
#### Attributes
|
||||
|
||||
<dl>
|
||||
<dt><tt>custom</tt> : int</dt>
|
||||
<dd>If 1 custom sampling logic</dd>
|
||||
<dt><tt>decoder</tt> : graph (required)</dt>
|
||||
<dd>Decoder subgraph to execute in a loop.</dd>
|
||||
<dt><tt>decoder_start_token_id</tt> : int</dt>
|
||||
<dd>The id of the token that indicates decoding starts.</dd>
|
||||
<dt><tt>encoder</tt> : graph</dt>
|
||||
<dd>The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
|
||||
<dt><tt>eos_token_id</tt> : int (required)</dt>
|
||||
<dd>The id of the end-of-sequence token</dd>
|
||||
<dt><tt>filter_value</tt> : float</dt>
|
||||
<dd>All filtered values will be set to this float value.</dd>
|
||||
<dt><tt>init_decoder</tt> : graph</dt>
|
||||
<dd>The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs</dd>
|
||||
<dt><tt>min_tokens_to_keep</tt> : int</dt>
|
||||
<dd>Minimumber of tokens we keep per batch example in the output.</dd>
|
||||
<dt><tt>model_type</tt> : int</dt>
|
||||
<dd>Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart</dd>
|
||||
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
|
||||
<dd>no repeat ngrams size</dd>
|
||||
<dt><tt>pad_token_id</tt> : int (required)</dt>
|
||||
<dd>The id of the padding token</dd>
|
||||
<dt><tt>presence_penalty</tt> : float</dt>
|
||||
<dd>Presence penalty for custom sampling</dd>
|
||||
<dt><tt>temperature</tt> : float</dt>
|
||||
<dd>The value used to module the next token probabilities.</dd>
|
||||
<dt><tt>top_p</tt> : float</dt>
|
||||
<dd>If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.</dd>
|
||||
<dt><tt>vocab_size</tt> : int</dt>
|
||||
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
|
||||
</dl>
|
||||
|
||||
#### Inputs (2 - 8)
|
||||
|
||||
<dl>
|
||||
<dt><tt>input_ids</tt> : I</dt>
|
||||
<dd>The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)</dd>
|
||||
<dt><tt>max_length</tt> : I</dt>
|
||||
<dd>The maximum length of the sequence to be generated. Shape is (1)</dd>
|
||||
<dt><tt>min_length</tt> (optional) : I</dt>
|
||||
<dd>The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)</dd>
|
||||
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
|
||||
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
|
||||
<dt><tt>vocab_mask</tt> (optional) : I</dt>
|
||||
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
|
||||
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
|
||||
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
|
||||
<dt><tt>attention_mask</tt> (optional) : I</dt>
|
||||
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
|
||||
<dt><tt>presence_mask</tt> (optional) : I</dt>
|
||||
<dd>Presence penalty mask. Shape is (batch_size, vocab_size)</dd>
|
||||
</dl>
|
||||
|
||||
#### Outputs (1 - 2)
|
||||
|
||||
<dl>
|
||||
<dt><tt>sequences</tt> : I</dt>
|
||||
<dd>Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)</dd>
|
||||
<dt><tt>filtered_logits</tt> (optional) : T</dt>
|
||||
<dd>Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)</dd>
|
||||
</dl>
|
||||
|
||||
#### Type Constraints
|
||||
|
||||
<dl>
|
||||
<dt><tt>T</tt> : tensor(float)</dt>
|
||||
<dd>Constrain input and output types to float tensors.</dd>
|
||||
<dt><tt>I</tt> : tensor(int32)</dt>
|
||||
<dd>Constrain to integer types</dd>
|
||||
</dl>
|
||||
|
||||
|
||||
### <a name="com.microsoft.SkipLayerNormalization"></a><a name="com.microsoft.skiplayernormalization">**com.microsoft.SkipLayerNormalization**</a>
|
||||
|
||||
Skip and Layer Normalization Fusion
|
||||
|
|
|
|||
|
|
@ -438,6 +438,7 @@ Do not modify directly.*
|
|||
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|
||||
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|
||||
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|
||||
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|
||||
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
|
||||
|
|
@ -797,6 +798,7 @@ Do not modify directly.*
|
|||
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|
||||
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|
||||
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|
||||
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
|
||||
|
|
@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range)>,
|
||||
|
|
|
|||
|
|
@ -61,16 +61,17 @@ void BeamSearch::Init(const OpKernelInfo& info) {
|
|||
parameters_.ParseFromAttributes(info);
|
||||
|
||||
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
|
||||
ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
|
||||
parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
|
||||
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt ||
|
||||
parameters_.model_type == IGenerationParameters::kModelTypeT5);
|
||||
|
||||
ONNX_NAMESPACE::GraphProto proto;
|
||||
if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
|
||||
|
||||
if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
|
||||
// Make sure the encoder sub-graph attribute is present for the T5 model.
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
|
||||
}
|
||||
|
||||
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
|
||||
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
|
||||
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
|
||||
if (info.GetAttr<ONNX_NAMESPACE::GraphProto>("init_decoder", &proto).IsOK()) {
|
||||
has_init_decoder_ = true;
|
||||
|
|
@ -87,7 +88,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
|
|||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) {
|
||||
const auto& node = Node();
|
||||
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
|
||||
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
|
||||
if (attribute_name == "decoder") {
|
||||
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
|
||||
|
|
@ -113,8 +114,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
|
|||
init_run_gpt_subgraph_ = std::move(res.second);
|
||||
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
|
||||
}
|
||||
|
||||
} else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) {
|
||||
} else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
|
||||
if (attribute_name == "encoder") {
|
||||
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
|
||||
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
|
|
@ -167,7 +167,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
// Make a copy of parameters since we will update it based on inputs later
|
||||
BeamSearchParameters parameters = parameters_;
|
||||
|
||||
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
|
||||
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
|
||||
if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32
|
||||
BeamSearchGpt<float> impl{
|
||||
*ctx_internal,
|
||||
|
|
|
|||
|
|
@ -191,7 +191,8 @@ Status BeamSearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
|
|||
context.Input<Tensor>(0), // input_ids
|
||||
context.Input<Tensor>(7), // vocab_mask
|
||||
context.Input<Tensor>(8), // prefix_vocab_mask
|
||||
context.Input<Tensor>(9))); // attention_mask
|
||||
context.Input<Tensor>(9), // attention_mask
|
||||
nullptr)); // presence_mask
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ Status BeamSearchParameters::Validate() const {
|
|||
}
|
||||
|
||||
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
|
||||
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IBeamSearchParameters::kModelTypeGpt));
|
||||
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IGenerationParameters::kModelTypeGpt));
|
||||
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
|
||||
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
|
||||
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
|
||||
|
|
@ -35,7 +35,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
|||
batch_size = static_cast<int>(dims[0]);
|
||||
|
||||
// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
|
||||
sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;
|
||||
sequence_length = (this->model_type == IGenerationParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;
|
||||
|
||||
auto* max_length_tensor = context->Input<Tensor>(1);
|
||||
max_length = max_length_tensor ? static_cast<int>(*max_length_tensor->Data<int32_t>()) : kMaxSequenceLength;
|
||||
|
|
@ -71,10 +71,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
|
|||
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
|
||||
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
|
||||
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
|
||||
if (vocab_size == -1) {
|
||||
if (vocab_size == -1 || vocab_size == 0) {
|
||||
vocab_size = vocabulary_size;
|
||||
}
|
||||
|
||||
num_heads = heads;
|
||||
head_size = hidden_size_per_head;
|
||||
num_layers = layers;
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
struct BeamSearchParameters : public IBeamSearchParameters {
|
||||
struct BeamSearchParameters : public IGenerationParameters {
|
||||
Status Validate() const;
|
||||
|
||||
int BatchBeamSize() const { return batch_size * num_beams; }
|
||||
|
|
|
|||
|
|
@ -87,7 +87,8 @@ class GenerateBase {
|
|||
const Tensor* input_ids,
|
||||
const Tensor* vocab_mask,
|
||||
const Tensor* prefix_vocab_mask,
|
||||
const Tensor* attention_mask) const {
|
||||
const Tensor* attention_mask,
|
||||
const Tensor* presence_mask) const {
|
||||
const auto& dims = input_ids->Shape().GetDims();
|
||||
if (dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
|
|
@ -149,6 +150,28 @@ class GenerateBase {
|
|||
}
|
||||
}
|
||||
|
||||
if (presence_mask != nullptr) {
|
||||
const auto& dims_presence = presence_mask->Shape().GetDims();
|
||||
if (dims_presence.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'presence_mask' is expected to have 2 dimensions, got ", dims_presence.size());
|
||||
}
|
||||
|
||||
// presence_mask first dimension should be same as the first dimension of input_ids
|
||||
if (static_cast<int>(dims_presence[0]) != static_cast<int>(dims[0])) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"input_ids and presence_mask must have the same batch_size");
|
||||
}
|
||||
|
||||
if (static_cast<int>(dims_presence[1]) != parameters->vocab_size) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Input 'presence_mask' shape[1] shall be vocab_size, got ", dims_presence[1]);
|
||||
}
|
||||
|
||||
// store prefix vocab mask in parameters.
|
||||
parameters->presence_mask = presence_mask->DataAsSpan<int32_t>();
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,11 +6,13 @@
|
|||
#include <memory>
|
||||
#include "core/providers/cpu/math/top_k.h"
|
||||
#include "core/providers/cpu/math/softmax_shared.h"
|
||||
#include "core/providers/cpu/generator/random.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/common/gsl.h"
|
||||
#include "contrib_ops/cpu/transformers/sequences.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
|
||||
#include "contrib_ops/cpu/transformers/generation_device_helper.h"
|
||||
#include "contrib_ops/cpu/transformers/sampling_cpu_helper.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
|
||||
|
||||
|
|
@ -175,6 +177,13 @@ Status CreateGptInputs(
|
|||
|
||||
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
|
||||
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
|
||||
if (num_beams == 1) {
|
||||
expanded_input_ids = std::move(input_ids);
|
||||
expanded_position_ids = std::move(position_ids);
|
||||
expanded_attention_mask = std::move(attention_mask);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
ExpandInputs<int32_t>(input_ids, num_beams, allocator, expanded_input_ids);
|
||||
ExpandInputs<int32_t>(position_ids, num_beams, allocator, expanded_position_ids);
|
||||
ExpandInputs<int32_t>(attention_mask, num_beams, allocator, expanded_attention_mask);
|
||||
|
|
@ -243,7 +252,7 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
transformers::IBeamScorer* beam_scorer, // beam scorer
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper) { // tensor dumper
|
||||
|
|
@ -400,17 +409,16 @@ template <typename T>
|
|||
Status GreedySearchProcessLogits(
|
||||
const OrtValue& logits, // logits output of subgraph
|
||||
transformers::IGreedySearchState<T>* greedy_state, // state
|
||||
transformers::ISamplingState<T>* sampling_state, // sampling_state
|
||||
transformers::ISequences* sequences, // sequences
|
||||
AllocatorPtr& allocator, // default allocator
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
bool do_sampling, // whether to do sampling
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper) { // tensor dumper
|
||||
#ifndef DEBUG_GENERATION
|
||||
ORT_UNUSED_PARAMETER(dumper);
|
||||
#endif
|
||||
|
||||
int batch_size = parameters->batch_size;
|
||||
int vocab_size = parameters->vocab_size;
|
||||
|
|
@ -448,6 +456,18 @@ Status GreedySearchProcessLogits(
|
|||
dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size);
|
||||
#endif
|
||||
|
||||
if (do_sampling) {
|
||||
ORT_RETURN_IF_ERROR(SamplingCpuHelper::Sample(allocator,
|
||||
thread_pool,
|
||||
next_token_scores,
|
||||
sampling_state,
|
||||
greedy_state,
|
||||
parameters,
|
||||
dumper));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// next_tokens = torch.argmax(scores, dim=-1)
|
||||
int64_t next_token_scores_dims[] = {static_cast<int64_t>(batch_size), vocab_size};
|
||||
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
|
||||
|
|
@ -460,28 +480,27 @@ Status GreedySearchProcessLogits(
|
|||
next_token_scores_value);
|
||||
const Tensor& input = next_token_scores_value.Get<Tensor>();
|
||||
|
||||
constexpr int axis = 1;
|
||||
constexpr unsigned top_k = 1;
|
||||
constexpr int axis = 1;
|
||||
constexpr bool largest = true;
|
||||
constexpr bool sorted = false;
|
||||
|
||||
Tensor topk_scores;
|
||||
Tensor topk_indices;
|
||||
ORT_RETURN_IF_ERROR(
|
||||
TopK(&input,
|
||||
axis,
|
||||
top_k,
|
||||
largest,
|
||||
sorted,
|
||||
allocator,
|
||||
stream,
|
||||
thread_pool,
|
||||
topk_scores,
|
||||
topk_indices));
|
||||
ORT_RETURN_IF_ERROR(TopK(&input,
|
||||
axis,
|
||||
top_k,
|
||||
largest,
|
||||
sorted,
|
||||
allocator,
|
||||
stream,
|
||||
thread_pool,
|
||||
topk_scores,
|
||||
topk_indices));
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("topk_scores", topk_scores);
|
||||
dumper->Print("topk_indices", topk_indices);
|
||||
dumper->Print("topk_scores", topk_scores);
|
||||
dumper->Print("topk_indices", topk_indices);
|
||||
#endif
|
||||
|
||||
gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
|
||||
|
|
@ -829,7 +848,7 @@ template Status ProcessLogits<float>(
|
|||
onnxruntime::concurrency::ThreadPool* thread_pool,
|
||||
transformers::ILogitsProcessorList* logits_processors,
|
||||
transformers::IBeamScorer* beam_scorer,
|
||||
const transformers::IBeamSearchParameters* parameters,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
int step,
|
||||
Stream* stream,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
|
@ -837,11 +856,13 @@ template Status ProcessLogits<float>(
|
|||
template Status GreedySearchProcessLogits<float>(
|
||||
const OrtValue& logits,
|
||||
transformers::IGreedySearchState<float>* greedy_state,
|
||||
transformers::ISamplingState<float>* sampling_state,
|
||||
transformers::ISequences* sequences,
|
||||
AllocatorPtr& allocator,
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool,
|
||||
transformers::ILogitsProcessorList* logits_processors,
|
||||
const transformers::IBeamSearchParameters* parameters,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
bool do_sampling,
|
||||
int step,
|
||||
Stream* ort_stream,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ using ProcessLogitsFunc = std::function<Status(
|
|||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
transformers::IBeamScorer* beam_scorer, // beam scorer
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper)>; // tensor dumper
|
||||
|
|
@ -92,11 +92,13 @@ template <typename T>
|
|||
using GreedySearchProcessLogitsFunc = std::function<Status(
|
||||
const OrtValue& logits, // logits output of subgraph
|
||||
transformers::IGreedySearchState<T>* greedy_state, // state
|
||||
transformers::ISamplingState<T>* sampling_state, // sampling buffers
|
||||
transformers::ISequences* sequences, // sequences
|
||||
AllocatorPtr& allocator, // default allocator
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
bool do_sampling, // whether to do sampling
|
||||
int step, // iteration counter
|
||||
Stream* ort_stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper)>; // tensor dumper
|
||||
|
|
@ -203,7 +205,7 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
transformers::IBeamScorer* beam_scorer, // beam scorer
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper); // tensor dumper
|
||||
|
|
@ -211,11 +213,13 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
template <typename T>
|
||||
Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph
|
||||
transformers::IGreedySearchState<T>* greedy_state, // state
|
||||
transformers::ISamplingState<T>* sampling_state, // sampling buffers
|
||||
transformers::ISequences* sequences, // sequences
|
||||
AllocatorPtr& allocator, // default allocator
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
bool do_sampling, // whether to do sampling
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper); // tensor dumper
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
#pragma once
|
||||
|
||||
#include <utility>
|
||||
#include <random>
|
||||
#include "core/common/gsl.h"
|
||||
#include "core/framework/allocator.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
|
|
@ -33,7 +34,7 @@ struct IBeamSearchState {
|
|||
gsl::span<float> scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
|
||||
gsl::span<float> remaining_scores; // portion of scores that is available for appending next token scores.
|
||||
gsl::span<float> topk_buffer; // temp buffer for topk computation, including:
|
||||
// 1st stage needs:
|
||||
// 1st stage needs:
|
||||
// temp score: (batch_size * num_beams * parts_vocab, 2 * num_beams)
|
||||
// temp token: (batch_size * num_beams * parts_vocab, 2 * num_beams)
|
||||
// 2nd stage needs:
|
||||
|
|
@ -65,6 +66,28 @@ struct IGreedySearchState {
|
|||
gsl::span<int32_t> next_tokens; // shape (batch_size)
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct ISamplingState {
|
||||
gsl::span<int> d_index_in;
|
||||
gsl::span<int> d_index_out;
|
||||
gsl::span<int> d_offset;
|
||||
gsl::span<T> d_sorted_score;
|
||||
gsl::span<float> d_sorted_softmaxed_score;
|
||||
gsl::span<float> d_softmaxed_score;
|
||||
gsl::span<float> h_softmaxed_score;
|
||||
gsl::span<float> d_sampled;
|
||||
gsl::span<float> h_sampled_all;
|
||||
gsl::span<int64_t> d_indices;
|
||||
gsl::span<int> d_presence_mask;
|
||||
|
||||
BufferUniquePtr storage_buffer;
|
||||
size_t temp_storage_bytes;
|
||||
std::default_random_engine generator;
|
||||
|
||||
gsl::span<T> sorted_scores;
|
||||
gsl::span<T> cumulative_probs;
|
||||
};
|
||||
|
||||
class ISequences {
|
||||
public:
|
||||
virtual ~ISequences() {}
|
||||
|
|
@ -96,7 +119,7 @@ class IBeamScorer {
|
|||
Tensor* output_sequence_scores) = 0;
|
||||
};
|
||||
|
||||
struct IBeamSearchParameters {
|
||||
struct IGenerationParameters {
|
||||
static constexpr int kModelTypeGpt = 0;
|
||||
static constexpr int kModelTypeT5 = 1;
|
||||
|
||||
|
|
@ -120,6 +143,7 @@ struct IBeamSearchParameters {
|
|||
|
||||
gsl::span<const int32_t> vocab_mask;
|
||||
gsl::span<const int32_t> prefix_vocab_mask;
|
||||
gsl::span<const int32_t> presence_mask;
|
||||
|
||||
// Parameters from outputs.
|
||||
bool output_scores; // whether scores existed in output
|
||||
|
|
@ -129,6 +153,15 @@ struct IBeamSearchParameters {
|
|||
int num_heads;
|
||||
int head_size;
|
||||
int num_layers;
|
||||
|
||||
// Parameters for TopK/TopP sampling.
|
||||
float presence_penalty;
|
||||
float filter_value;
|
||||
float temperature = 1.0f;
|
||||
float top_p = 0.0f;
|
||||
int seed = 0;
|
||||
int min_tokens_to_keep = 1;
|
||||
bool custom_sampling = false;
|
||||
};
|
||||
|
||||
class IConsoleDumper {
|
||||
|
|
|
|||
|
|
@ -81,16 +81,15 @@ void GreedySearch::Init(const OpKernelInfo& info) {
|
|||
parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size);
|
||||
|
||||
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
|
||||
ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
|
||||
parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
|
||||
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt);
|
||||
|
||||
ONNX_NAMESPACE::GraphProto proto;
|
||||
if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
|
||||
if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
|
||||
// Make sure the encoder sub-graph attribute is present for the T5 model.
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
|
||||
}
|
||||
|
||||
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
|
||||
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
|
||||
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
|
||||
if (info.GetAttr<ONNX_NAMESPACE::GraphProto>("init_decoder", &proto).IsOK()) {
|
||||
has_init_decoder_ = true;
|
||||
|
|
@ -105,7 +104,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
|
|||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) {
|
||||
const auto& node = Node();
|
||||
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { // GPT-2
|
||||
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2
|
||||
if (attribute_name == "decoder") {
|
||||
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
|
||||
|
|
@ -131,8 +130,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
|
|||
init_run_gpt_subgraph_ = std::move(res.second);
|
||||
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
|
||||
}
|
||||
|
||||
} else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { // encoder-decoder like T5
|
||||
} else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5
|
||||
ORT_THROW("Not Implemented");
|
||||
// if (attribute_name == "encoder") {
|
||||
// ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
|
||||
|
|
@ -186,7 +184,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
|
|||
if (parameters_.model_type == 0) { // GPT-2
|
||||
// Subgraph has constraint that the output is either float or float16
|
||||
if (!gpt_subgraph_->IsOutputFloat16()) {
|
||||
GreedySearchGpt<float> impl{
|
||||
GreedySearchGpt<float, GreedySearchParameters> impl{
|
||||
*ctx_internal,
|
||||
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
|
||||
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
|
||||
|
|
@ -207,7 +205,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
|
|||
|
||||
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
} else {
|
||||
GreedySearchGpt<MLFloat16> impl{
|
||||
GreedySearchGpt<MLFloat16, GreedySearchParameters> impl{
|
||||
*ctx_internal,
|
||||
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
|
||||
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
|
||||
|
|
|
|||
|
|
@ -21,7 +21,6 @@ namespace transformers {
|
|||
|
||||
using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
|
||||
|
||||
// bugbug: refactor
|
||||
class GreedySearch : public IControlFlowKernel {
|
||||
public:
|
||||
explicit GreedySearch(const OpKernelInfo& info)
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <random>
|
||||
#include <vector>
|
||||
#include "contrib_ops/cpu/transformers/generation_shared.h"
|
||||
#include "contrib_ops/cpu/transformers/generate_impl_base.h"
|
||||
|
|
@ -11,6 +12,63 @@ namespace contrib {
|
|||
|
||||
namespace transformers {
|
||||
|
||||
template <typename T>
|
||||
struct SamplingState : public ISamplingState<T> {
|
||||
void Init(AllocatorPtr allocator,
|
||||
AllocatorPtr cpu_allocator,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
int max_iter,
|
||||
int seed,
|
||||
bool is_cuda) {
|
||||
int total_count = batch_size * vocab_size;
|
||||
|
||||
this->h_softmaxed_score = AllocateBuffer<float>(cpu_allocator, h_softmaxed_score_buffer_, SafeInt<size_t>(total_count));
|
||||
|
||||
this->generator = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
|
||||
|
||||
if (is_cuda) {
|
||||
this->d_index_in = AllocateBuffer<int>(allocator, d_index_in_buffer_, SafeInt<size_t>(total_count));
|
||||
this->d_index_out = AllocateBuffer<int>(allocator, d_index_out_buffer_, SafeInt<size_t>(total_count));
|
||||
this->d_offset = AllocateBuffer<int>(allocator, d_offset_buffer_, SafeInt<size_t>(batch_size + 1));
|
||||
this->d_sorted_score = AllocateBuffer<T>(allocator, d_sorted_score_buffer_, SafeInt<size_t>(total_count));
|
||||
this->d_sorted_softmaxed_score = AllocateBuffer<float>(allocator, d_sorted_softmaxed_score_buffer_, SafeInt<size_t>(total_count));
|
||||
this->d_softmaxed_score = AllocateBuffer<float>(allocator, d_softmaxed_score_buffer_, SafeInt<size_t>(total_count));
|
||||
this->d_sampled = AllocateBuffer<float>(allocator, d_sampled_buffer_, SafeInt<size_t>(batch_size));
|
||||
this->h_sampled_all = AllocateBuffer<float>(cpu_allocator, h_sampled_all_buffer_, SafeInt<size_t>(batch_size * max_iter));
|
||||
this->d_indices = AllocateBuffer<int64_t>(allocator, d_indices_buffer_, SafeInt<size_t>(batch_size));
|
||||
this->temp_storage_bytes = 0;
|
||||
// TODO: Do not allocate this buffer if there's no presence_mask
|
||||
this->d_presence_mask = AllocateBuffer<int>(allocator, d_presence_mask_buffer_, SafeInt<size_t>(total_count));
|
||||
|
||||
std::uniform_real_distribution<float> distribution(0.0, 1.0);
|
||||
static_cast<void>(distribution(this->generator));
|
||||
for (size_t i = 0; i < this->h_sampled_all.size(); ++i) {
|
||||
this->h_sampled_all[i] = distribution(this->generator);
|
||||
}
|
||||
} else {
|
||||
// TODO: Some buffer can be reused for CPU
|
||||
this->sorted_scores = AllocateBuffer<T>(cpu_allocator, sorted_scores_buffer_, SafeInt<size_t>(total_count));
|
||||
this->cumulative_probs = AllocateBuffer<T>(cpu_allocator, cumulative_probs_buffer_, SafeInt<size_t>(total_count));
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
BufferUniquePtr d_index_in_buffer_;
|
||||
BufferUniquePtr d_index_out_buffer_;
|
||||
BufferUniquePtr d_offset_buffer_;
|
||||
BufferUniquePtr d_sorted_score_buffer_;
|
||||
BufferUniquePtr d_sorted_softmaxed_score_buffer_;
|
||||
BufferUniquePtr d_softmaxed_score_buffer_;
|
||||
BufferUniquePtr h_softmaxed_score_buffer_;
|
||||
BufferUniquePtr d_sampled_buffer_;
|
||||
BufferUniquePtr h_sampled_all_buffer_;
|
||||
BufferUniquePtr d_indices_buffer_;
|
||||
BufferUniquePtr d_presence_mask_buffer_;
|
||||
BufferUniquePtr sorted_scores_buffer_;
|
||||
BufferUniquePtr cumulative_probs_buffer_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct GreedySearchState : public IGreedySearchState<T> {
|
||||
Sequences sequences;
|
||||
|
|
@ -68,7 +126,7 @@ struct GreedySearchState : public IGreedySearchState<T> {
|
|||
};
|
||||
|
||||
// Base class of gready search implementation that is common for both GPT-2 and Bart/T5.
|
||||
template <typename T>
|
||||
template <typename T, typename ParametersT>
|
||||
class GreedySearchBase : public GenerateBase {
|
||||
public:
|
||||
GreedySearchBase(OpKernelContextInternal& context,
|
||||
|
|
@ -76,7 +134,7 @@ class GreedySearchBase : public GenerateBase {
|
|||
concurrency::ThreadPool* thread_pool,
|
||||
Stream* ort_stream,
|
||||
IConsoleDumper* cuda_dumper,
|
||||
GreedySearchParameters& params,
|
||||
ParametersT& params,
|
||||
const GenerationDeviceHelper::TopkFunc& topk_func,
|
||||
const GenerationDeviceHelper::GreedySearchProcessLogitsFunc<T>& process_logits_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func)
|
||||
|
|
@ -105,23 +163,25 @@ class GreedySearchBase : public GenerateBase {
|
|||
Status GenerateNextToken(const OrtValue& logits,
|
||||
gsl::span<int32_t>& next_tokens,
|
||||
GreedySearchState<T>& greedy_state,
|
||||
ISamplingState<T>& sampling_state,
|
||||
int counter,
|
||||
int eos_token_id);
|
||||
|
||||
// Calculate scores from logits, then apply filtering and select next token for each beam.
|
||||
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
|
||||
GreedySearchState<T>& greedy_state,
|
||||
ISamplingState<T>& sampling_state,
|
||||
AllocatorPtr& allocator,
|
||||
int counter);
|
||||
|
||||
GreedySearchParameters* parameters_;
|
||||
ParametersT* parameters_;
|
||||
|
||||
// Device specific functions
|
||||
GenerationDeviceHelper::GreedySearchProcessLogitsFunc<T> process_logits_func_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status GreedySearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
|
||||
template <typename T, typename ParametersT>
|
||||
Status GreedySearchBase<T, ParametersT>::CheckInputs(const OpKernelContextInternal& context) {
|
||||
// Input shapes:
|
||||
// input_ids : (batch_size, sequence_length)
|
||||
// vocab_mask : (vocab_size) or nullptr
|
||||
|
|
@ -129,13 +189,14 @@ Status GreedySearchBase<T>::CheckInputs(const OpKernelContextInternal& context)
|
|||
context.Input<Tensor>(0), // input_ids
|
||||
context.Input<Tensor>(4), // vocab_mask
|
||||
context.Input<Tensor>(5), // prefix_vocab_mask
|
||||
nullptr)); // attention_mask
|
||||
context.Input<Tensor>(6), // attention_mask
|
||||
context.Input<Tensor>(7))); // presence_mask
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status GreedySearchBase<T>::Initialize() {
|
||||
template <typename T, typename ParametersT>
|
||||
Status GreedySearchBase<T, ParametersT>::Initialize() {
|
||||
ORT_RETURN_IF_ERROR(this->context_.GetTempSpaceAllocator(&this->temp_space_allocator_));
|
||||
|
||||
ORT_RETURN_IF_ERROR(CheckScalarInput("max_length", 1, true));
|
||||
|
|
@ -155,26 +216,29 @@ Status GreedySearchBase<T>::Initialize() {
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status GreedySearchBase<T>::ProcessLogits(
|
||||
template <typename T, typename ParametersT>
|
||||
Status GreedySearchBase<T, ParametersT>::ProcessLogits(
|
||||
const OrtValue& logits,
|
||||
GreedySearchState<T>& greedy_state,
|
||||
ISamplingState<T>& sampling_state,
|
||||
AllocatorPtr& allocator,
|
||||
int counter) {
|
||||
return process_logits_func_(logits, &greedy_state, &(greedy_state.sequences), allocator,
|
||||
this->thread_pool_, &this->logits_processors_,
|
||||
parameters_, counter, this->ort_stream_, this->GetConsoleDumper());
|
||||
bool use_sampling = std::is_same<ParametersT, SamplingParameters>::value;
|
||||
return process_logits_func_(logits, &greedy_state, &sampling_state, &(greedy_state.sequences), allocator,
|
||||
this->thread_pool_, &this->logits_processors_, parameters_,
|
||||
use_sampling, counter, this->ort_stream_, this->GetConsoleDumper());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status GreedySearchBase<T>::GenerateNextToken(
|
||||
template <typename T, typename ParametersT>
|
||||
Status GreedySearchBase<T, ParametersT>::GenerateNextToken(
|
||||
const OrtValue& logits,
|
||||
gsl::span<int32_t>& next_tokens,
|
||||
GreedySearchState<T>& greedy_state,
|
||||
ISamplingState<T>& sampling_state,
|
||||
int counter,
|
||||
int eos_token_id) {
|
||||
// Process logits to get next token scores
|
||||
ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, this->temp_space_allocator_, counter));
|
||||
ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, sampling_state, this->temp_space_allocator_, counter));
|
||||
|
||||
next_tokens = greedy_state.next_tokens;
|
||||
for (size_t i = 0; i < next_tokens.size(); i++) {
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ std::pair<Status, std::unique_ptr<GptSubgraph>> CreateGptSubgraphAndUpdateParame
|
|||
} // namespace gpt_details
|
||||
|
||||
// Greedy search implementation for GPT-2 model.
|
||||
template <typename T>
|
||||
class GreedySearchGpt : public GreedySearchBase<T> {
|
||||
template <typename T, typename ParametersT>
|
||||
class GreedySearchGpt : public GreedySearchBase<T, ParametersT> {
|
||||
public:
|
||||
GreedySearchGpt(OpKernelContextInternal& context,
|
||||
const SessionState* init_run_decoder_session_state,
|
||||
|
|
@ -35,7 +35,7 @@ class GreedySearchGpt : public GreedySearchBase<T> {
|
|||
concurrency::ThreadPool* thread_pool,
|
||||
Stream* ort_stream,
|
||||
IConsoleDumper* cuda_dumper,
|
||||
GreedySearchParameters& params,
|
||||
ParametersT& params,
|
||||
const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func,
|
||||
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const GenerationDeviceHelper::TopkFunc& topk_func,
|
||||
|
|
@ -43,15 +43,15 @@ class GreedySearchGpt : public GreedySearchBase<T> {
|
|||
const GenerationDeviceHelper::InitGreedyStateFunc<T>& init_greedy_state_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const GenerationDeviceHelper::UpdateGptFeedsFunc<T>& update_feeds_func)
|
||||
: GreedySearchBase<T>(context,
|
||||
decoder_session_state,
|
||||
thread_pool,
|
||||
ort_stream,
|
||||
cuda_dumper,
|
||||
params,
|
||||
topk_func,
|
||||
process_logits_func,
|
||||
device_copy_func),
|
||||
: GreedySearchBase<T, ParametersT>(context,
|
||||
decoder_session_state,
|
||||
thread_pool,
|
||||
ort_stream,
|
||||
cuda_dumper,
|
||||
params,
|
||||
topk_func,
|
||||
process_logits_func,
|
||||
device_copy_func),
|
||||
init_run_decoder_session_state_(init_run_decoder_session_state),
|
||||
init_run_gpt_subgraph_(init_run_gpt_subgraph),
|
||||
gpt_subgraph_(gpt_subgraph),
|
||||
|
|
@ -94,8 +94,8 @@ class GreedySearchGpt : public GreedySearchBase<T> {
|
|||
GenerationDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
Status GreedySearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
|
||||
template <typename T, typename ParametersT>
|
||||
Status GreedySearchGpt<T, ParametersT>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
|
||||
OrtValue& expanded_input_ids,
|
||||
std::vector<OrtValue>& feeds,
|
||||
IAllocatorUniquePtr<char>& buffer) {
|
||||
|
|
@ -134,8 +134,8 @@ Status GreedySearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengt
|
|||
this->parameters_->max_length);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status GreedySearchGpt<T>::UpdateFeeds(
|
||||
template <typename T, typename ParametersT>
|
||||
Status GreedySearchGpt<T, ParametersT>::UpdateFeeds(
|
||||
const std::vector<OrtValue>& last_outputs,
|
||||
std::vector<OrtValue>& next_inputs,
|
||||
int current_length,
|
||||
|
|
@ -161,11 +161,11 @@ Status GreedySearchGpt<T>::UpdateFeeds(
|
|||
);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status GreedySearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager,
|
||||
const FeedsFetchesManager& feeds_fetches_manager) {
|
||||
template <typename T, typename ParametersT>
|
||||
Status GreedySearchGpt<T, ParametersT>::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager,
|
||||
const FeedsFetchesManager& feeds_fetches_manager) {
|
||||
auto status = Status::OK();
|
||||
const GreedySearchParameters* parameters = this->parameters_;
|
||||
const ParametersT* parameters = this->parameters_;
|
||||
|
||||
// Allocate output tensors.
|
||||
int64_t sequences_dims[] = {parameters->batch_size, parameters->max_length};
|
||||
|
|
@ -184,6 +184,17 @@ Status GreedySearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fet
|
|||
parameters->max_length,
|
||||
this->IsCuda());
|
||||
|
||||
SamplingState<T> sampling_state;
|
||||
if (std::is_same<ParametersT, SamplingParameters>::value) {
|
||||
sampling_state.Init(this->temp_space_allocator_,
|
||||
this->cpu_allocator_,
|
||||
static_cast<int>(parameters->BatchBeamSize()),
|
||||
static_cast<int>(parameters->vocab_size),
|
||||
static_cast<int>(parameters->max_length - parameters->sequence_length),
|
||||
parameters->seed,
|
||||
this->IsCuda());
|
||||
}
|
||||
|
||||
IAllocatorUniquePtr<char> buffer;
|
||||
OrtValue expanded_input_ids_in_cpu;
|
||||
ORT_RETURN_IF_ERROR(CreateInitialFeeds(greedy_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
|
||||
|
|
@ -276,6 +287,7 @@ Status GreedySearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fet
|
|||
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
|
||||
next_tokens,
|
||||
greedy_state,
|
||||
sampling_state,
|
||||
iteration_counter,
|
||||
parameters->eos_token_id));
|
||||
|
||||
|
|
@ -324,6 +336,27 @@ Status GreedySearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fet
|
|||
gsl::copy(sequence_source, batch_output);
|
||||
}
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
// Debug the one step filtered logits for sampling
|
||||
int64_t filtered_logits_dims[] = {parameters->batch_size, parameters->vocab_size};
|
||||
TensorShape filtered_logits_shape(&filtered_logits_dims[0],
|
||||
sizeof(filtered_logits_dims) / sizeof(filtered_logits_dims[0]));
|
||||
Tensor* filtered_logits = this->context_.Output(1, filtered_logits_shape);
|
||||
if (filtered_logits != nullptr) {
|
||||
gsl::span<float> filtered_logits_span = filtered_logits->MutableDataAsSpan<float>();
|
||||
for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) {
|
||||
auto batch_output = filtered_logits_span.subspan(
|
||||
static_cast<size_t>(batch_id) * parameters->vocab_size,
|
||||
parameters->vocab_size);
|
||||
gsl::span<const float> batch_filtered_logits = gsl::make_span(sampling_state.h_softmaxed_score.data() +
|
||||
batch_id * parameters->vocab_size,
|
||||
parameters->vocab_size);
|
||||
|
||||
gsl::copy(batch_filtered_logits, batch_output);
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
return status;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,8 +6,12 @@
|
|||
#include "core/common/narrow.h"
|
||||
#include "core/common/safeint.h"
|
||||
#include "core/common/span_utils.h"
|
||||
#include "core/providers/cpu/math/softmax_shared.h"
|
||||
#include "contrib_ops/cpu/transformers/logits_processor.h"
|
||||
#include "contrib_ops/cpu/transformers/dump_tensor.h"
|
||||
#include <vector>
|
||||
#include <numeric>
|
||||
#include <algorithm>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -187,6 +191,53 @@ void PrefixVocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
|
|||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
TemperatureLogitsProcessor<T>::TemperatureLogitsProcessor(float temperature) : temperature_(temperature) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void TemperatureLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
|
||||
NextTokenScores<T>& next_token_scores) {
|
||||
if (temperature_ == 1.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
T* p = next_token_scores.scores.data();
|
||||
for (size_t i = 0; i < next_token_scores.scores.size(); i++) {
|
||||
*p /= temperature_;
|
||||
++p;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
DumpScores("TemperatureLogitsProcessor", next_token_scores);
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
PresencePenaltyLogitsProcessor<T>::PresencePenaltyLogitsProcessor(const gsl::span<const int32_t>& presence_mask,
|
||||
float presence_penalty)
|
||||
: presence_mask_(presence_mask), presence_penalty_(presence_penalty) {
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void PresencePenaltyLogitsProcessor<T>::Process(const ISequences*,
|
||||
NextTokenScores<T>& next_token_scores) {
|
||||
if (presence_penalty_ == 0.0f) {
|
||||
return;
|
||||
}
|
||||
|
||||
assert(!presence_mask_.empty());
|
||||
|
||||
T* p = next_token_scores.scores.data();
|
||||
for (size_t i = 0; i < next_token_scores.scores.size(); i++) {
|
||||
*p -= presence_mask_[i] * presence_penalty_;
|
||||
}
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
DumpScores("PresencePenaltyLogitsProcessor", next_token_scores);
|
||||
#endif
|
||||
}
|
||||
|
||||
void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
|
||||
LogitsProcessorInitImpl<BeamSearchParameters>(parameters);
|
||||
}
|
||||
|
|
@ -195,6 +246,10 @@ void LogitsProcessorList::Init(const GreedySearchParameters& parameters) {
|
|||
LogitsProcessorInitImpl<GreedySearchParameters>(parameters);
|
||||
}
|
||||
|
||||
void LogitsProcessorList::Init(const SamplingParameters& parameters) {
|
||||
LogitsProcessorInitImpl<SamplingParameters>(parameters);
|
||||
}
|
||||
|
||||
void LogitsProcessorList::Process(const ISequences* sequences,
|
||||
gsl::span<float>& next_token_scores,
|
||||
int step) {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "contrib_ops/cpu/transformers/sequences.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
|
||||
#include "contrib_ops/cpu/transformers/greedy_search_parameters.h"
|
||||
#include "contrib_ops/cpu/transformers/sampling_parameters.h"
|
||||
#include "contrib_ops/cpu/transformers/generation_shared.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
|
|
@ -96,11 +97,53 @@ class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor<T> {
|
|||
const int batch_size_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class TemperatureLogitsProcessor : public ILogitsProcessor<T> {
|
||||
public:
|
||||
TemperatureLogitsProcessor(float temperature);
|
||||
|
||||
void Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) override;
|
||||
|
||||
private:
|
||||
float temperature_;
|
||||
};
|
||||
|
||||
// template <typename T>
|
||||
// class TopPLogitsProcessor : public ILogitsProcessor<T> {
|
||||
// public:
|
||||
// TopPLogitsProcessor(float top_p, float filter_value,
|
||||
// onnxruntime::concurrency::ThreadPool* thread_pool);
|
||||
|
||||
// void Process(const ISequences* sequences,
|
||||
// NextTokenScores<T>& next_token_scores) override;
|
||||
|
||||
// private:
|
||||
// float top_p_;
|
||||
// float filter_value_;
|
||||
// onnxruntime::concurrency::ThreadPool* thread_pool_;
|
||||
// };
|
||||
|
||||
template <typename T>
|
||||
class PresencePenaltyLogitsProcessor : public ILogitsProcessor<T> {
|
||||
public:
|
||||
PresencePenaltyLogitsProcessor(const gsl::span<const int32_t>& presence_mask,
|
||||
float presence_penalty);
|
||||
|
||||
void Process(const ISequences* sequences,
|
||||
NextTokenScores<T>& next_token_scores) override;
|
||||
|
||||
private:
|
||||
gsl::span<const int32_t> presence_mask_;
|
||||
float presence_penalty_;
|
||||
};
|
||||
|
||||
class LogitsProcessorList : public ILogitsProcessorList {
|
||||
public:
|
||||
LogitsProcessorList() = default;
|
||||
void Init(const BeamSearchParameters& parameters);
|
||||
void Init(const GreedySearchParameters& parameters);
|
||||
void Init(const SamplingParameters& parameters);
|
||||
void Process(const ISequences* sequences, gsl::span<float>& next_token_scores, int step);
|
||||
|
||||
private:
|
||||
|
|
@ -140,6 +183,19 @@ class LogitsProcessorList : public ILogitsProcessorList {
|
|||
processor_list_.push_back(min_length_processor_.get());
|
||||
}
|
||||
|
||||
if (parameters.temperature > 0) {
|
||||
temperature_processor_ = std::make_unique<TemperatureLogitsProcessor<float>>(parameters.temperature);
|
||||
processor_list_.push_back(temperature_processor_.get());
|
||||
}
|
||||
|
||||
if (!parameters.presence_mask.empty()) {
|
||||
presence_penalty_processor_ = std::make_unique<
|
||||
PresencePenaltyLogitsProcessor<float>
|
||||
>(parameters.presence_mask,
|
||||
parameters.presence_penalty);
|
||||
processor_list_.push_back(presence_penalty_processor_.get());
|
||||
}
|
||||
|
||||
batch_beam_size_ = parameters.BatchBeamSize();
|
||||
vocab_size_ = parameters.vocab_size;
|
||||
}
|
||||
|
|
@ -153,6 +209,8 @@ class LogitsProcessorList : public ILogitsProcessorList {
|
|||
std::unique_ptr<VocabMaskLogitsProcessor<float>> vocab_mask_processor_;
|
||||
std::unique_ptr<PrefixVocabMaskLogitsProcessor<float>> prefix_vocab_mask_processor_;
|
||||
std::unique_ptr<MinLengthLogitsProcessor<float>> min_length_processor_;
|
||||
std::unique_ptr<TemperatureLogitsProcessor<float>> temperature_processor_;
|
||||
std::unique_ptr<PresencePenaltyLogitsProcessor<float>> presence_penalty_processor_;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
|
|
|
|||
175
onnxruntime/contrib_ops/cpu/transformers/sampling.cc
Normal file
175
onnxruntime/contrib_ops/cpu/transformers/sampling.cc
Normal file
|
|
@ -0,0 +1,175 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
// there's no way to use a raw pointer as the copy destination with std::copy_n
|
||||
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
|
||||
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
|
||||
#ifdef _MSC_VER
|
||||
#pragma warning(push)
|
||||
#pragma warning(disable : 4996)
|
||||
#endif
|
||||
|
||||
#include "core/framework/op_kernel_context_internal.h"
|
||||
#include "core/framework/utils.h"
|
||||
#include "contrib_ops/cpu/transformers/sampling.h"
|
||||
#include "contrib_ops/cpu/transformers/logits_processor.h"
|
||||
#include "contrib_ops/cpu/transformers/sequences.h"
|
||||
#include "contrib_ops/cpu/transformers/dump_tensor.h"
|
||||
#include "contrib_ops/cpu/transformers/greedy_search_impl_gpt.h"
|
||||
|
||||
using namespace ONNX_NAMESPACE;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
Sampling, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T, \
|
||||
kCpuExecutionProvider, \
|
||||
(*KernelDefBuilder::Create()) \
|
||||
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
|
||||
transformers::Sampling);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float)
|
||||
|
||||
namespace transformers {
|
||||
|
||||
void Sampling::Init(const OpKernelInfo& info) {
|
||||
parameters_.ParseFromAttributes(info);
|
||||
parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size);
|
||||
|
||||
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
|
||||
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt);
|
||||
|
||||
ONNX_NAMESPACE::GraphProto proto;
|
||||
if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
|
||||
// Make sure the encoder sub-graph attribute is present for the T5 model.
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
|
||||
}
|
||||
|
||||
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
|
||||
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
|
||||
if (info.GetAttr<ONNX_NAMESPACE::GraphProto>("init_decoder", &proto).IsOK()) {
|
||||
has_init_decoder_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure the decoder sub-graph attribute is present for all model types.
|
||||
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("decoder", &proto).IsOK());
|
||||
}
|
||||
|
||||
Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state,
|
||||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) {
|
||||
const auto& node = Node();
|
||||
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2
|
||||
if (attribute_name == "decoder") {
|
||||
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
|
||||
subgraph_session_state, parameters_);
|
||||
|
||||
auto status = res.first;
|
||||
if (!status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
gpt_subgraph_ = std::move(res.second);
|
||||
decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
|
||||
} else if (attribute_name == "init_decoder") {
|
||||
ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
|
||||
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
|
||||
subgraph_session_state, parameters_);
|
||||
|
||||
auto status = res.first;
|
||||
if (!status.IsOK()) {
|
||||
return status;
|
||||
}
|
||||
|
||||
init_run_gpt_subgraph_ = std::move(res.second);
|
||||
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
|
||||
}
|
||||
} else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5
|
||||
ORT_THROW("Not Implemented");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status Sampling::Compute(OpKernelContext* ctx) const {
|
||||
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
|
||||
|
||||
auto* decoder_session_state = ctx_internal->SubgraphSessionState("decoder");
|
||||
ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
|
||||
ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
|
||||
auto* init_run_decoder_session_state = ctx_internal->SubgraphSessionState("init_decoder");
|
||||
if (has_init_decoder_) {
|
||||
ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
|
||||
ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
|
||||
ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_
|
||||
&& init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
|
||||
"past_present_share_buffer mode must be same for init decoder and decoder subgraphes");
|
||||
}
|
||||
|
||||
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
|
||||
|
||||
// make a copy since we will update the parameters based on inputs later
|
||||
SamplingParameters parameters = parameters_;
|
||||
|
||||
if (parameters_.model_type == 0) { // GPT-2
|
||||
// Subgraph has constraint that the output is either float or float16
|
||||
if (!gpt_subgraph_->IsOutputFloat16()) {
|
||||
GreedySearchGpt<float, SamplingParameters> impl{
|
||||
*ctx_internal,
|
||||
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
|
||||
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
|
||||
*decoder_session_state,
|
||||
*gpt_subgraph_,
|
||||
thread_pool,
|
||||
ctx->GetComputeStream(),
|
||||
dumper_,
|
||||
parameters,
|
||||
GenerationCpuDeviceHelper::CreateGptInputs,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::GreedySearchProcessLogits<float>,
|
||||
init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState<float>,
|
||||
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
|
||||
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
} else {
|
||||
GreedySearchGpt<MLFloat16, SamplingParameters> impl{
|
||||
*ctx_internal,
|
||||
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
|
||||
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
|
||||
*decoder_session_state,
|
||||
*gpt_subgraph_,
|
||||
thread_pool,
|
||||
ctx->GetComputeStream(),
|
||||
dumper_,
|
||||
parameters,
|
||||
GenerationCpuDeviceHelper::CreateGptInputs,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_fp16_func_,
|
||||
init_greedy_state_fp16_func_,
|
||||
device_copy_func_,
|
||||
update_gpt_feeds_fp16_func_};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
103
onnxruntime/contrib_ops/cpu/transformers/sampling.h
Normal file
103
onnxruntime/contrib_ops/cpu/transformers/sampling.h
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cpu/controlflow/utils.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
|
||||
#include "contrib_ops/cpu/transformers/generation_device_helper.h"
|
||||
#include "contrib_ops/cpu/transformers/sampling_parameters.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class FeedsFetchesManager;
|
||||
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
|
||||
|
||||
class Sampling : public IControlFlowKernel {
|
||||
public:
|
||||
explicit Sampling(const OpKernelInfo& info)
|
||||
: IControlFlowKernel(info),
|
||||
decoder_feeds_fetches_manager_(nullptr),
|
||||
dumper_(nullptr) {
|
||||
Init(info);
|
||||
}
|
||||
|
||||
void Init(const OpKernelInfo& info);
|
||||
|
||||
Status Compute(OpKernelContext* ctx) const override;
|
||||
|
||||
Status SetupSubgraphExecutionInfo(const SessionState& session_state,
|
||||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) override;
|
||||
|
||||
protected:
|
||||
void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; }
|
||||
|
||||
// device helpers that is same for both GPT and encoder-decoder models.
|
||||
void SetDeviceHelpers(
|
||||
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const GenerationDeviceHelper::TopkFunc& topk_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
const GenerationDeviceHelper::GreedySearchProcessLogitsFunc<float>& process_logits_func,
|
||||
const GenerationDeviceHelper::GreedySearchProcessLogitsFunc<MLFloat16>& process_logits_fp16_func,
|
||||
const GenerationDeviceHelper::InitGreedyStateFunc<float>& init_greedy_state_func,
|
||||
const GenerationDeviceHelper::InitGreedyStateFunc<MLFloat16>& init_greedy_state_fp16_func) {
|
||||
add_to_feeds_func_ = add_to_feeds_func;
|
||||
topk_func_ = topk_func;
|
||||
device_copy_func_ = device_copy_func;
|
||||
process_logits_func_ = process_logits_func;
|
||||
process_logits_fp16_func_ = process_logits_fp16_func;
|
||||
init_greedy_state_func_ = init_greedy_state_func;
|
||||
init_greedy_state_fp16_func_ = init_greedy_state_fp16_func;
|
||||
}
|
||||
|
||||
void SetDeviceHelpers_Gpt(
|
||||
const GenerationDeviceHelper::UpdateGptFeedsFunc<float>& update_gpt_feeds_func,
|
||||
const GenerationDeviceHelper::UpdateGptFeedsFunc<MLFloat16>& update_gpt_feeds_fp16_func) {
|
||||
update_gpt_feeds_func_ = update_gpt_feeds_func;
|
||||
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
|
||||
}
|
||||
|
||||
private:
|
||||
// Device specific functions
|
||||
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
GenerationDeviceHelper::TopkFunc topk_func_;
|
||||
GenerationDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
|
||||
|
||||
GenerationDeviceHelper::GreedySearchProcessLogitsFunc<float> process_logits_func_;
|
||||
GenerationDeviceHelper::GreedySearchProcessLogitsFunc<MLFloat16> process_logits_fp16_func_;
|
||||
|
||||
GenerationDeviceHelper::InitGreedyStateFunc<float> init_greedy_state_func_;
|
||||
GenerationDeviceHelper::InitGreedyStateFunc<MLFloat16> init_greedy_state_fp16_func_;
|
||||
|
||||
//------------------------------------------------------------
|
||||
// Device specific functions for GPT
|
||||
//------------------------------------------------------------
|
||||
GenerationDeviceHelper::UpdateGptFeedsFunc<float> update_gpt_feeds_func_;
|
||||
GenerationDeviceHelper::UpdateGptFeedsFunc<MLFloat16> update_gpt_feeds_fp16_func_;
|
||||
|
||||
//------------------------------------------------------------
|
||||
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
|
||||
//------------------------------------------------------------
|
||||
std::unique_ptr<GptSubgraph> init_run_gpt_subgraph_;
|
||||
std::unique_ptr<GptSubgraph> gpt_subgraph_;
|
||||
|
||||
FeedsFetchesManager* decoder_feeds_fetches_manager_;
|
||||
FeedsFetchesManager* init_run_decoder_feeds_fetches_manager_;
|
||||
|
||||
IConsoleDumper* dumper_;
|
||||
|
||||
SamplingParameters parameters_;
|
||||
|
||||
bool has_init_decoder_ = false;
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
161
onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
Normal file
161
onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace SamplingCpuHelper {
|
||||
|
||||
template <typename T>
|
||||
void filter_scores(std::vector<size_t>& sorted_indice,
|
||||
gsl::span<T>& next_token_score,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
size_t index) {
|
||||
size_t real_index = sorted_indice[index];
|
||||
next_token_score[real_index] = (T)parameters->filter_value;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void cumulate_and_filter_custom(gsl::span<T>& next_token_scores,
|
||||
gsl::span<T>& cumulative_probs,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
std::vector<size_t>& sorted_indices) {
|
||||
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
|
||||
size_t offset = i * parameters->vocab_size;
|
||||
if (cumulative_probs[offset] > parameters->top_p) {
|
||||
filter_scores(sorted_indices, next_token_scores, parameters, 1 + offset);
|
||||
}
|
||||
for (size_t j = 1; j < static_cast<size_t>(parameters->vocab_size) - 1; j++) {
|
||||
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
|
||||
if (cumulative_probs[j + offset] > parameters->top_p) {
|
||||
filter_scores(sorted_indices, next_token_scores, parameters, j + offset + 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void cumulate_and_filter(gsl::span<T>& next_token_scores,
|
||||
gsl::span<T>& cumulative_probs,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
std::vector<size_t>& sorted_indices) {
|
||||
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
|
||||
size_t offset = i * parameters->vocab_size;
|
||||
if (cumulative_probs[offset] <= 1 - parameters->top_p) {
|
||||
filter_scores(sorted_indices, next_token_scores, parameters, offset);
|
||||
}
|
||||
for (size_t j = 1; j < static_cast<size_t>(parameters->vocab_size) - static_cast<size_t>(parameters->min_tokens_to_keep); j++) {
|
||||
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
|
||||
if (cumulative_probs[j + offset] <= 1 - parameters->top_p) {
|
||||
filter_scores(sorted_indices, next_token_scores, parameters, j + offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status Sample(AllocatorPtr& allocator,
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool,
|
||||
gsl::span<T>& next_token_scores,
|
||||
transformers::ISamplingState<T>* sampling_state,
|
||||
transformers::IGreedySearchState<T>* greedy_state,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
const transformers::IConsoleDumper* dumper) {
|
||||
ORT_UNUSED_PARAMETER(dumper);
|
||||
|
||||
gsl::span<T>& sorted_scores = sampling_state->sorted_scores;
|
||||
memcpy(sorted_scores.data(), next_token_scores.data(), next_token_scores.size_bytes());
|
||||
std::vector<size_t> sorted_indices(static_cast<size_t>(parameters->batch_size) * static_cast<size_t>(parameters->vocab_size));
|
||||
|
||||
std::function<bool(T, T)> predicator;
|
||||
if (parameters->custom_sampling) {
|
||||
predicator = std::greater<T>();
|
||||
} else {
|
||||
predicator = std::less<T>();
|
||||
}
|
||||
|
||||
// TODO: This could be optimized with allocated buffer and handwritten sort algorithm
|
||||
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
|
||||
auto indices_begin = sorted_indices.begin() + i * parameters->vocab_size;
|
||||
auto indices_end = sorted_indices.begin() + (i + 1) * parameters->vocab_size;
|
||||
std::iota(indices_begin, indices_end, 0);
|
||||
std::sort(indices_begin, indices_end,
|
||||
[&next_token_scores, &predicator](size_t i1, size_t i2) {
|
||||
return !predicator(next_token_scores[i1], next_token_scores[i2]);
|
||||
});
|
||||
|
||||
std::sort(sorted_scores.begin() + i * parameters->vocab_size,
|
||||
sorted_scores.begin() + (i + 1) * parameters->vocab_size,
|
||||
predicator);
|
||||
}
|
||||
|
||||
gsl::span<T>& cumulative_probs = sampling_state->cumulative_probs;
|
||||
|
||||
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters->batch_size,
|
||||
parameters->vocab_size,
|
||||
sorted_scores.data(),
|
||||
cumulative_probs.data(),
|
||||
false,
|
||||
thread_pool));
|
||||
|
||||
if (parameters->custom_sampling) {
|
||||
cumulate_and_filter_custom(next_token_scores, cumulative_probs, parameters, sorted_indices);
|
||||
} else {
|
||||
cumulate_and_filter(next_token_scores, cumulative_probs, parameters, sorted_indices);
|
||||
}
|
||||
|
||||
gsl::span<T>& next_token_probs = sampling_state->h_softmaxed_score;
|
||||
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters->batch_size,
|
||||
parameters->vocab_size,
|
||||
next_token_scores.data(),
|
||||
next_token_probs.data(),
|
||||
false,
|
||||
thread_pool));
|
||||
|
||||
// torch.multinomial()
|
||||
int64_t next_token_probs_dims[] = {static_cast<int64_t>(parameters->batch_size), parameters->vocab_size};
|
||||
TensorShape next_token_probs_shape(&next_token_probs_dims[0], 2);
|
||||
auto element_type = DataTypeImpl::GetType<T>();
|
||||
OrtValue next_token_probs_value;
|
||||
Tensor::InitOrtValue(element_type,
|
||||
next_token_probs_shape,
|
||||
next_token_probs.data(),
|
||||
allocator->Info(),
|
||||
next_token_probs_value);
|
||||
const Tensor& input = next_token_probs_value.Get<Tensor>();
|
||||
|
||||
std::default_random_engine& generator = sampling_state->generator;
|
||||
|
||||
int64_t sampled_idx_dims[] = {static_cast<int64_t>(parameters->batch_size), 1};
|
||||
TensorShape sampled_idx_shape(&sampled_idx_dims[0], 2);
|
||||
|
||||
gsl::span<int64_t>& next_token_idx = greedy_state->next_tokens_cpu;
|
||||
|
||||
OrtValue sampled_idx_ov;
|
||||
Tensor::InitOrtValue(DataTypeImpl::GetType<int64_t>(),
|
||||
sampled_idx_shape,
|
||||
next_token_idx.data(),
|
||||
allocator->Info(),
|
||||
sampled_idx_ov);
|
||||
Tensor* sampled_idx = sampled_idx_ov.GetMutable<Tensor>();
|
||||
|
||||
// Copy the allocator because MultinomialComputeShared() uses move(allocator)
|
||||
AllocatorPtr allocatortemp = allocator;
|
||||
ORT_RETURN_IF_ERROR(MultinomialComputeShared<int64_t>(allocatortemp,
|
||||
input,
|
||||
parameters->batch_size,
|
||||
parameters->vocab_size,
|
||||
1,
|
||||
generator,
|
||||
*sampled_idx));
|
||||
// TODO: update presense_mask()
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("sampled_idx", *sampled_idx);
|
||||
#endif
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace SamplingCudaHelper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#include "contrib_ops/cpu/transformers/sampling_parameters.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) {
|
||||
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", 0));
|
||||
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
|
||||
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
|
||||
decoder_start_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("decoder_start_token_id", -1));
|
||||
no_repeat_ngram_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_repeat_ngram_size", 0));
|
||||
temperature = info.GetAttrOrDefault<float>("temperature", 1.0f);
|
||||
top_p = info.GetAttrOrDefault<float>("top_p", 0.0f);
|
||||
filter_value = info.GetAttrOrDefault<float>("filter_value", -std::numeric_limits<float>::infinity());
|
||||
min_tokens_to_keep = static_cast<int>(info.GetAttrOrDefault<int64_t>("min_tokens_to_keep", 0));
|
||||
presence_penalty = info.GetAttrOrDefault<float>("presence_penalty", 0.0f);
|
||||
custom_sampling = static_cast<int>(info.GetAttrOrDefault<int64_t>("custom", 0));
|
||||
vocab_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("vocab_size", -1));
|
||||
}
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,19 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "contrib_ops/cpu/transformers/greedy_search_parameters.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace transformers {
|
||||
|
||||
struct SamplingParameters : public GreedySearchParameters {
|
||||
void ParseFromAttributes(const OpKernelInfo& info);
|
||||
};
|
||||
|
||||
} // namespace transformers
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -75,6 +75,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
|
||||
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
|
||||
|
|
@ -189,6 +190,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,
|
||||
|
|
|
|||
|
|
@ -1,286 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "cub/util_type.cuh"
|
||||
#include "contrib_ops/cuda/transformers/beam_search_impl.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
__global__ void InitKernel(float* beam_scores,
|
||||
int num_beams,
|
||||
int total_elements) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < total_elements) {
|
||||
int beam_index = index % num_beams;
|
||||
beam_scores[index] = beam_index > 0 ? static_cast<float>(-1e9) : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
void LaunchInitKernel(
|
||||
float* beam_scores,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * num_beams;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
InitKernel<<<gridSize, blockSize, 0, stream>>>(beam_scores, num_beams, total_elements);
|
||||
}
|
||||
|
||||
__global__ void NextTokenKernel(const int64_t* next_token_indices,
|
||||
int32_t* next_indices,
|
||||
int32_t* next_tokens,
|
||||
int vocab_size,
|
||||
int total_elements) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < total_elements) {
|
||||
next_indices[index] = next_token_indices[index] / vocab_size;
|
||||
next_tokens[index] = next_token_indices[index] % vocab_size;
|
||||
}
|
||||
}
|
||||
|
||||
void LaunchNextTokenKernel(const int64_t* next_token_indices,
|
||||
int32_t* next_indices,
|
||||
int32_t* next_tokens,
|
||||
int batch_size,
|
||||
int top_k,
|
||||
int vocab_size,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * top_k;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
NextTokenKernel<<<gridSize, blockSize, 0, stream>>>(next_token_indices,
|
||||
next_indices,
|
||||
next_tokens,
|
||||
vocab_size,
|
||||
total_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LogitsProcessKernel(
|
||||
T* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int total_elements,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < total_elements) {
|
||||
int batch_beam_index = index / padded_vocab_size;
|
||||
int word_id = index % padded_vocab_size;
|
||||
|
||||
if (word_id >= vocab_size) {
|
||||
// Set any value within the padding region to the lowest value so that it isn't picked
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
} else {
|
||||
// RepetitionPenaltyLogitsProcessor
|
||||
if (repetition_penalty != 1.0f) {
|
||||
int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length;
|
||||
bool found = false;
|
||||
for (int i = 0; i < current_sequence_length; i++) {
|
||||
if (current_sequence[i] == word_id) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
float score = (float)next_token_scores[index];
|
||||
next_token_scores[index] = (T)(score < 0 ? score * repetition_penalty : score / repetition_penalty);
|
||||
}
|
||||
}
|
||||
|
||||
// NoRepeatNGramLogitsProcessor
|
||||
if (no_repeat_ngram_size > 0 && current_sequence_length >= no_repeat_ngram_size) {
|
||||
int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length;
|
||||
bool found = false;
|
||||
for (int i = no_repeat_ngram_size - 1; i < current_sequence_length; i++) {
|
||||
if (current_sequence[i] == word_id) { // last token of n-gram matched
|
||||
found = true;
|
||||
for (int j = 0; j < no_repeat_ngram_size - 1; j++) { // match the remaining N-1 tokens
|
||||
if (current_sequence[i - j - 1] != current_sequence[current_sequence_length - 1 - j]) {
|
||||
found = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (found) {
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// VocabMaskLogitsProcessor
|
||||
if (vocab_mask != nullptr && vocab_mask[word_id] == 0) {
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
return;
|
||||
}
|
||||
|
||||
// PrefixVocabMaskLogitsProcessor
|
||||
int batch_id = batch_beam_index / num_beams;
|
||||
if (prefix_vocab_mask != nullptr && prefix_vocab_mask[batch_id * vocab_size + word_id] == 0) {
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
return;
|
||||
}
|
||||
|
||||
// MinLengthLogitsProcessor
|
||||
if (word_id == demote_token_id) {
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LaunchLogitsProcessKernel(
|
||||
T* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * num_beams * padded_vocab_size;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(
|
||||
next_token_scores,
|
||||
vocab_mask,
|
||||
prefix_vocab_mask,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
padded_vocab_size,
|
||||
total_elements,
|
||||
demote_token_id,
|
||||
sequences,
|
||||
max_sequence_length,
|
||||
current_sequence_length,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size);
|
||||
}
|
||||
|
||||
// Instantiation
|
||||
template void LaunchLogitsProcessKernel(
|
||||
float* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void LaunchLogitsProcessKernel(
|
||||
half* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
__global__ void AddProbsKernel(float* log_probs,
|
||||
float* cum_log_probs,
|
||||
const int vocab_size,
|
||||
const int total_elements) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int batch_beam_index = index / vocab_size;
|
||||
|
||||
if (index < total_elements)
|
||||
log_probs[index] += cum_log_probs[batch_beam_index];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LaunchAddProbsKernel(T* log_probs,
|
||||
T* cum_log_probs,
|
||||
const int batch_size,
|
||||
const int num_beams,
|
||||
const int vocab_size,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * num_beams * vocab_size;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
AddProbsKernel<<<gridSize, blockSize, 0, stream>>>(log_probs, cum_log_probs, vocab_size, total_elements);
|
||||
}
|
||||
|
||||
template void LaunchAddProbsKernel(
|
||||
float* log_probs,
|
||||
float* cum_log_probs,
|
||||
const int batch_size,
|
||||
const int num_beams,
|
||||
const int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void UpdateGptInputsKernel(const T* old_mask_data,
|
||||
T* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < batch_beam_size * current_length) {
|
||||
// Update attention mask.
|
||||
int i = index / current_length;
|
||||
int j = index % current_length;
|
||||
mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast<T>(1);
|
||||
|
||||
if (next_positions != nullptr) {
|
||||
// Update sequence length (or next positions).
|
||||
if (index < batch_beam_size) {
|
||||
next_positions[index]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
|
||||
int32_t* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length,
|
||||
cudaStream_t stream) {
|
||||
assert(current_length > 0);
|
||||
int total_elements = batch_beam_size * current_length;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
UpdateGptInputsKernel<int32_t><<<gridSize, blockSize, 0, stream>>>(
|
||||
old_mask_data, mask_data, next_positions, batch_beam_size, current_length);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -1,61 +0,0 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
void LaunchInitKernel(
|
||||
float* beam_scores,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void LaunchAddProbsKernel(T* log_probs,
|
||||
T* cum_log_probs,
|
||||
const int batch_size,
|
||||
const int num_beams,
|
||||
const int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void LaunchLogitsProcessKernel(
|
||||
T* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
void LaunchNextTokenKernel(const int64_t* next_token_indices,
|
||||
int32_t* next_indices,
|
||||
int32_t* next_tokens,
|
||||
int batch_size,
|
||||
int top_k,
|
||||
int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
|
||||
int32_t* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,758 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "cub/util_type.cuh"
|
||||
#include <cub/cub.cuh>
|
||||
#include <cub/device/device_segmented_radix_sort.cuh>
|
||||
#include "contrib_ops/cuda/transformers/generation_cuda_impl.h"
|
||||
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
__global__ void InitKernel(float* beam_scores,
|
||||
int num_beams,
|
||||
int total_elements) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < total_elements) {
|
||||
int beam_index = index % num_beams;
|
||||
beam_scores[index] = beam_index > 0 ? static_cast<float>(-1e9) : 0.0f;
|
||||
}
|
||||
}
|
||||
|
||||
void LaunchInitKernel(
|
||||
float* beam_scores,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * num_beams;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
InitKernel<<<gridSize, blockSize, 0, stream>>>(beam_scores, num_beams, total_elements);
|
||||
}
|
||||
|
||||
__global__ void NextTokenKernel(const int64_t* next_token_indices,
|
||||
int32_t* next_indices,
|
||||
int32_t* next_tokens,
|
||||
int vocab_size,
|
||||
int total_elements) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < total_elements) {
|
||||
next_indices[index] = next_token_indices[index] / vocab_size;
|
||||
next_tokens[index] = next_token_indices[index] % vocab_size;
|
||||
}
|
||||
}
|
||||
|
||||
void LaunchNextTokenKernel(const int64_t* next_token_indices,
|
||||
int32_t* next_indices,
|
||||
int32_t* next_tokens,
|
||||
int batch_size,
|
||||
int top_k,
|
||||
int vocab_size,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * top_k;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
NextTokenKernel<<<gridSize, blockSize, 0, stream>>>(next_token_indices,
|
||||
next_indices,
|
||||
next_tokens,
|
||||
vocab_size,
|
||||
total_elements);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void LogitsProcessKernel(
|
||||
T* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
const int* presence_mask,
|
||||
float presence_penalty,
|
||||
float temperature,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int total_elements,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < total_elements) {
|
||||
int batch_beam_index = index / padded_vocab_size;
|
||||
int word_id = index % padded_vocab_size;
|
||||
|
||||
if (word_id >= vocab_size) {
|
||||
// Set any value within the padding region to the lowest value so that it isn't picked
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
} else {
|
||||
// RepetitionPenaltyLogitsProcessor
|
||||
if (repetition_penalty != 1.0f) {
|
||||
int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length;
|
||||
bool found = false;
|
||||
for (int i = 0; i < current_sequence_length; i++) {
|
||||
if (current_sequence[i] == word_id) {
|
||||
found = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
float score = (float)next_token_scores[index];
|
||||
next_token_scores[index] = (T)(score < 0 ? score * repetition_penalty : score / repetition_penalty);
|
||||
}
|
||||
}
|
||||
|
||||
// NoRepeatNGramLogitsProcessor
|
||||
if (no_repeat_ngram_size > 0 && current_sequence_length >= no_repeat_ngram_size) {
|
||||
int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length;
|
||||
bool found = false;
|
||||
for (int i = no_repeat_ngram_size - 1; i < current_sequence_length; i++) {
|
||||
if (current_sequence[i] == word_id) { // last token of n-gram matched
|
||||
found = true;
|
||||
for (int j = 0; j < no_repeat_ngram_size - 1; j++) { // match the remaining N-1 tokens
|
||||
if (current_sequence[i - j - 1] != current_sequence[current_sequence_length - 1 - j]) {
|
||||
found = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (found) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (found) {
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// VocabMaskLogitsProcessor
|
||||
if (vocab_mask != nullptr && vocab_mask[word_id] == 0) {
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
return;
|
||||
}
|
||||
|
||||
// PrefixVocabMaskLogitsProcessor
|
||||
int batch_id = batch_beam_index / num_beams;
|
||||
if (prefix_vocab_mask != nullptr && prefix_vocab_mask[batch_id * vocab_size + word_id] == 0) {
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
return;
|
||||
}
|
||||
|
||||
// MinLengthLogitsProcessor
|
||||
if (word_id == demote_token_id) {
|
||||
next_token_scores[index] = cub::FpLimits<T>::Lowest();
|
||||
}
|
||||
|
||||
// PresencePenaltyLogitsProcessor
|
||||
if (presence_mask != nullptr && presence_mask[index] == 1) {
|
||||
float score = (float)next_token_scores[index] - presence_penalty;
|
||||
next_token_scores[index] = (T)score;
|
||||
}
|
||||
|
||||
// TemperatureLogitsProcessor
|
||||
if (temperature != 1.0f) {
|
||||
float score = (float)(next_token_scores[index]);
|
||||
next_token_scores[index] = (T)(score / temperature);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LaunchLogitsProcessKernel(
|
||||
T* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int* presence_mask,
|
||||
float presence_penalty,
|
||||
float temperature,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * num_beams * vocab_size;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(
|
||||
next_token_scores,
|
||||
vocab_mask,
|
||||
prefix_vocab_mask,
|
||||
presence_mask,
|
||||
presence_penalty,
|
||||
temperature,
|
||||
num_beams,
|
||||
vocab_size,
|
||||
padded_vocab_size,
|
||||
total_elements,
|
||||
demote_token_id,
|
||||
sequences,
|
||||
max_sequence_length,
|
||||
current_sequence_length,
|
||||
repetition_penalty,
|
||||
no_repeat_ngram_size);
|
||||
}
|
||||
|
||||
// Instantiation
|
||||
template void LaunchLogitsProcessKernel(
|
||||
float* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int* presence_mask,
|
||||
float presence_penalty,
|
||||
float temperature,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template void LaunchLogitsProcessKernel(
|
||||
half* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int* presence_mask,
|
||||
float presence_penalty,
|
||||
float temperature,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
__global__ void AddProbsKernel(float* log_probs,
|
||||
float* cum_log_probs,
|
||||
const int vocab_size,
|
||||
const int total_elements) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int batch_beam_index = index / vocab_size;
|
||||
|
||||
if (index < total_elements)
|
||||
log_probs[index] += cum_log_probs[batch_beam_index];
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LaunchAddProbsKernel(T* log_probs,
|
||||
T* cum_log_probs,
|
||||
const int batch_size,
|
||||
const int num_beams,
|
||||
const int vocab_size,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * num_beams * vocab_size;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
AddProbsKernel<<<gridSize, blockSize, 0, stream>>>(log_probs, cum_log_probs, vocab_size, total_elements);
|
||||
}
|
||||
|
||||
template void LaunchAddProbsKernel(
|
||||
float* log_probs,
|
||||
float* cum_log_probs,
|
||||
const int batch_size,
|
||||
const int num_beams,
|
||||
const int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
__global__ void UpdateGptInputsKernel(const T* old_mask_data,
|
||||
T* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index < batch_beam_size * current_length) {
|
||||
// Update attention mask.
|
||||
int i = index / current_length;
|
||||
int j = index % current_length;
|
||||
mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast<T>(1);
|
||||
|
||||
if (next_positions != nullptr) {
|
||||
// Update sequence length (or next positions).
|
||||
if (index < batch_beam_size) {
|
||||
next_positions[index]++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
|
||||
int32_t* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length,
|
||||
cudaStream_t stream) {
|
||||
assert(current_length > 0);
|
||||
int total_elements = batch_beam_size * current_length;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
UpdateGptInputsKernel<int32_t><<<gridSize, blockSize, 0, stream>>>(
|
||||
old_mask_data, mask_data, next_positions, batch_beam_size, current_length);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void GetTempStorageSize(const T *d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes) {
|
||||
if (is_descending) {
|
||||
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
|
||||
temp_storage_bytes,
|
||||
d_keys_in,
|
||||
(T*)nullptr,
|
||||
d_values_in,
|
||||
(int*)nullptr,
|
||||
num_items,
|
||||
num_segments,
|
||||
d_offsets,
|
||||
d_offsets + 1,
|
||||
0,
|
||||
sizeof(T) * 8,
|
||||
stream);
|
||||
} else {
|
||||
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
|
||||
temp_storage_bytes,
|
||||
d_keys_in,
|
||||
(T*)nullptr,
|
||||
d_values_in,
|
||||
(int*)nullptr,
|
||||
num_items,
|
||||
num_segments,
|
||||
d_offsets,
|
||||
d_offsets + 1,
|
||||
0,
|
||||
sizeof(T) * 8,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void GetTempStorageSize(
|
||||
const float *d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes);
|
||||
|
||||
template void GetTempStorageSize(
|
||||
const half *d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes);
|
||||
|
||||
// TODO: merge to one kernel
|
||||
__global__ void SetupParamsKernel(int* d_values_in,
|
||||
int* d_offsets,
|
||||
int batch_size,
|
||||
int vocab_size) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int total_elements = batch_size * vocab_size;
|
||||
if (index < total_elements) {
|
||||
d_values_in[index] = index % vocab_size;
|
||||
}
|
||||
if (index < batch_size + 1) {
|
||||
d_offsets[index] = index * vocab_size;
|
||||
}
|
||||
}
|
||||
|
||||
void LaunchSetupParamsKernel(int* d_values_in,
|
||||
int* d_offsets,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
cudaStream_t stream) {
|
||||
int total_elements = batch_size * vocab_size;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
SetupParamsKernel<<<gridSize, blockSize, 0, stream>>>(d_values_in,
|
||||
d_offsets,
|
||||
batch_size,
|
||||
vocab_size);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LaunchSortPairs(void *d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const T *d_keys_in,
|
||||
T *d_keys_out,
|
||||
const int *d_values_in,
|
||||
int *d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int *d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending) {
|
||||
if (is_descending) {
|
||||
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
|
||||
temp_storage_bytes,
|
||||
d_keys_in,
|
||||
d_keys_out,
|
||||
d_values_in,
|
||||
d_values_out,
|
||||
num_items,
|
||||
num_segments,
|
||||
d_offsets,
|
||||
d_offsets + 1,
|
||||
0,
|
||||
sizeof(T) * 8,
|
||||
stream);
|
||||
} else {
|
||||
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
|
||||
temp_storage_bytes,
|
||||
d_keys_in,
|
||||
d_keys_out,
|
||||
d_values_in,
|
||||
d_values_out,
|
||||
num_items,
|
||||
num_segments,
|
||||
d_offsets,
|
||||
d_offsets + 1,
|
||||
0,
|
||||
sizeof(T) * 8,
|
||||
stream);
|
||||
}
|
||||
}
|
||||
|
||||
template void LaunchSortPairs(void *d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const float *d_keys_in,
|
||||
float *d_keys_out,
|
||||
const int *d_values_in,
|
||||
int *d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int *d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
template void LaunchSortPairs(void *d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const half *d_keys_in,
|
||||
half *d_keys_out,
|
||||
const int *d_values_in,
|
||||
int *d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int *d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
template <typename T>
|
||||
__global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in,
|
||||
const int* d_sorted_indices,
|
||||
T* d_logits_in_out,
|
||||
float top_p_threshold,
|
||||
float filter_value,
|
||||
int batch_size,
|
||||
int vocab_size) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (index >= batch_size * vocab_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
int vocab_idx = index % vocab_size;
|
||||
int batch_id = index / vocab_size;
|
||||
int start_index = batch_id * vocab_size;
|
||||
|
||||
int count = vocab_idx;
|
||||
float sum = 0.0f;
|
||||
while (count >= 0) {
|
||||
sum += d_sorted_logits_in[start_index];
|
||||
++start_index;
|
||||
--count;
|
||||
}
|
||||
|
||||
if (sum > top_p_threshold) {
|
||||
// Shift the indices to the right by one according to the custom implementation.
|
||||
int shifted_index = index + 1;
|
||||
if (shifted_index % vocab_size != 0) {
|
||||
int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index];
|
||||
d_logits_in_out[original_index] = (T)filter_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void FilterLogitsKernel(float* d_sorted_logits_in,
|
||||
const int* d_sorted_indices,
|
||||
T* d_logits_in_out,
|
||||
float top_p_threshold,
|
||||
float filter_value,
|
||||
int min_tokens_to_keep,
|
||||
int batch_size,
|
||||
int vocab_size) {
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
|
||||
if (index >= batch_size * vocab_size) {
|
||||
return;
|
||||
}
|
||||
|
||||
int vocab_idx = index % vocab_size;
|
||||
int batch_id = index / vocab_size;
|
||||
int start_index = batch_id * vocab_size;
|
||||
|
||||
int count = vocab_idx;
|
||||
float sum = 0.0f;
|
||||
// TODO: Optimization needed. e.g. use CUB::SCAN() for cumulative probabilities.
|
||||
while (count >= 0) {
|
||||
sum += d_sorted_logits_in[start_index];
|
||||
++start_index;
|
||||
--count;
|
||||
}
|
||||
|
||||
if (sum <= top_p_threshold) {
|
||||
if (index % vocab_size + min_tokens_to_keep < vocab_size) {
|
||||
int original_index = batch_id * vocab_size + d_sorted_indices[index];
|
||||
d_logits_in_out[original_index] = (T)filter_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
|
||||
const int* d_sorted_indices,
|
||||
T* d_logits_in_out,
|
||||
float top_p,
|
||||
float filter_value,
|
||||
int min_tokens_to_keep,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
cudaStream_t stream,
|
||||
bool is_descending) {
|
||||
int total_elements = batch_size * vocab_size;
|
||||
constexpr int blockSize = 256;
|
||||
const int gridSize = (total_elements + blockSize - 1) / blockSize;
|
||||
if (is_descending) {
|
||||
FilterLogitsKernelCustom<<<gridSize, blockSize, 0, stream>>>(d_sorted_logits_in,
|
||||
d_sorted_indices,
|
||||
d_logits_in_out,
|
||||
top_p,
|
||||
filter_value,
|
||||
batch_size,
|
||||
vocab_size);
|
||||
} else {
|
||||
FilterLogitsKernel<<<gridSize, blockSize, 0, stream>>>(d_sorted_logits_in,
|
||||
d_sorted_indices,
|
||||
d_logits_in_out,
|
||||
1 - top_p,
|
||||
filter_value,
|
||||
min_tokens_to_keep,
|
||||
batch_size,
|
||||
vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
template void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
|
||||
const int* d_sorted_indices,
|
||||
float* d_logits_in_out,
|
||||
float top_p,
|
||||
float filter_value,
|
||||
int min_tokens_to_keep,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
template void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
|
||||
const int* d_sorted_indices,
|
||||
half* d_logits_in_out,
|
||||
float top_p,
|
||||
float filter_value,
|
||||
int min_tokens_to_keep,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
|
||||
// Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu
|
||||
template <typename scalar_t, typename accscalar_t>
|
||||
__global__ void sampleMultinomialOnce(int64_t* dest,
|
||||
int distributions,
|
||||
int categories,
|
||||
scalar_t* sampled,
|
||||
scalar_t* dist,
|
||||
int stride_dist, // dist->stride(0)
|
||||
int stride_categories, // dist->stride(1)
|
||||
int* d_presence_mask) {
|
||||
extern __shared__ unsigned char my_smem[];
|
||||
__shared__ bool found;
|
||||
__shared__ unsigned foundPos;
|
||||
accscalar_t *smem = reinterpret_cast<accscalar_t *>(my_smem);
|
||||
accscalar_t accZero = static_cast<accscalar_t>(0);
|
||||
scalar_t zero = static_cast<scalar_t>(0);
|
||||
for (int curDist = blockIdx.x;
|
||||
curDist < distributions; curDist += gridDim.x) {
|
||||
|
||||
// Assume sum = 1 in Top P sampling as the input is softmaxed.
|
||||
accscalar_t sum = 1;
|
||||
|
||||
// Broadcast sum and sample value
|
||||
if (threadIdx.x == 0) {
|
||||
// Make sure the sum of our distribution didn't overflow
|
||||
// CUDA_KERNEL_ASSERT(!_isinf(val));
|
||||
// CUDA_KERNEL_ASSERT(sum > accZero);
|
||||
foundPos = 0;
|
||||
smem[0] = sum;
|
||||
smem[1] = sampled[curDist];
|
||||
}
|
||||
__syncthreads();
|
||||
sum = smem[0];
|
||||
scalar_t sample = static_cast<scalar_t>(smem[1]);
|
||||
__syncthreads();
|
||||
if (sum == accZero) {
|
||||
// Choose the first element
|
||||
if (threadIdx.x == 0) {
|
||||
dest[curDist] = 0;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
int chunks = (categories + (int)blockDim.x - 1) / blockDim.x;
|
||||
accscalar_t prevHighProb = accZero;
|
||||
found = false;
|
||||
for (int chunk = 0; chunk < chunks && !found; ++chunk) {
|
||||
// All threads in bounds load a value
|
||||
int cat = chunk * blockDim.x + threadIdx.x;
|
||||
accscalar_t dist_val = cat < categories ?
|
||||
static_cast<accscalar_t>(dist[curDist * stride_dist + cat * stride_categories]) / sum :
|
||||
accZero;
|
||||
smem[threadIdx.x] = dist_val;
|
||||
__syncthreads();
|
||||
// Perform an inclusive prefix sum of the shared memory contents
|
||||
for (int offset = 1; offset < blockDim.x; offset *= 2) {
|
||||
accscalar_t val = accZero;
|
||||
if (threadIdx.x >= offset) {
|
||||
val = smem[threadIdx.x - offset] + smem[threadIdx.x];
|
||||
}
|
||||
__syncthreads();
|
||||
if (threadIdx.x >= offset) {
|
||||
smem[threadIdx.x] = val;
|
||||
}
|
||||
__syncthreads();
|
||||
}
|
||||
// Each thread will check to see if the sample falls in its bucket
|
||||
scalar_t curBucket =
|
||||
static_cast<scalar_t>(smem[threadIdx.x] + prevHighProb);
|
||||
scalar_t prevBucket = static_cast<scalar_t>(
|
||||
threadIdx.x == 0 ? prevHighProb
|
||||
: smem[threadIdx.x - 1] + prevHighProb);
|
||||
bool inBucket =
|
||||
(cat < categories) &&
|
||||
(!(sample >= curBucket) &&
|
||||
(sample >= prevBucket) &&
|
||||
(dist_val > zero));
|
||||
if (inBucket) {
|
||||
// We're done; we have the sample
|
||||
// Torch indices are 1-based
|
||||
atomicMax(&foundPos, cat);
|
||||
found = true;
|
||||
}
|
||||
// Store the previous scan's high value for future use
|
||||
prevHighProb = prevHighProb + smem[blockDim.x - 1];
|
||||
__syncthreads();
|
||||
}
|
||||
if (threadIdx.x == 0) {
|
||||
if (found) {
|
||||
dest[curDist] = foundPos;
|
||||
} else {
|
||||
// This should address a rare bug where we don't select a valid index. This likely occurs when
|
||||
// due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but
|
||||
// and our uniform sample is greater than this value. In this case we likely have unitialized memory
|
||||
// in dest[curDist]. So basically we will loop through the distribution and pick the largest index
|
||||
// where the distribution is non-zero. This is obviously terribly inefficient, but due to the
|
||||
// rarity in which this occurs, this should not be an issue.
|
||||
for (int cat = categories - 1; cat >= 0; --cat) {
|
||||
if (dist[curDist * stride_dist + cat * stride_categories] > zero) {
|
||||
dest[curDist] = cat;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update presence mask
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
if (index >= distributions * categories) {
|
||||
return;
|
||||
}
|
||||
int dist_idx = index / categories;
|
||||
int cat_idx = index % categories;
|
||||
if (dest[dist_idx] == cat_idx) {
|
||||
d_presence_mask[index] = 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Only support n_sample = 1
|
||||
void TorchMultinomialKernelLauncher(float* d_input,
|
||||
float* d_sampled,
|
||||
int64_t* d_output,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
int* d_presence_mask,
|
||||
cudaStream_t stream)
|
||||
{
|
||||
// Store the props in class variables
|
||||
int device;
|
||||
cudaGetDevice(&device);
|
||||
cudaDeviceProp props;
|
||||
cudaGetDeviceProperties(&props, device);
|
||||
|
||||
int numSM = props.multiProcessorCount;
|
||||
int maxThreads = props.maxThreadsPerBlock;
|
||||
int warp_size = 32; //at::cuda::warp_size();
|
||||
int requiredWarps = (vocab_size + warp_size - 1) / warp_size;
|
||||
int requiredThreads = std::min(maxThreads, requiredWarps * warp_size);
|
||||
int requiredShared = requiredThreads * sizeof(float);
|
||||
|
||||
dim3 block(requiredThreads);
|
||||
dim3 grid(std::min(batch_size, numSM * 4));
|
||||
|
||||
sampleMultinomialOnce<float, float>
|
||||
<<<grid, block, requiredShared, stream>>>(d_output,
|
||||
batch_size,
|
||||
vocab_size,
|
||||
d_sampled,
|
||||
d_input,
|
||||
vocab_size,
|
||||
1,
|
||||
d_presence_mask);
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
114
onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h
Normal file
114
onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <stdint.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <curand_kernel.h>
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
void LaunchInitKernel(
|
||||
float* beam_scores,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void LaunchAddProbsKernel(T* log_probs,
|
||||
T* cum_log_probs,
|
||||
const int batch_size,
|
||||
const int num_beams,
|
||||
const int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void LaunchLogitsProcessKernel(
|
||||
T* next_token_scores,
|
||||
const int* vocab_mask,
|
||||
const int* prefix_vocab_mask,
|
||||
int* presence_mask,
|
||||
float presence_penalty,
|
||||
float temperature,
|
||||
int batch_size,
|
||||
int num_beams,
|
||||
int vocab_size,
|
||||
int padded_vocab_size,
|
||||
int demote_token_id,
|
||||
int32_t* sequences,
|
||||
int max_sequence_length,
|
||||
int current_sequence_length,
|
||||
float repetition_penalty,
|
||||
int no_repeat_ngram_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
void LaunchNextTokenKernel(const int64_t* next_token_indices,
|
||||
int32_t* next_indices,
|
||||
int32_t* next_tokens,
|
||||
int batch_size,
|
||||
int top_k,
|
||||
int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
|
||||
int32_t* mask_data,
|
||||
int32_t* next_positions,
|
||||
int batch_beam_size,
|
||||
int current_length,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void GetTempStorageSize(const T *d_keys_in,
|
||||
const int* d_values_in,
|
||||
int* d_offsets,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
cudaStream_t stream,
|
||||
bool is_descending,
|
||||
size_t& temp_storage_bytes);
|
||||
|
||||
void LaunchSetupParamsKernel(int* d_values_in,
|
||||
int* d_offsets,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
cudaStream_t stream);
|
||||
|
||||
template <typename T>
|
||||
void LaunchSortPairs(void *d_temp_storage,
|
||||
size_t temp_storage_bytes,
|
||||
const T *d_keys_in,
|
||||
T *d_keys_out,
|
||||
const int *d_values_in,
|
||||
int *d_values_out,
|
||||
int num_items,
|
||||
int num_segments,
|
||||
int *d_offsets,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
template <typename T>
|
||||
void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
|
||||
const int* d_sorted_indices,
|
||||
T* d_logits_in_out,
|
||||
float top_p,
|
||||
float filter_value,
|
||||
int min_tokens_to_keep,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
cudaStream_t stream,
|
||||
bool is_descending);
|
||||
|
||||
void TorchMultinomialKernelLauncher(float* d_input,
|
||||
float* d_sampled,
|
||||
int64_t* d_output,
|
||||
int batch_size,
|
||||
int vocab_size,
|
||||
int* d_presence_mask,
|
||||
cudaStream_t stream);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -4,19 +4,25 @@
|
|||
#include <utility>
|
||||
#include <memory>
|
||||
#include "core/providers/shared_library/provider_api.h"
|
||||
#include "core/providers/cuda/cuda_kernel.h"
|
||||
#include "core/providers/cuda/math/topk_impl.h"
|
||||
#include "core/providers/cuda/math/softmax.h"
|
||||
#include "core/providers/cuda/shared_inc/accumulation_type.h"
|
||||
#include "core/framework/ort_value.h"
|
||||
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
|
||||
#include <cuda_runtime.h>
|
||||
#include "contrib_ops/cuda/transformers/beam_search_impl.h"
|
||||
#include "contrib_ops/cuda/transformers/generation_cuda_impl.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
|
||||
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
|
||||
#include "contrib_ops/cuda/transformers/beam_search_topk.h"
|
||||
#include "core/providers/cuda/nvtx_profile.h"
|
||||
#include "core/providers/cuda/nvtx_profile_context.h"
|
||||
#include "sampling_cuda_helper.h"
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
#include <iostream>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace concurrency {
|
||||
|
|
@ -128,7 +134,6 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
|
|||
cudaStream_t stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
|
||||
auto pinned_buffer = IAllocator::MakeUniquePtr<void>(pinned_allocator, total_bytes);
|
||||
char* pinned_data = static_cast<char*>(pinned_buffer.get());
|
||||
|
||||
// Copy tensors to one pinned memory buffer (so that we only need copy to GPU once)
|
||||
char* destination = pinned_data;
|
||||
for (auto& input : inputs) {
|
||||
|
|
@ -145,16 +150,13 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
|
|||
"AddToFeeds: An implementation for the input type ",
|
||||
dataType, " is not supported yet");
|
||||
}
|
||||
|
||||
// Do not need alignment because GPT has int32 inputs (past is empty) and T5 encoder has int64 inputs.
|
||||
destination += bytes;
|
||||
}
|
||||
}
|
||||
|
||||
if (!buffer) {
|
||||
buffer = provider->GetScratchBuffer<char>(total_bytes, ort_stream, WaitCudaNotificationOnDevice);
|
||||
}
|
||||
|
||||
char* gpu_data = buffer.get();
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream));
|
||||
|
||||
|
|
@ -164,7 +166,6 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
|
|||
CUDA_RETURN_IF_ERROR(cudaEventCreate(&isCopyDone));
|
||||
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, stream));
|
||||
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone));
|
||||
|
||||
// TODO(tianleiwu): allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call.
|
||||
const OrtMemoryInfo& location = provider->GetAllocator(0, OrtMemTypeDefault)->Info();
|
||||
for (auto& input : inputs) {
|
||||
|
|
@ -253,7 +254,7 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
transformers::IBeamScorer* beam_scorer, // beam scorer
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
int step, // iteration counter
|
||||
Stream* ort_stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper) { // tensor dumper
|
||||
|
|
@ -363,6 +364,9 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
next_token_scores.data(),
|
||||
parameters->vocab_mask.data(),
|
||||
step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only.
|
||||
nullptr, // parameters->presence_mask.data(),
|
||||
parameters->presence_penalty,
|
||||
parameters->temperature,
|
||||
parameters->batch_size,
|
||||
parameters->num_beams,
|
||||
parameters->vocab_size,
|
||||
|
|
@ -504,11 +508,13 @@ template <typename T>
|
|||
Status GreedySearchProcessLogits(
|
||||
const OrtValue& logits, // logits output of subgraph
|
||||
transformers::IGreedySearchState<T>* greedy_state, // state
|
||||
transformers::ISamplingState<T>* sampling_state, // buffers
|
||||
transformers::ISequences* sequences, // sequences
|
||||
AllocatorPtr& allocator, // default allocator
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
bool do_sampling, // whether to do sampling
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper) { // tensor dumper
|
||||
|
|
@ -519,7 +525,6 @@ Status GreedySearchProcessLogits(
|
|||
#endif
|
||||
|
||||
ORT_UNUSED_PARAMETER(logits_processors);
|
||||
|
||||
#ifndef DEBUG_GENERATION
|
||||
ORT_UNUSED_PARAMETER(dumper);
|
||||
#endif
|
||||
|
|
@ -596,6 +601,13 @@ Status GreedySearchProcessLogits(
|
|||
cudaMemcpyHostToDevice, cuda_stream));
|
||||
}
|
||||
|
||||
// Copy parameters->presence_mask to sampling_state->presence_mask
|
||||
gsl::span<int>& presence_mask = sampling_state->d_presence_mask;
|
||||
if (step == 1 && parameters->presence_mask.data() != nullptr) {
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(presence_mask.data(), parameters->presence_mask.data(),
|
||||
sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream));
|
||||
}
|
||||
|
||||
// TODO(hasesh): Can we avoid the const_cast by changing the interface of
|
||||
// GreedySearchProcessLogits() to take in a non-const OrtValue for logits
|
||||
// as this is the only place we will ever use the logits and it may be reasonable
|
||||
|
|
@ -605,11 +617,15 @@ Status GreedySearchProcessLogits(
|
|||
: reinterpret_cast<CudaT*>(next_token_scores.data()),
|
||||
parameters->vocab_mask.data(),
|
||||
step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only.
|
||||
parameters->presence_mask.data() ? presence_mask.data() : nullptr,
|
||||
parameters->presence_penalty,
|
||||
parameters->temperature,
|
||||
parameters->batch_size,
|
||||
parameters->num_beams,
|
||||
parameters->vocab_size,
|
||||
is_reuse_logits_buffer ? padded_vocab_size : parameters->vocab_size,
|
||||
(parameters->min_length > 0 && current_sequence_length < parameters->min_length) ? parameters->eos_token_id : -1,
|
||||
(parameters->min_length > 0 && current_sequence_length < parameters->sequence_length + parameters->min_length)
|
||||
? parameters->eos_token_id : -1,
|
||||
reinterpret_cast<int32_t*>(sequences_buffer.get()),
|
||||
parameters->max_length,
|
||||
current_sequence_length,
|
||||
|
|
@ -629,6 +645,19 @@ Status GreedySearchProcessLogits(
|
|||
// TODO(wy): support output_scores in greedy search
|
||||
ORT_UNUSED_PARAMETER(output_scores);
|
||||
|
||||
if (do_sampling) {
|
||||
ORT_RETURN_IF_ERROR(SamplingCudaHelper::Sample(allocator,
|
||||
cuda_stream,
|
||||
next_token_scores,
|
||||
sampling_state,
|
||||
greedy_state,
|
||||
parameters,
|
||||
step,
|
||||
dumper));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// next_tokens = torch.argmax(scores, dim=-1)
|
||||
int64_t next_token_scores_dims[] = {static_cast<int64_t>(batch_size),
|
||||
is_reuse_logits_buffer ? padded_vocab_size : vocab_size};
|
||||
|
|
@ -657,8 +686,8 @@ Status GreedySearchProcessLogits(
|
|||
*topk_scores, *topk_indices));
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
dumper->Print("topk_indices", *(topk_indices.get()));
|
||||
dumper->Print("topk_scores", *(topk_scores.get()));
|
||||
dumper->Print("topk_indices", *(topk_indices.get()));
|
||||
#endif
|
||||
|
||||
const int64_t* next_token_indices = topk_indices->Data<int64_t>();
|
||||
|
|
@ -996,7 +1025,7 @@ template Status ProcessLogits<float>(
|
|||
onnxruntime::concurrency::ThreadPool* thread_pool,
|
||||
transformers::ILogitsProcessorList* logits_processors,
|
||||
transformers::IBeamScorer* beam_scorer,
|
||||
const transformers::IBeamSearchParameters* parameters,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
int step,
|
||||
Stream* ort_stream,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
|
@ -1004,11 +1033,13 @@ template Status ProcessLogits<float>(
|
|||
template Status GreedySearchProcessLogits<float>(
|
||||
const OrtValue& logits,
|
||||
transformers::IGreedySearchState<float>* greedy_state,
|
||||
transformers::ISamplingState<float>* sampling_state,
|
||||
transformers::ISequences* sequences,
|
||||
AllocatorPtr& allocator,
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool,
|
||||
transformers::ILogitsProcessorList* logits_processors,
|
||||
const transformers::IBeamSearchParameters* parameters,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
bool do_sampling,
|
||||
int step,
|
||||
Stream* ort_stream,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
|
@ -1063,7 +1094,7 @@ template Status ProcessLogits<MLFloat16>(
|
|||
onnxruntime::concurrency::ThreadPool* thread_pool,
|
||||
transformers::ILogitsProcessorList* logits_processors,
|
||||
transformers::IBeamScorer* beam_scorer,
|
||||
const transformers::IBeamSearchParameters* parameters,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
int step,
|
||||
Stream* ort_stream,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
|
@ -1071,11 +1102,13 @@ template Status ProcessLogits<MLFloat16>(
|
|||
template Status GreedySearchProcessLogits<MLFloat16>(
|
||||
const OrtValue& logits,
|
||||
transformers::IGreedySearchState<MLFloat16>* greedy_state,
|
||||
transformers::ISamplingState<MLFloat16>* sampling_state,
|
||||
transformers::ISequences* sequences,
|
||||
AllocatorPtr& allocator,
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool,
|
||||
transformers::ILogitsProcessorList* logits_processors,
|
||||
const transformers::IBeamSearchParameters* parameters,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
bool do_sampling,
|
||||
int step,
|
||||
Stream* ort_stream,
|
||||
const transformers::IConsoleDumper* dumper);
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
transformers::IBeamScorer* beam_scorer, // beam scorer
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper); // tensor dumper
|
||||
|
|
@ -63,11 +63,13 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
template <typename T>
|
||||
Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph
|
||||
transformers::IGreedySearchState<T>* greedy_state, // state
|
||||
transformers::ISamplingState<T>* sampling_state,// sampling buffers
|
||||
transformers::ISequences* sequences, // sequences
|
||||
AllocatorPtr& allocator, // default allocator
|
||||
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
|
||||
transformers::ILogitsProcessorList* logits_processors, // logits processors
|
||||
const transformers::IBeamSearchParameters* parameters, // parameters
|
||||
const transformers::IGenerationParameters* parameters, // parameters
|
||||
bool do_sampling, // whether to do sampling
|
||||
int step, // iteration counter
|
||||
Stream* stream, // cuda stream (for CUDA only)
|
||||
const transformers::IConsoleDumper* dumper); // tensor dumper
|
||||
|
|
|
|||
68
onnxruntime/contrib_ops/cuda/transformers/sampling.cc
Normal file
68
onnxruntime/contrib_ops/cuda/transformers/sampling.cc
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/cuda_execution_provider.h"
|
||||
#include "contrib_ops/cuda/transformers/sampling.h"
|
||||
#include "contrib_ops/cuda/transformers/generation_device_helper.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
ONNX_OPERATOR_KERNEL_EX(
|
||||
Sampling,
|
||||
kMSDomain,
|
||||
1,
|
||||
kCudaExecutionProvider,
|
||||
(*KernelDefBuilder::Create())
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'input_ids' needs to be on CPU
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 1) // 'max_length' needs to be on CPU
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 2) // 'min_length' needs to be on CPU
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 3) // 'repetition_penalty' needs to be on CPU
|
||||
.InputMemoryType(OrtMemTypeCPUInput, 6) // 'custom_attention_mask' needs to be on CPU
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU
|
||||
.OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'logits_to_debug' output on CPU
|
||||
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
|
||||
DataTypeImpl::GetTensorType<MLFloat16>()}),
|
||||
Sampling);
|
||||
|
||||
transformers::CudaTensorConsoleDumper g_cuda_dumper_sampling;
|
||||
|
||||
Sampling::Sampling(const OpKernelInfo& info)
|
||||
: onnxruntime::contrib::transformers::Sampling(info) {
|
||||
SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds,
|
||||
GenerationCudaDeviceHelper::TopK,
|
||||
GenerationCudaDeviceHelper::DeviceCopy<float>,
|
||||
GenerationCudaDeviceHelper::GreedySearchProcessLogits<float>,
|
||||
GenerationCudaDeviceHelper::GreedySearchProcessLogits<MLFloat16>,
|
||||
GenerationCudaDeviceHelper::InitGreedyState<float>,
|
||||
GenerationCudaDeviceHelper::InitGreedyState<MLFloat16>);
|
||||
|
||||
SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds<float>,
|
||||
GenerationCudaDeviceHelper::UpdateGptFeeds<MLFloat16>);
|
||||
|
||||
SetConsoleDumper(&g_cuda_dumper_sampling);
|
||||
}
|
||||
|
||||
Status Sampling::ComputeInternal(OpKernelContext* context) const {
|
||||
return onnxruntime::contrib::transformers::Sampling::Compute(context);
|
||||
}
|
||||
|
||||
Status Sampling::Compute(OpKernelContext* context) const {
|
||||
auto s = ComputeInternal(context);
|
||||
|
||||
if (s.IsOK()) {
|
||||
auto err = cudaGetLastError();
|
||||
if (err != cudaSuccess) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err));
|
||||
}
|
||||
}
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
26
onnxruntime/contrib_ops/cuda/transformers/sampling.h
Normal file
26
onnxruntime/contrib_ops/cuda/transformers/sampling.h
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "contrib_ops/cpu/transformers/sampling.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
class SessionState;
|
||||
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
class Sampling final : public onnxruntime::contrib::transformers::Sampling {
|
||||
public:
|
||||
Sampling(const OpKernelInfo& info);
|
||||
|
||||
Status Compute(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
Status ComputeInternal(OpKernelContext* context) const;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
186
onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h
Normal file
186
onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cpu/tensor/utils.h"
|
||||
#include "contrib_ops/cpu/transformers/generation_shared.h"
|
||||
#include "core/providers/cuda/math/softmax.h"
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
#include <iostream>
|
||||
#endif
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace SamplingCudaHelper {
|
||||
|
||||
template <typename T>
|
||||
Status Sample(AllocatorPtr& allocator,
|
||||
cudaStream_t cuda_stream,
|
||||
gsl::span<T>& next_token_scores,
|
||||
transformers::ISamplingState<T>* sampling_state,
|
||||
transformers::IGreedySearchState<T>* greedy_state,
|
||||
const transformers::IGenerationParameters* parameters,
|
||||
int step,
|
||||
const transformers::IConsoleDumper* dumper) {
|
||||
ORT_UNUSED_PARAMETER(dumper);
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
gsl::span<int>& d_index_in = sampling_state->d_index_in;
|
||||
gsl::span<int>& d_offset = sampling_state->d_offset;
|
||||
|
||||
BufferUniquePtr& storage_buffer = sampling_state->storage_buffer;
|
||||
size_t& temp_storage_bytes = sampling_state->temp_storage_bytes;
|
||||
|
||||
bool is_descending = parameters->custom_sampling;
|
||||
if (step == 1) {
|
||||
cuda::GetTempStorageSize<CudaT>(reinterpret_cast<CudaT*>(next_token_scores.data()),
|
||||
d_index_in.data(),
|
||||
d_offset.data(),
|
||||
parameters->batch_size * parameters->vocab_size,
|
||||
parameters->batch_size,
|
||||
cuda_stream,
|
||||
is_descending,
|
||||
temp_storage_bytes);
|
||||
|
||||
cuda::LaunchSetupParamsKernel(d_index_in.data(),
|
||||
d_offset.data(),
|
||||
parameters->batch_size,
|
||||
parameters->vocab_size,
|
||||
cuda_stream);
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("d_offset_buffer", d_offset.data(), parameters->batch_size + 1, 1);
|
||||
#endif
|
||||
|
||||
void* temp_storage = allocator->Alloc(sampling_state->temp_storage_bytes);
|
||||
BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator));
|
||||
storage_buffer = std::move(temp_storage_buffer);
|
||||
}
|
||||
|
||||
gsl::span<T>& d_sorted_score = sampling_state->d_sorted_score;
|
||||
gsl::span<int>& d_index_out = sampling_state->d_index_out;
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("temp_storage_bytes", sampling_state->temp_storage_bytes, true);
|
||||
#endif
|
||||
|
||||
cuda::LaunchSortPairs<CudaT>(storage_buffer.get(),
|
||||
temp_storage_bytes,
|
||||
reinterpret_cast<CudaT*>(next_token_scores.data()),
|
||||
reinterpret_cast<CudaT*>(d_sorted_score.data()),
|
||||
d_index_in.data(),
|
||||
d_index_out.data(),
|
||||
parameters->batch_size * parameters->vocab_size,
|
||||
parameters->batch_size,
|
||||
d_offset.data(),
|
||||
cuda_stream,
|
||||
is_descending);
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("d_sorted_score_buffer",
|
||||
reinterpret_cast<T*>(d_sorted_score.data()),
|
||||
parameters->batch_size,
|
||||
parameters->vocab_size);
|
||||
dumper->Print("d_index_buffer_in", d_index_in.data(), parameters->batch_size, parameters->vocab_size);
|
||||
dumper->Print("d_index_buffer_out", d_index_out.data(), parameters->batch_size, parameters->vocab_size);
|
||||
#endif
|
||||
|
||||
gsl::span<float>& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score;
|
||||
dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
|
||||
d_sorted_softmaxed_score.data(),
|
||||
reinterpret_cast<CudaT*>(d_sorted_score.data()),
|
||||
parameters->vocab_size,
|
||||
parameters->vocab_size,
|
||||
parameters->vocab_size,
|
||||
parameters->batch_size);
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("d_sorted_softmaxed_score_buffer",
|
||||
d_sorted_softmaxed_score.data(),
|
||||
parameters->batch_size,
|
||||
parameters->vocab_size);
|
||||
#endif
|
||||
|
||||
cuda::LaunchFilterLogitsKernel<CudaT>(d_sorted_softmaxed_score.data(),
|
||||
d_index_out.data(),
|
||||
reinterpret_cast<CudaT*>(next_token_scores.data()),
|
||||
parameters->top_p,
|
||||
parameters->filter_value,
|
||||
parameters->min_tokens_to_keep,
|
||||
parameters->batch_size,
|
||||
parameters->vocab_size,
|
||||
cuda_stream,
|
||||
is_descending);
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("next_token_scores after filtering logits",
|
||||
reinterpret_cast<T*>(next_token_scores.data()),
|
||||
parameters->batch_size,
|
||||
parameters->vocab_size);
|
||||
#endif
|
||||
|
||||
gsl::span<float>& d_softmaxed_score = sampling_state->d_softmaxed_score;
|
||||
dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
|
||||
d_softmaxed_score.data(),
|
||||
reinterpret_cast<CudaT*>(next_token_scores.data()),
|
||||
parameters->vocab_size,
|
||||
parameters->vocab_size,
|
||||
parameters->vocab_size,
|
||||
parameters->batch_size);
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("d_softmaxed_score_buffer",
|
||||
d_softmaxed_score.data(),
|
||||
parameters->batch_size,
|
||||
parameters->vocab_size);
|
||||
#endif
|
||||
|
||||
// Multinomial sampling
|
||||
gsl::span<float>& d_sampled = sampling_state->d_sampled;
|
||||
gsl::span<float>& h_sampled_all = sampling_state->h_sampled_all;
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(),
|
||||
h_sampled_all.data() + (step - 1) * parameters->batch_size,
|
||||
sizeof(float) * parameters->batch_size,
|
||||
cudaMemcpyHostToDevice,
|
||||
cuda_stream));
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("d_sampled", d_sampled.data(), parameters->batch_size, 1);
|
||||
#endif
|
||||
|
||||
gsl::span<int64_t>& d_indices = sampling_state->d_indices;
|
||||
gsl::span<int>& presence_mask = sampling_state->d_presence_mask;
|
||||
cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(),
|
||||
d_sampled.data(),
|
||||
d_indices.data(),
|
||||
parameters->batch_size,
|
||||
parameters->vocab_size,
|
||||
presence_mask.data(),
|
||||
cuda_stream);
|
||||
|
||||
#ifdef DEBUG_GENERATION
|
||||
dumper->Print("d_indices", d_indices.data(), parameters->batch_size, 1);
|
||||
#endif
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(),
|
||||
sampling_state->d_indices.data(),
|
||||
greedy_state->next_tokens_cpu.size_bytes(),
|
||||
cudaMemcpyDeviceToHost,
|
||||
cuda_stream));
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state->h_softmaxed_score.data(),
|
||||
sampling_state->d_softmaxed_score.data(),
|
||||
sampling_state->h_softmaxed_score.size_bytes(),
|
||||
cudaMemcpyDeviceToHost,
|
||||
cuda_stream));
|
||||
|
||||
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace SamplingCudaHelper
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -510,6 +510,13 @@ void GreedySearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
|
|||
sequences_shape.add_dim()->set_dim_value(batch_size);
|
||||
sequences_shape.add_dim()->set_dim_value(max_length_value);
|
||||
updateOutputShape(ctx, 0, sequences_shape);
|
||||
|
||||
if (ctx.getNumOutputs() > 1) {
|
||||
ONNX_NAMESPACE::TensorShapeProto logits_to_debug_shape;
|
||||
logits_to_debug_shape.add_dim()->set_dim_value(batch_size);
|
||||
logits_to_debug_shape.add_dim();
|
||||
updateOutputShape(ctx, 1, logits_to_debug_shape);
|
||||
}
|
||||
}
|
||||
|
||||
constexpr const char* Gelu_ver1_doc =
|
||||
|
|
@ -1114,6 +1121,48 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1,
|
|||
GreedySearchShapeInference(ctx);
|
||||
}));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
|
||||
OpSchema()
|
||||
.SetDoc("Greedy Sampling for text generation.")
|
||||
.Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT)
|
||||
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
|
||||
.Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast<int64_t>(-1))
|
||||
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("temperature", "The value used to module the next token probabilities.", AttributeProto::FLOAT, 1.0f)
|
||||
.Attr("top_p",
|
||||
"If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.",
|
||||
AttributeProto::FLOAT, 0.0f)
|
||||
.Attr("filter_value", "All filtered values will be set to this float value.", AttributeProto::FLOAT, -1e20f)
|
||||
.Attr("min_tokens_to_keep", "Minimumber of tokens we keep per batch example in the output.", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("presence_penalty", "Presence penalty for custom sampling", AttributeProto::FLOAT, 0.0f)
|
||||
.Attr("custom", "If 1 custom sampling logic", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("model_type", "Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast<int64_t>(0))
|
||||
.Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
|
||||
.Attr("init_decoder",
|
||||
"The subgraph for the first decoding run. It will be called once before `decoder` subgraph. "
|
||||
"This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs",
|
||||
AttributeProto::GRAPH, OPTIONAL_VALUE)
|
||||
.Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH)
|
||||
.Attr("vocab_size",
|
||||
"Size of the vocabulary. "
|
||||
"If not provided, it will be inferred from the decoder subgraph's output shape",
|
||||
AttributeProto::INT, static_cast<int64_t>(-1))
|
||||
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
|
||||
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
|
||||
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
|
||||
.Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
|
||||
.Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional)
|
||||
.Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
|
||||
.Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
|
||||
.Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
|
||||
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I")
|
||||
.Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional)
|
||||
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
GreedySearchShapeInference(ctx);
|
||||
}));
|
||||
|
||||
ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1,
|
||||
OpSchema()
|
||||
.Input(0, "X", "input", "T")
|
||||
|
|
|
|||
|
|
@ -81,6 +81,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
|
|||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul);
|
||||
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer);
|
||||
|
|
@ -165,6 +166,7 @@ class OpSet_Microsoft_ver1 {
|
|||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul)>());
|
||||
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer)>());
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@
|
|||
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search.h"
|
||||
#include "contrib_ops/cpu/transformers/greedy_search.h"
|
||||
#include "contrib_ops/cpu/transformers/sampling.h"
|
||||
#ifdef ENABLE_ATEN
|
||||
#include "contrib_ops/cpu/aten_ops/aten_op.h"
|
||||
#endif
|
||||
|
|
@ -248,6 +249,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
|
|||
subgraph_session_state);
|
||||
}
|
||||
|
||||
void Sampling__Init(contrib::transformers::Sampling* p, const OpKernelInfo& info) override { p->contrib::transformers::Sampling::Init(info); }
|
||||
Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
|
||||
Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }
|
||||
|
||||
|
||||
#ifdef ENABLE_ATEN
|
||||
Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); }
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ class AttentionBase;
|
|||
namespace transformers {
|
||||
class BeamSearch;
|
||||
class GreedySearch;
|
||||
class Sampling;
|
||||
}
|
||||
} // namespace contrib
|
||||
|
||||
|
|
@ -174,6 +175,10 @@ struct ProviderHostCPU {
|
|||
const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) = 0;
|
||||
|
||||
virtual void Sampling__Init(contrib::transformers::Sampling* p, const OpKernelInfo& info) = 0;
|
||||
virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0;
|
||||
virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;
|
||||
|
||||
#ifdef ENABLE_ATEN
|
||||
virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0;
|
||||
#endif
|
||||
|
|
|
|||
|
|
@ -207,13 +207,13 @@ template <typename T, typename IndexType = int64_t>
|
|||
using EigenVector = Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>>;
|
||||
|
||||
template <typename OutputType>
|
||||
static Status MultinomialCompute(OpKernelContext* ctx,
|
||||
const Tensor& X,
|
||||
const int64_t batch_size,
|
||||
const int64_t num_classes,
|
||||
const int64_t num_samples,
|
||||
std::default_random_engine& generator,
|
||||
Tensor& Y) {
|
||||
Status MultinomialComputeShared(AllocatorPtr& alloc,
|
||||
const Tensor& X,
|
||||
const int64_t batch_size,
|
||||
const int64_t num_classes,
|
||||
const int64_t num_samples,
|
||||
std::default_random_engine& generator,
|
||||
Tensor& Y) {
|
||||
if (!utils::HasType<EnabledMultinomialOutputTypes, OutputType>()) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build.");
|
||||
}
|
||||
|
|
@ -227,8 +227,6 @@ static Status MultinomialCompute(OpKernelContext* ctx,
|
|||
Matrix<OutputType> output = Matrix<OutputType>(Y.MutableData<OutputType>(), Y_dims);
|
||||
|
||||
// BEGIN create temporary tensor
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
|
||||
auto cdf_data = static_cast<double*>(alloc->Alloc(SafeInt<size_t>(sizeof(double)) * num_classes));
|
||||
BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(std::move(alloc)));
|
||||
Eigen::array<int64_t, 1> cdf_dims = {{num_classes}};
|
||||
|
|
@ -271,6 +269,20 @@ static Status MultinomialCompute(OpKernelContext* ctx,
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename OutputType>
|
||||
static Status MultinomialCompute(OpKernelContext* ctx,
|
||||
const Tensor& X,
|
||||
const int64_t batch_size,
|
||||
const int64_t num_classes,
|
||||
const int64_t num_samples,
|
||||
std::default_random_engine& generator,
|
||||
Tensor& Y) {
|
||||
// BEGIN create temporary tensor
|
||||
AllocatorPtr alloc;
|
||||
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
|
||||
return MultinomialComputeShared<OutputType>(alloc, X, batch_size, num_classes, num_samples, generator, Y);
|
||||
}
|
||||
|
||||
Status Multinomial::Compute(OpKernelContext* ctx) const {
|
||||
const auto* tensor_pointer = ctx->Input<Tensor>(0);
|
||||
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
|
||||
|
|
@ -408,4 +420,12 @@ void GenerateData(std::default_random_engine& generator, TDistribution distribut
|
|||
}
|
||||
}
|
||||
|
||||
template Status MultinomialComputeShared<int64_t>(AllocatorPtr& alloc,
|
||||
const Tensor& X,
|
||||
const int64_t batch_size,
|
||||
const int64_t num_classes,
|
||||
const int64_t num_samples,
|
||||
std::default_random_engine& generator,
|
||||
Tensor& Y);
|
||||
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -13,6 +13,15 @@
|
|||
|
||||
namespace onnxruntime {
|
||||
|
||||
template <typename OutputType>
|
||||
Status MultinomialComputeShared(AllocatorPtr& alloc,
|
||||
const Tensor& X,
|
||||
const int64_t batch_size,
|
||||
const int64_t num_classes,
|
||||
const int64_t num_samples,
|
||||
std::default_random_engine& generator,
|
||||
Tensor& Y);
|
||||
|
||||
class RandomNormal final : public OpKernel {
|
||||
public:
|
||||
RandomNormal(const OpKernelInfo& info) : OpKernel(info) {
|
||||
|
|
|
|||
|
|
@ -33,6 +33,7 @@
|
|||
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
|
||||
#include "contrib_ops/cpu/transformers/beam_search.h"
|
||||
#include "contrib_ops/cpu/transformers/greedy_search.h"
|
||||
#include "contrib_ops/cpu/transformers/sampling.h"
|
||||
#ifdef ENABLE_ATEN
|
||||
#include "contrib_ops/cpu/aten_ops/aten_op.h"
|
||||
#endif
|
||||
|
|
@ -601,6 +602,15 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
|
|||
return g_host_cpu.GreedySearch__SetupSubgraphExecutionInfo(this, session_state, attribute_name,
|
||||
subgraph_session_state);
|
||||
}
|
||||
|
||||
void Sampling::Init(const OpKernelInfo& info) { g_host_cpu.Sampling__Init(this, info); }
|
||||
|
||||
Status Sampling::Compute(OpKernelContext* ctx) const { return g_host_cpu.Sampling__Compute(this, ctx); }
|
||||
|
||||
Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name,
|
||||
const SessionState& subgraph_session_state) {
|
||||
return g_host_cpu.Sampling__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); }
|
||||
|
||||
} // namespace transformers
|
||||
|
||||
#ifdef ENABLE_ATEN
|
||||
|
|
|
|||
|
|
@ -25,6 +25,9 @@ Example 4: convert MT5 model with external data file like mt5-base-beamsearch.on
|
|||
|
||||
Example 5: convert gpt2 model with greedy search:
|
||||
python convert_generation.py -m gpt2 --output gpt2_greedy_search.onnx --num_beams 1 --num_return_sequences 1
|
||||
|
||||
Example 6: convert gpt2 model with sampling:
|
||||
python convert_generation.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 --top_p 0.6
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
|
@ -72,6 +75,7 @@ logger = logging.getLogger("")
|
|||
class GenerationType(Enum):
|
||||
BEAMSEARCH = "beam_search"
|
||||
GREEDYSEARCH = "greedy_search"
|
||||
SAMPLING = "sampling"
|
||||
|
||||
def __str__(self):
|
||||
return self.value
|
||||
|
|
@ -261,6 +265,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
|||
)
|
||||
model_group.set_defaults(custom_attention_mask=False)
|
||||
|
||||
model_group.add_argument(
|
||||
"--presence_mask",
|
||||
required=False,
|
||||
action="store_true",
|
||||
help="Presence mask for custom sampling",
|
||||
)
|
||||
model_group.set_defaults(presence_mask=False)
|
||||
|
||||
beam_parameters_group = parser.add_argument_group(
|
||||
"Beam search parameters not stored in the output model, for testing parity and performance"
|
||||
)
|
||||
|
|
@ -295,6 +307,54 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
|
|||
help="Positive. >1 to penalize and <1 to encourage.",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
required=False,
|
||||
default=1.0,
|
||||
help="The value used to module the next token probabilities.",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--top_p",
|
||||
type=float,
|
||||
required=False,
|
||||
default=1.0,
|
||||
help="Top P for sampling",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--filter_value",
|
||||
type=float,
|
||||
required=False,
|
||||
default=-float("Inf"),
|
||||
help="Filter value for Top P sampling",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--min_tokens_to_keep",
|
||||
type=int,
|
||||
required=False,
|
||||
default=1,
|
||||
help="Minimumber of tokens we keep per batch example in the output.",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--presence_penalty",
|
||||
type=float,
|
||||
required=False,
|
||||
default=0.0,
|
||||
help="presence penalty for custom sampling.",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--custom",
|
||||
type=int,
|
||||
required=False,
|
||||
default=0,
|
||||
help="If 1 customized top P logic is applied",
|
||||
)
|
||||
|
||||
beam_parameters_group.add_argument(
|
||||
"--vocab_size",
|
||||
type=int,
|
||||
|
|
@ -1201,18 +1261,20 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
args (argparse.Namespace): arguments parsed from command line
|
||||
"""
|
||||
is_gpt2: bool = args.model_type == "gpt2"
|
||||
is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH
|
||||
is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH
|
||||
is_sampling: bool = generation_type == GenerationType.SAMPLING
|
||||
past_present_share_buffer: bool = args.past_present_share_buffer and is_greedysearch
|
||||
|
||||
logger.info(f"**** past_present_share_buffer={past_present_share_buffer}, is_greedysearch={is_greedysearch}")
|
||||
|
||||
if is_greedysearch:
|
||||
if is_greedysearch or is_sampling:
|
||||
if not is_gpt2:
|
||||
raise NotImplementedError("Currently only gpt2 with greedy search is supported")
|
||||
raise NotImplementedError("Currently only gpt2 with greedy search/sampling is supported")
|
||||
if args.output_sequences_scores:
|
||||
raise NotImplementedError("output_sequences_scores currently is not supported in greedy search")
|
||||
raise NotImplementedError("output_sequences_scores currently is not supported in greedy search/sampling")
|
||||
if args.output_token_scores:
|
||||
raise NotImplementedError("output_token_scores currently is not supported in greedy search")
|
||||
raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling")
|
||||
|
||||
if is_gpt2:
|
||||
if args.decoder_onnx and os.path.exists(args.decoder_onnx):
|
||||
|
|
@ -1243,7 +1305,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
args.pad_vocab_size
|
||||
and args.precision == Precision.FLOAT16
|
||||
and is_gpt2
|
||||
and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH)
|
||||
and (is_beamsearch or is_greedysearch or is_sampling)
|
||||
):
|
||||
logger.info(
|
||||
f"Pad logits MatMul weights for optimal MatMul perf in fp16 on {args.decoder_onnx}. "
|
||||
|
|
@ -1257,11 +1319,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
|
||||
gpt2_init_decoder_generated = False
|
||||
gpt2_init_decoder_onnx_path = None
|
||||
if (
|
||||
args.separate_gpt2_decoder_for_init_run
|
||||
and is_gpt2
|
||||
and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH)
|
||||
):
|
||||
if args.separate_gpt2_decoder_for_init_run and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling):
|
||||
logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ")
|
||||
|
||||
gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format(
|
||||
|
|
@ -1331,8 +1389,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
else:
|
||||
verify_t5_decoder_subgraph(decoder_model.graph, args.precision)
|
||||
|
||||
inputs = (
|
||||
[
|
||||
inputs = None
|
||||
if is_beamsearch:
|
||||
inputs = [
|
||||
"input_ids",
|
||||
"max_length",
|
||||
"min_length",
|
||||
|
|
@ -1341,14 +1400,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
"length_penalty",
|
||||
"repetition_penalty",
|
||||
]
|
||||
if not is_greedysearch
|
||||
else [
|
||||
elif is_greedysearch or is_sampling:
|
||||
inputs = [
|
||||
"input_ids",
|
||||
"max_length",
|
||||
"min_length",
|
||||
"repetition_penalty",
|
||||
]
|
||||
)
|
||||
|
||||
if args.vocab_mask:
|
||||
inputs.append("vocab_mask")
|
||||
|
|
@ -1365,6 +1423,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
else:
|
||||
inputs.append("")
|
||||
|
||||
if is_sampling and args.custom and args.presence_mask:
|
||||
inputs.append("presence_mask")
|
||||
|
||||
outputs = ["sequences"]
|
||||
if args.output_sequences_scores:
|
||||
outputs.append("sequences_scores")
|
||||
|
|
@ -1373,40 +1434,60 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
|
||||
outputs.append("scores")
|
||||
|
||||
node = (
|
||||
onnx.helper.make_node(
|
||||
node = None
|
||||
if is_beamsearch:
|
||||
node = onnx.helper.make_node(
|
||||
"BeamSearch",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=f"BeamSearch_{args.model_type}",
|
||||
)
|
||||
if not is_greedysearch
|
||||
else onnx.helper.make_node(
|
||||
elif is_greedysearch:
|
||||
node = onnx.helper.make_node(
|
||||
"GreedySearch",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=f"GreedySearch_{args.model_type}",
|
||||
)
|
||||
)
|
||||
elif is_sampling:
|
||||
node = onnx.helper.make_node(
|
||||
"Sampling",
|
||||
inputs=inputs,
|
||||
outputs=outputs,
|
||||
name=f"Sampling_{args.model_type}",
|
||||
)
|
||||
|
||||
node.domain = "com.microsoft"
|
||||
|
||||
attr_to_extend = (
|
||||
[
|
||||
attr_to_extend = None
|
||||
if is_beamsearch:
|
||||
attr_to_extend = [
|
||||
onnx.helper.make_attribute("eos_token_id", eos_token_id),
|
||||
onnx.helper.make_attribute("pad_token_id", pad_token_id),
|
||||
onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
|
||||
onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
|
||||
]
|
||||
if not is_greedysearch
|
||||
else [
|
||||
elif is_greedysearch:
|
||||
attr_to_extend = [
|
||||
onnx.helper.make_attribute("eos_token_id", eos_token_id),
|
||||
onnx.helper.make_attribute("pad_token_id", pad_token_id),
|
||||
onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
|
||||
onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
]
|
||||
)
|
||||
elif is_sampling:
|
||||
attr_to_extend = [
|
||||
onnx.helper.make_attribute("eos_token_id", eos_token_id),
|
||||
onnx.helper.make_attribute("pad_token_id", pad_token_id),
|
||||
onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
|
||||
onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
|
||||
onnx.helper.make_attribute("temperature", args.temperature),
|
||||
onnx.helper.make_attribute("top_p", args.top_p),
|
||||
onnx.helper.make_attribute("filter_value", args.filter_value),
|
||||
onnx.helper.make_attribute("min_tokens_to_keep", args.min_tokens_to_keep),
|
||||
onnx.helper.make_attribute("custom", args.custom),
|
||||
onnx.helper.make_attribute("presence_penalty", args.presence_penalty),
|
||||
]
|
||||
|
||||
# Explicitly pass in the vocab size via an attribute
|
||||
if logits_matmul_weight_padded:
|
||||
|
|
@ -1481,8 +1562,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
|
||||
repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
|
||||
|
||||
graph_inputs = (
|
||||
[
|
||||
graph_inputs = None
|
||||
if is_beamsearch:
|
||||
graph_inputs = [
|
||||
input_ids,
|
||||
max_length,
|
||||
min_length,
|
||||
|
|
@ -1491,14 +1573,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
length_penalty,
|
||||
repetition_penalty,
|
||||
]
|
||||
if not is_greedysearch
|
||||
else [
|
||||
elif is_greedysearch or is_sampling:
|
||||
graph_inputs = [
|
||||
input_ids,
|
||||
max_length,
|
||||
min_length,
|
||||
repetition_penalty,
|
||||
]
|
||||
)
|
||||
|
||||
if args.vocab_mask:
|
||||
vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
|
||||
|
|
@ -1516,37 +1597,41 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
|
|||
)
|
||||
graph_inputs.append(attention_mask)
|
||||
|
||||
if args.custom and args.presence_mask:
|
||||
presence_mask = onnx.helper.make_tensor_value_info(
|
||||
"presence_mask", TensorProto.INT32, ["batch_size", vocab_size]
|
||||
)
|
||||
graph_inputs.append(presence_mask)
|
||||
|
||||
# graph outputs
|
||||
sequences = (
|
||||
onnx.helper.make_tensor_value_info(
|
||||
sequences = None
|
||||
if is_beamsearch:
|
||||
sequences = onnx.helper.make_tensor_value_info(
|
||||
"sequences",
|
||||
TensorProto.INT32,
|
||||
["batch_size", "num_return_sequences", "max_length"],
|
||||
)
|
||||
if not is_greedysearch
|
||||
else onnx.helper.make_tensor_value_info(
|
||||
elif is_greedysearch or is_sampling:
|
||||
sequences = onnx.helper.make_tensor_value_info(
|
||||
"sequences",
|
||||
TensorProto.INT32,
|
||||
["batch_size", "max_length"],
|
||||
)
|
||||
)
|
||||
|
||||
sequences_scores = onnx.helper.make_tensor_value_info(
|
||||
"sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"]
|
||||
)
|
||||
|
||||
scores = onnx.helper.make_tensor_value_info(
|
||||
"scores",
|
||||
TensorProto.FLOAT,
|
||||
["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
|
||||
)
|
||||
|
||||
graph_outputs = [sequences]
|
||||
|
||||
if args.output_sequences_scores:
|
||||
sequences_scores = onnx.helper.make_tensor_value_info(
|
||||
"sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"]
|
||||
)
|
||||
graph_outputs.append(sequences_scores)
|
||||
|
||||
if args.output_token_scores:
|
||||
scores = onnx.helper.make_tensor_value_info(
|
||||
"scores",
|
||||
TensorProto.FLOAT,
|
||||
["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
|
||||
)
|
||||
graph_outputs.append(scores)
|
||||
|
||||
new_graph = onnx.helper.make_graph(
|
||||
|
|
@ -2085,6 +2170,10 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None
|
|||
is_greedy = args.num_beams == 1 and args.num_return_sequences == 1
|
||||
|
||||
if args.model_type == "gpt2" and is_greedy:
|
||||
if args.top_p > 0.0 and args.top_p < 1.0:
|
||||
convert_generation_model(args, GenerationType.SAMPLING)
|
||||
logger.info("The test for gpt2_sampling onnx model is not implemented yet")
|
||||
return
|
||||
convert_generation_model(args, GenerationType.GREEDYSEARCH)
|
||||
else:
|
||||
convert_generation_model(args)
|
||||
|
|
|
|||
Loading…
Reference in a new issue