Sampling op (#13426)

### Description
<!-- Describe your changes. -->

Sampling op for cpu and cuda
support huggingface case and custom case
            


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2022-12-22 17:34:12 -08:00 committed by GitHub
parent 4d2dc8bbbd
commit 68518a1b72
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 2404 additions and 509 deletions

View file

@ -75,10 +75,13 @@ set(contrib_ops_excluded_files
"transformers/beam_search.h"
"transformers/generation_device_helper.cc"
"transformers/generation_device_helper.h"
"transformers/beam_search_impl.cu"
"transformers/beam_search_impl.h"
"transformers/generation_cuda_impl.cu"
"transformers/generation_cuda_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
"transformers/sampling.cc"
"transformers/sampling.h"
"transformers/sampling_cuda_helper.h"
"transformers/dump_cuda_tensor.cc"
"transformers/dump_cuda_tensor.h"
"conv_transpose_with_dynamic_pads.cc"

View file

@ -74,6 +74,7 @@ Do not modify directly.*
* <a href="#com.microsoft.RestorePadding">com.microsoft.RestorePadding</a>
* <a href="#com.microsoft.Rfft">com.microsoft.Rfft</a>
* <a href="#com.microsoft.SampleOp">com.microsoft.SampleOp</a>
* <a href="#com.microsoft.Sampling">com.microsoft.Sampling</a>
* <a href="#com.microsoft.SkipLayerNormalization">com.microsoft.SkipLayerNormalization</a>
* <a href="#com.microsoft.Snpe">com.microsoft.Snpe</a>
* <a href="#com.microsoft.SparseToDenseMatMul">com.microsoft.SparseToDenseMatMul</a>
@ -3810,6 +3811,89 @@ This version of the operator has been available since version 1 of the 'com.micr
</dl>
### <a name="com.microsoft.Sampling"></a><a name="com.microsoft.sampling">**com.microsoft.Sampling**</a>
Greedy Sampling for text generation.
#### Version
This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
#### Attributes
<dl>
<dt><tt>custom</tt> : int</dt>
<dd>If 1 custom sampling logic</dd>
<dt><tt>decoder</tt> : graph (required)</dt>
<dd>Decoder subgraph to execute in a loop.</dd>
<dt><tt>decoder_start_token_id</tt> : int</dt>
<dd>The id of the token that indicates decoding starts.</dd>
<dt><tt>encoder</tt> : graph</dt>
<dd>The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.</dd>
<dt><tt>eos_token_id</tt> : int (required)</dt>
<dd>The id of the end-of-sequence token</dd>
<dt><tt>filter_value</tt> : float</dt>
<dd>All filtered values will be set to this float value.</dd>
<dt><tt>init_decoder</tt> : graph</dt>
<dd>The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs</dd>
<dt><tt>min_tokens_to_keep</tt> : int</dt>
<dd>Minimumber of tokens we keep per batch example in the output.</dd>
<dt><tt>model_type</tt> : int</dt>
<dd>Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart</dd>
<dt><tt>no_repeat_ngram_size</tt> : int</dt>
<dd>no repeat ngrams size</dd>
<dt><tt>pad_token_id</tt> : int (required)</dt>
<dd>The id of the padding token</dd>
<dt><tt>presence_penalty</tt> : float</dt>
<dd>Presence penalty for custom sampling</dd>
<dt><tt>temperature</tt> : float</dt>
<dd>The value used to module the next token probabilities.</dd>
<dt><tt>top_p</tt> : float</dt>
<dd>If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.</dd>
<dt><tt>vocab_size</tt> : int</dt>
<dd>Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape</dd>
</dl>
#### Inputs (2 - 8)
<dl>
<dt><tt>input_ids</tt> : I</dt>
<dd>The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)</dd>
<dt><tt>max_length</tt> : I</dt>
<dd>The maximum length of the sequence to be generated. Shape is (1)</dd>
<dt><tt>min_length</tt> (optional) : I</dt>
<dd>The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)</dd>
<dt><tt>repetition_penalty</tt> (optional) : T</dt>
<dd>The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)</dd>
<dt><tt>vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)</dd>
<dt><tt>prefix_vocab_mask</tt> (optional) : I</dt>
<dd>Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)</dd>
<dt><tt>attention_mask</tt> (optional) : I</dt>
<dd>Custom attention mask. Shape is (batch_size, sequence_length)</dd>
<dt><tt>presence_mask</tt> (optional) : I</dt>
<dd>Presence penalty mask. Shape is (batch_size, vocab_size)</dd>
</dl>
#### Outputs (1 - 2)
<dl>
<dt><tt>sequences</tt> : I</dt>
<dd>Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)</dd>
<dt><tt>filtered_logits</tt> (optional) : T</dt>
<dd>Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)</dd>
</dl>
#### Type Constraints
<dl>
<dt><tt>T</tt> : tensor(float)</dt>
<dd>Constrain input and output types to float tensors.</dd>
<dt><tt>I</tt> : tensor(int32)</dt>
<dd>Constrain to integer types</dd>
</dl>
### <a name="com.microsoft.SkipLayerNormalization"></a><a name="com.microsoft.skiplayernormalization">**com.microsoft.SkipLayerNormalization**</a>
Skip and Layer Normalization Fusion

View file

@ -438,6 +438,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**<br> *in* limit:**T**<br> *in* delta:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**<br> *in* B:**T1**<br> *out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)<br/> **T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(string)|
@ -797,6 +798,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**<br> *in* sequence_token_count:**M**<br> *out* output:**T**<br> *out* token_offset:**M**<br> *out* cumulated_seq_len:**M**<br> *out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**<br> *in* token_offset:**M**<br> *out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
|Sampling|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *in* presence_mask:**I**<br> *out* sequences:**I**<br> *out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**<br> *in* skip:**T**<br> *in* gamma:**T**<br> *in* beta:**T**<br> *in* bias:**T**<br> *out* output:**T**<br> *out* mean:**U**<br> *out* inv_std_var:**U**<br> *out* input_skip_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**<br> *in* k:**tensor(int64)**<br> *out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|

View file

@ -18,6 +18,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range)>,

View file

@ -61,16 +61,17 @@ void BeamSearch::Init(const OpKernelInfo& info) {
parameters_.ParseFromAttributes(info);
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt ||
parameters_.model_type == IGenerationParameters::kModelTypeT5);
ONNX_NAMESPACE::GraphProto proto;
if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
// Make sure the encoder sub-graph attribute is present for the T5 model.
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
}
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
if (info.GetAttr<ONNX_NAMESPACE::GraphProto>("init_decoder", &proto).IsOK()) {
has_init_decoder_ = true;
@ -87,7 +88,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
const auto& node = Node();
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
if (attribute_name == "decoder") {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
@ -113,8 +114,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
init_run_gpt_subgraph_ = std::move(res.second);
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
}
} else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) {
} else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
if (attribute_name == "encoder") {
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
@ -167,7 +167,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
// Make a copy of parameters since we will update it based on inputs later
BeamSearchParameters parameters = parameters_;
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32
BeamSearchGpt<float> impl{
*ctx_internal,

View file

@ -191,7 +191,8 @@ Status BeamSearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
context.Input<Tensor>(0), // input_ids
context.Input<Tensor>(7), // vocab_mask
context.Input<Tensor>(8), // prefix_vocab_mask
context.Input<Tensor>(9))); // attention_mask
context.Input<Tensor>(9), // attention_mask
nullptr)); // presence_mask
return Status::OK();
}

View file

@ -18,7 +18,7 @@ Status BeamSearchParameters::Validate() const {
}
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IBeamSearchParameters::kModelTypeGpt));
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", IGenerationParameters::kModelTypeGpt));
early_stopping = info.GetAttrOrDefault<int64_t>("early_stopping", 0) == 1;
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
@ -35,7 +35,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
batch_size = static_cast<int>(dims[0]);
// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;
sequence_length = (this->model_type == IGenerationParameters::kModelTypeGpt) ? static_cast<int>(dims[1]) : 1;
auto* max_length_tensor = context->Input<Tensor>(1);
max_length = max_length_tensor ? static_cast<int>(*max_length_tensor->Data<int32_t>()) : kMaxSequenceLength;
@ -71,10 +71,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
if (vocab_size == -1) {
if (vocab_size == -1 || vocab_size == 0) {
vocab_size = vocabulary_size;
}
num_heads = heads;
head_size = hidden_size_per_head;
num_layers = layers;

View file

@ -10,7 +10,7 @@ namespace onnxruntime {
namespace contrib {
namespace transformers {
struct BeamSearchParameters : public IBeamSearchParameters {
struct BeamSearchParameters : public IGenerationParameters {
Status Validate() const;
int BatchBeamSize() const { return batch_size * num_beams; }

View file

@ -87,7 +87,8 @@ class GenerateBase {
const Tensor* input_ids,
const Tensor* vocab_mask,
const Tensor* prefix_vocab_mask,
const Tensor* attention_mask) const {
const Tensor* attention_mask,
const Tensor* presence_mask) const {
const auto& dims = input_ids->Shape().GetDims();
if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@ -149,6 +150,28 @@ class GenerateBase {
}
}
if (presence_mask != nullptr) {
const auto& dims_presence = presence_mask->Shape().GetDims();
if (dims_presence.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'presence_mask' is expected to have 2 dimensions, got ", dims_presence.size());
}
// presence_mask first dimension should be same as the first dimension of input_ids
if (static_cast<int>(dims_presence[0]) != static_cast<int>(dims[0])) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"input_ids and presence_mask must have the same batch_size");
}
if (static_cast<int>(dims_presence[1]) != parameters->vocab_size) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Input 'presence_mask' shape[1] shall be vocab_size, got ", dims_presence[1]);
}
// store prefix vocab mask in parameters.
parameters->presence_mask = presence_mask->DataAsSpan<int32_t>();
}
return Status::OK();
}

View file

@ -6,11 +6,13 @@
#include <memory>
#include "core/providers/cpu/math/top_k.h"
#include "core/providers/cpu/math/softmax_shared.h"
#include "core/providers/cpu/generator/random.h"
#include "core/common/safeint.h"
#include "core/common/gsl.h"
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/generation_device_helper.h"
#include "contrib_ops/cpu/transformers/sampling_cpu_helper.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
@ -175,6 +177,13 @@ Status CreateGptInputs(
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
if (num_beams == 1) {
expanded_input_ids = std::move(input_ids);
expanded_position_ids = std::move(position_ids);
expanded_attention_mask = std::move(attention_mask);
return Status::OK();
}
ExpandInputs<int32_t>(input_ids, num_beams, allocator, expanded_input_ids);
ExpandInputs<int32_t>(position_ids, num_beams, allocator, expanded_position_ids);
ExpandInputs<int32_t>(attention_mask, num_beams, allocator, expanded_attention_mask);
@ -243,7 +252,7 @@ Status ProcessLogits(const OrtValue& logits, //
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
transformers::IBeamScorer* beam_scorer, // beam scorer
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
@ -400,17 +409,16 @@ template <typename T>
Status GreedySearchProcessLogits(
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingState<T>* sampling_state, // sampling_state
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
#ifndef DEBUG_GENERATION
ORT_UNUSED_PARAMETER(dumper);
#endif
int batch_size = parameters->batch_size;
int vocab_size = parameters->vocab_size;
@ -448,6 +456,18 @@ Status GreedySearchProcessLogits(
dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size);
#endif
if (do_sampling) {
ORT_RETURN_IF_ERROR(SamplingCpuHelper::Sample(allocator,
thread_pool,
next_token_scores,
sampling_state,
greedy_state,
parameters,
dumper));
return Status::OK();
}
// next_tokens = torch.argmax(scores, dim=-1)
int64_t next_token_scores_dims[] = {static_cast<int64_t>(batch_size), vocab_size};
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
@ -460,28 +480,27 @@ Status GreedySearchProcessLogits(
next_token_scores_value);
const Tensor& input = next_token_scores_value.Get<Tensor>();
constexpr int axis = 1;
constexpr unsigned top_k = 1;
constexpr int axis = 1;
constexpr bool largest = true;
constexpr bool sorted = false;
Tensor topk_scores;
Tensor topk_indices;
ORT_RETURN_IF_ERROR(
TopK(&input,
axis,
top_k,
largest,
sorted,
allocator,
stream,
thread_pool,
topk_scores,
topk_indices));
ORT_RETURN_IF_ERROR(TopK(&input,
axis,
top_k,
largest,
sorted,
allocator,
stream,
thread_pool,
topk_scores,
topk_indices));
#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", topk_scores);
dumper->Print("topk_indices", topk_indices);
dumper->Print("topk_scores", topk_scores);
dumper->Print("topk_indices", topk_indices);
#endif
gsl::span<const int64_t> next_token_indices = topk_indices.DataAsSpan<int64_t>();
@ -829,7 +848,7 @@ template Status ProcessLogits<float>(
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ILogitsProcessorList* logits_processors,
transformers::IBeamScorer* beam_scorer,
const transformers::IBeamSearchParameters* parameters,
const transformers::IGenerationParameters* parameters,
int step,
Stream* stream,
const transformers::IConsoleDumper* dumper);
@ -837,11 +856,13 @@ template Status ProcessLogits<float>(
template Status GreedySearchProcessLogits<float>(
const OrtValue& logits,
transformers::IGreedySearchState<float>* greedy_state,
transformers::ISamplingState<float>* sampling_state,
transformers::ISequences* sequences,
AllocatorPtr& allocator,
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ILogitsProcessorList* logits_processors,
const transformers::IBeamSearchParameters* parameters,
const transformers::IGenerationParameters* parameters,
bool do_sampling,
int step,
Stream* ort_stream,
const transformers::IConsoleDumper* dumper);

View file

@ -83,7 +83,7 @@ using ProcessLogitsFunc = std::function<Status(
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
transformers::IBeamScorer* beam_scorer, // beam scorer
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper)>; // tensor dumper
@ -92,11 +92,13 @@ template <typename T>
using GreedySearchProcessLogitsFunc = std::function<Status(
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingState<T>* sampling_state, // sampling buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* ort_stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper)>; // tensor dumper
@ -203,7 +205,7 @@ Status ProcessLogits(const OrtValue& logits, //
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
transformers::IBeamScorer* beam_scorer, // beam scorer
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper); // tensor dumper
@ -211,11 +213,13 @@ Status ProcessLogits(const OrtValue& logits, //
template <typename T>
Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingState<T>* sampling_state, // sampling buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper); // tensor dumper

View file

@ -4,6 +4,7 @@
#pragma once
#include <utility>
#include <random>
#include "core/common/gsl.h"
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
@ -33,7 +34,7 @@ struct IBeamSearchState {
gsl::span<float> scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
gsl::span<float> remaining_scores; // portion of scores that is available for appending next token scores.
gsl::span<float> topk_buffer; // temp buffer for topk computation, including:
// 1st stage needs:
// 1st stage needs:
// temp score: (batch_size * num_beams * parts_vocab, 2 * num_beams)
// temp token: (batch_size * num_beams * parts_vocab, 2 * num_beams)
// 2nd stage needs:
@ -65,6 +66,28 @@ struct IGreedySearchState {
gsl::span<int32_t> next_tokens; // shape (batch_size)
};
template <typename T>
struct ISamplingState {
gsl::span<int> d_index_in;
gsl::span<int> d_index_out;
gsl::span<int> d_offset;
gsl::span<T> d_sorted_score;
gsl::span<float> d_sorted_softmaxed_score;
gsl::span<float> d_softmaxed_score;
gsl::span<float> h_softmaxed_score;
gsl::span<float> d_sampled;
gsl::span<float> h_sampled_all;
gsl::span<int64_t> d_indices;
gsl::span<int> d_presence_mask;
BufferUniquePtr storage_buffer;
size_t temp_storage_bytes;
std::default_random_engine generator;
gsl::span<T> sorted_scores;
gsl::span<T> cumulative_probs;
};
class ISequences {
public:
virtual ~ISequences() {}
@ -96,7 +119,7 @@ class IBeamScorer {
Tensor* output_sequence_scores) = 0;
};
struct IBeamSearchParameters {
struct IGenerationParameters {
static constexpr int kModelTypeGpt = 0;
static constexpr int kModelTypeT5 = 1;
@ -120,6 +143,7 @@ struct IBeamSearchParameters {
gsl::span<const int32_t> vocab_mask;
gsl::span<const int32_t> prefix_vocab_mask;
gsl::span<const int32_t> presence_mask;
// Parameters from outputs.
bool output_scores; // whether scores existed in output
@ -129,6 +153,15 @@ struct IBeamSearchParameters {
int num_heads;
int head_size;
int num_layers;
// Parameters for TopK/TopP sampling.
float presence_penalty;
float filter_value;
float temperature = 1.0f;
float top_p = 0.0f;
int seed = 0;
int min_tokens_to_keep = 1;
bool custom_sampling = false;
};
class IConsoleDumper {

View file

@ -81,16 +81,15 @@ void GreedySearch::Init(const OpKernelInfo& info) {
parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size);
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt);
ONNX_NAMESPACE::GraphProto proto;
if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
// Make sure the encoder sub-graph attribute is present for the T5 model.
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
}
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
if (info.GetAttr<ONNX_NAMESPACE::GraphProto>("init_decoder", &proto).IsOK()) {
has_init_decoder_ = true;
@ -105,7 +104,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
const auto& node = Node();
if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { // GPT-2
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2
if (attribute_name == "decoder") {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
@ -131,8 +130,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
init_run_gpt_subgraph_ = std::move(res.second);
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
}
} else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { // encoder-decoder like T5
} else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5
ORT_THROW("Not Implemented");
// if (attribute_name == "encoder") {
// ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
@ -186,7 +184,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
if (parameters_.model_type == 0) { // GPT-2
// Subgraph has constraint that the output is either float or float16
if (!gpt_subgraph_->IsOutputFloat16()) {
GreedySearchGpt<float> impl{
GreedySearchGpt<float, GreedySearchParameters> impl{
*ctx_internal,
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
@ -207,7 +205,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
} else {
GreedySearchGpt<MLFloat16> impl{
GreedySearchGpt<MLFloat16, GreedySearchParameters> impl{
*ctx_internal,
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,

View file

@ -21,7 +21,6 @@ namespace transformers {
using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
// bugbug: refactor
class GreedySearch : public IControlFlowKernel {
public:
explicit GreedySearch(const OpKernelInfo& info)

View file

@ -2,6 +2,7 @@
// Licensed under the MIT License.
#pragma once
#include <random>
#include <vector>
#include "contrib_ops/cpu/transformers/generation_shared.h"
#include "contrib_ops/cpu/transformers/generate_impl_base.h"
@ -11,6 +12,63 @@ namespace contrib {
namespace transformers {
template <typename T>
struct SamplingState : public ISamplingState<T> {
void Init(AllocatorPtr allocator,
AllocatorPtr cpu_allocator,
int batch_size,
int vocab_size,
int max_iter,
int seed,
bool is_cuda) {
int total_count = batch_size * vocab_size;
this->h_softmaxed_score = AllocateBuffer<float>(cpu_allocator, h_softmaxed_score_buffer_, SafeInt<size_t>(total_count));
this->generator = std::default_random_engine{gsl::narrow_cast<uint32_t>(seed)};
if (is_cuda) {
this->d_index_in = AllocateBuffer<int>(allocator, d_index_in_buffer_, SafeInt<size_t>(total_count));
this->d_index_out = AllocateBuffer<int>(allocator, d_index_out_buffer_, SafeInt<size_t>(total_count));
this->d_offset = AllocateBuffer<int>(allocator, d_offset_buffer_, SafeInt<size_t>(batch_size + 1));
this->d_sorted_score = AllocateBuffer<T>(allocator, d_sorted_score_buffer_, SafeInt<size_t>(total_count));
this->d_sorted_softmaxed_score = AllocateBuffer<float>(allocator, d_sorted_softmaxed_score_buffer_, SafeInt<size_t>(total_count));
this->d_softmaxed_score = AllocateBuffer<float>(allocator, d_softmaxed_score_buffer_, SafeInt<size_t>(total_count));
this->d_sampled = AllocateBuffer<float>(allocator, d_sampled_buffer_, SafeInt<size_t>(batch_size));
this->h_sampled_all = AllocateBuffer<float>(cpu_allocator, h_sampled_all_buffer_, SafeInt<size_t>(batch_size * max_iter));
this->d_indices = AllocateBuffer<int64_t>(allocator, d_indices_buffer_, SafeInt<size_t>(batch_size));
this->temp_storage_bytes = 0;
// TODO: Do not allocate this buffer if there's no presence_mask
this->d_presence_mask = AllocateBuffer<int>(allocator, d_presence_mask_buffer_, SafeInt<size_t>(total_count));
std::uniform_real_distribution<float> distribution(0.0, 1.0);
static_cast<void>(distribution(this->generator));
for (size_t i = 0; i < this->h_sampled_all.size(); ++i) {
this->h_sampled_all[i] = distribution(this->generator);
}
} else {
// TODO: Some buffer can be reused for CPU
this->sorted_scores = AllocateBuffer<T>(cpu_allocator, sorted_scores_buffer_, SafeInt<size_t>(total_count));
this->cumulative_probs = AllocateBuffer<T>(cpu_allocator, cumulative_probs_buffer_, SafeInt<size_t>(total_count));
}
}
private:
BufferUniquePtr d_index_in_buffer_;
BufferUniquePtr d_index_out_buffer_;
BufferUniquePtr d_offset_buffer_;
BufferUniquePtr d_sorted_score_buffer_;
BufferUniquePtr d_sorted_softmaxed_score_buffer_;
BufferUniquePtr d_softmaxed_score_buffer_;
BufferUniquePtr h_softmaxed_score_buffer_;
BufferUniquePtr d_sampled_buffer_;
BufferUniquePtr h_sampled_all_buffer_;
BufferUniquePtr d_indices_buffer_;
BufferUniquePtr d_presence_mask_buffer_;
BufferUniquePtr sorted_scores_buffer_;
BufferUniquePtr cumulative_probs_buffer_;
};
template <typename T>
struct GreedySearchState : public IGreedySearchState<T> {
Sequences sequences;
@ -68,7 +126,7 @@ struct GreedySearchState : public IGreedySearchState<T> {
};
// Base class of gready search implementation that is common for both GPT-2 and Bart/T5.
template <typename T>
template <typename T, typename ParametersT>
class GreedySearchBase : public GenerateBase {
public:
GreedySearchBase(OpKernelContextInternal& context,
@ -76,7 +134,7 @@ class GreedySearchBase : public GenerateBase {
concurrency::ThreadPool* thread_pool,
Stream* ort_stream,
IConsoleDumper* cuda_dumper,
GreedySearchParameters& params,
ParametersT& params,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::GreedySearchProcessLogitsFunc<T>& process_logits_func,
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func)
@ -105,23 +163,25 @@ class GreedySearchBase : public GenerateBase {
Status GenerateNextToken(const OrtValue& logits,
gsl::span<int32_t>& next_tokens,
GreedySearchState<T>& greedy_state,
ISamplingState<T>& sampling_state,
int counter,
int eos_token_id);
// Calculate scores from logits, then apply filtering and select next token for each beam.
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
GreedySearchState<T>& greedy_state,
ISamplingState<T>& sampling_state,
AllocatorPtr& allocator,
int counter);
GreedySearchParameters* parameters_;
ParametersT* parameters_;
// Device specific functions
GenerationDeviceHelper::GreedySearchProcessLogitsFunc<T> process_logits_func_;
};
template <typename T>
Status GreedySearchBase<T>::CheckInputs(const OpKernelContextInternal& context) {
template <typename T, typename ParametersT>
Status GreedySearchBase<T, ParametersT>::CheckInputs(const OpKernelContextInternal& context) {
// Input shapes:
// input_ids : (batch_size, sequence_length)
// vocab_mask : (vocab_size) or nullptr
@ -129,13 +189,14 @@ Status GreedySearchBase<T>::CheckInputs(const OpKernelContextInternal& context)
context.Input<Tensor>(0), // input_ids
context.Input<Tensor>(4), // vocab_mask
context.Input<Tensor>(5), // prefix_vocab_mask
nullptr)); // attention_mask
context.Input<Tensor>(6), // attention_mask
context.Input<Tensor>(7))); // presence_mask
return Status::OK();
}
template <typename T>
Status GreedySearchBase<T>::Initialize() {
template <typename T, typename ParametersT>
Status GreedySearchBase<T, ParametersT>::Initialize() {
ORT_RETURN_IF_ERROR(this->context_.GetTempSpaceAllocator(&this->temp_space_allocator_));
ORT_RETURN_IF_ERROR(CheckScalarInput("max_length", 1, true));
@ -155,26 +216,29 @@ Status GreedySearchBase<T>::Initialize() {
return Status::OK();
}
template <typename T>
Status GreedySearchBase<T>::ProcessLogits(
template <typename T, typename ParametersT>
Status GreedySearchBase<T, ParametersT>::ProcessLogits(
const OrtValue& logits,
GreedySearchState<T>& greedy_state,
ISamplingState<T>& sampling_state,
AllocatorPtr& allocator,
int counter) {
return process_logits_func_(logits, &greedy_state, &(greedy_state.sequences), allocator,
this->thread_pool_, &this->logits_processors_,
parameters_, counter, this->ort_stream_, this->GetConsoleDumper());
bool use_sampling = std::is_same<ParametersT, SamplingParameters>::value;
return process_logits_func_(logits, &greedy_state, &sampling_state, &(greedy_state.sequences), allocator,
this->thread_pool_, &this->logits_processors_, parameters_,
use_sampling, counter, this->ort_stream_, this->GetConsoleDumper());
}
template <typename T>
Status GreedySearchBase<T>::GenerateNextToken(
template <typename T, typename ParametersT>
Status GreedySearchBase<T, ParametersT>::GenerateNextToken(
const OrtValue& logits,
gsl::span<int32_t>& next_tokens,
GreedySearchState<T>& greedy_state,
ISamplingState<T>& sampling_state,
int counter,
int eos_token_id) {
// Process logits to get next token scores
ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, this->temp_space_allocator_, counter));
ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, sampling_state, this->temp_space_allocator_, counter));
next_tokens = greedy_state.next_tokens;
for (size_t i = 0; i < next_tokens.size(); i++) {

View file

@ -24,8 +24,8 @@ std::pair<Status, std::unique_ptr<GptSubgraph>> CreateGptSubgraphAndUpdateParame
} // namespace gpt_details
// Greedy search implementation for GPT-2 model.
template <typename T>
class GreedySearchGpt : public GreedySearchBase<T> {
template <typename T, typename ParametersT>
class GreedySearchGpt : public GreedySearchBase<T, ParametersT> {
public:
GreedySearchGpt(OpKernelContextInternal& context,
const SessionState* init_run_decoder_session_state,
@ -35,7 +35,7 @@ class GreedySearchGpt : public GreedySearchBase<T> {
concurrency::ThreadPool* thread_pool,
Stream* ort_stream,
IConsoleDumper* cuda_dumper,
GreedySearchParameters& params,
ParametersT& params,
const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func,
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
@ -43,15 +43,15 @@ class GreedySearchGpt : public GreedySearchBase<T> {
const GenerationDeviceHelper::InitGreedyStateFunc<T>& init_greedy_state_func,
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const GenerationDeviceHelper::UpdateGptFeedsFunc<T>& update_feeds_func)
: GreedySearchBase<T>(context,
decoder_session_state,
thread_pool,
ort_stream,
cuda_dumper,
params,
topk_func,
process_logits_func,
device_copy_func),
: GreedySearchBase<T, ParametersT>(context,
decoder_session_state,
thread_pool,
ort_stream,
cuda_dumper,
params,
topk_func,
process_logits_func,
device_copy_func),
init_run_decoder_session_state_(init_run_decoder_session_state),
init_run_gpt_subgraph_(init_run_gpt_subgraph),
gpt_subgraph_(gpt_subgraph),
@ -94,8 +94,8 @@ class GreedySearchGpt : public GreedySearchBase<T> {
GenerationDeviceHelper::UpdateGptFeedsFunc<T> update_feeds_func_;
};
template <typename T>
Status GreedySearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
template <typename T, typename ParametersT>
Status GreedySearchGpt<T, ParametersT>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector<OrtValue>& feeds,
IAllocatorUniquePtr<char>& buffer) {
@ -134,8 +134,8 @@ Status GreedySearchGpt<T>::CreateInitialFeeds(gsl::span<int32_t>& sequence_lengt
this->parameters_->max_length);
}
template <typename T>
Status GreedySearchGpt<T>::UpdateFeeds(
template <typename T, typename ParametersT>
Status GreedySearchGpt<T, ParametersT>::UpdateFeeds(
const std::vector<OrtValue>& last_outputs,
std::vector<OrtValue>& next_inputs,
int current_length,
@ -161,11 +161,11 @@ Status GreedySearchGpt<T>::UpdateFeeds(
);
}
template <typename T>
Status GreedySearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager,
const FeedsFetchesManager& feeds_fetches_manager) {
template <typename T, typename ParametersT>
Status GreedySearchGpt<T, ParametersT>::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager,
const FeedsFetchesManager& feeds_fetches_manager) {
auto status = Status::OK();
const GreedySearchParameters* parameters = this->parameters_;
const ParametersT* parameters = this->parameters_;
// Allocate output tensors.
int64_t sequences_dims[] = {parameters->batch_size, parameters->max_length};
@ -184,6 +184,17 @@ Status GreedySearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fet
parameters->max_length,
this->IsCuda());
SamplingState<T> sampling_state;
if (std::is_same<ParametersT, SamplingParameters>::value) {
sampling_state.Init(this->temp_space_allocator_,
this->cpu_allocator_,
static_cast<int>(parameters->BatchBeamSize()),
static_cast<int>(parameters->vocab_size),
static_cast<int>(parameters->max_length - parameters->sequence_length),
parameters->seed,
this->IsCuda());
}
IAllocatorUniquePtr<char> buffer;
OrtValue expanded_input_ids_in_cpu;
ORT_RETURN_IF_ERROR(CreateInitialFeeds(greedy_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
@ -276,6 +287,7 @@ Status GreedySearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fet
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
next_tokens,
greedy_state,
sampling_state,
iteration_counter,
parameters->eos_token_id));
@ -324,6 +336,27 @@ Status GreedySearchGpt<T>::Execute(const FeedsFetchesManager* init_run_feeds_fet
gsl::copy(sequence_source, batch_output);
}
#ifdef DEBUG_GENERATION
// Debug the one step filtered logits for sampling
int64_t filtered_logits_dims[] = {parameters->batch_size, parameters->vocab_size};
TensorShape filtered_logits_shape(&filtered_logits_dims[0],
sizeof(filtered_logits_dims) / sizeof(filtered_logits_dims[0]));
Tensor* filtered_logits = this->context_.Output(1, filtered_logits_shape);
if (filtered_logits != nullptr) {
gsl::span<float> filtered_logits_span = filtered_logits->MutableDataAsSpan<float>();
for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) {
auto batch_output = filtered_logits_span.subspan(
static_cast<size_t>(batch_id) * parameters->vocab_size,
parameters->vocab_size);
gsl::span<const float> batch_filtered_logits = gsl::make_span(sampling_state.h_softmaxed_score.data() +
batch_id * parameters->vocab_size,
parameters->vocab_size);
gsl::copy(batch_filtered_logits, batch_output);
}
}
#endif
return status;
}

View file

@ -6,8 +6,12 @@
#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/common/span_utils.h"
#include "core/providers/cpu/math/softmax_shared.h"
#include "contrib_ops/cpu/transformers/logits_processor.h"
#include "contrib_ops/cpu/transformers/dump_tensor.h"
#include <vector>
#include <numeric>
#include <algorithm>
namespace onnxruntime {
namespace contrib {
@ -187,6 +191,53 @@ void PrefixVocabMaskLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
#endif
}
template <typename T>
TemperatureLogitsProcessor<T>::TemperatureLogitsProcessor(float temperature) : temperature_(temperature) {
}
template <typename T>
void TemperatureLogitsProcessor<T>::Process(const ISequences* /*sequences*/,
NextTokenScores<T>& next_token_scores) {
if (temperature_ == 1.0f) {
return;
}
T* p = next_token_scores.scores.data();
for (size_t i = 0; i < next_token_scores.scores.size(); i++) {
*p /= temperature_;
++p;
}
#ifdef DEBUG_GENERATION
DumpScores("TemperatureLogitsProcessor", next_token_scores);
#endif
}
template <typename T>
PresencePenaltyLogitsProcessor<T>::PresencePenaltyLogitsProcessor(const gsl::span<const int32_t>& presence_mask,
float presence_penalty)
: presence_mask_(presence_mask), presence_penalty_(presence_penalty) {
}
template <typename T>
void PresencePenaltyLogitsProcessor<T>::Process(const ISequences*,
NextTokenScores<T>& next_token_scores) {
if (presence_penalty_ == 0.0f) {
return;
}
assert(!presence_mask_.empty());
T* p = next_token_scores.scores.data();
for (size_t i = 0; i < next_token_scores.scores.size(); i++) {
*p -= presence_mask_[i] * presence_penalty_;
}
#ifdef DEBUG_GENERATION
DumpScores("PresencePenaltyLogitsProcessor", next_token_scores);
#endif
}
void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
LogitsProcessorInitImpl<BeamSearchParameters>(parameters);
}
@ -195,6 +246,10 @@ void LogitsProcessorList::Init(const GreedySearchParameters& parameters) {
LogitsProcessorInitImpl<GreedySearchParameters>(parameters);
}
void LogitsProcessorList::Init(const SamplingParameters& parameters) {
LogitsProcessorInitImpl<SamplingParameters>(parameters);
}
void LogitsProcessorList::Process(const ISequences* sequences,
gsl::span<float>& next_token_scores,
int step) {

View file

@ -7,6 +7,7 @@
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "contrib_ops/cpu/transformers/greedy_search_parameters.h"
#include "contrib_ops/cpu/transformers/sampling_parameters.h"
#include "contrib_ops/cpu/transformers/generation_shared.h"
namespace onnxruntime {
@ -96,11 +97,53 @@ class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor<T> {
const int batch_size_;
};
template <typename T>
class TemperatureLogitsProcessor : public ILogitsProcessor<T> {
public:
TemperatureLogitsProcessor(float temperature);
void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;
private:
float temperature_;
};
// template <typename T>
// class TopPLogitsProcessor : public ILogitsProcessor<T> {
// public:
// TopPLogitsProcessor(float top_p, float filter_value,
// onnxruntime::concurrency::ThreadPool* thread_pool);
// void Process(const ISequences* sequences,
// NextTokenScores<T>& next_token_scores) override;
// private:
// float top_p_;
// float filter_value_;
// onnxruntime::concurrency::ThreadPool* thread_pool_;
// };
template <typename T>
class PresencePenaltyLogitsProcessor : public ILogitsProcessor<T> {
public:
PresencePenaltyLogitsProcessor(const gsl::span<const int32_t>& presence_mask,
float presence_penalty);
void Process(const ISequences* sequences,
NextTokenScores<T>& next_token_scores) override;
private:
gsl::span<const int32_t> presence_mask_;
float presence_penalty_;
};
class LogitsProcessorList : public ILogitsProcessorList {
public:
LogitsProcessorList() = default;
void Init(const BeamSearchParameters& parameters);
void Init(const GreedySearchParameters& parameters);
void Init(const SamplingParameters& parameters);
void Process(const ISequences* sequences, gsl::span<float>& next_token_scores, int step);
private:
@ -140,6 +183,19 @@ class LogitsProcessorList : public ILogitsProcessorList {
processor_list_.push_back(min_length_processor_.get());
}
if (parameters.temperature > 0) {
temperature_processor_ = std::make_unique<TemperatureLogitsProcessor<float>>(parameters.temperature);
processor_list_.push_back(temperature_processor_.get());
}
if (!parameters.presence_mask.empty()) {
presence_penalty_processor_ = std::make_unique<
PresencePenaltyLogitsProcessor<float>
>(parameters.presence_mask,
parameters.presence_penalty);
processor_list_.push_back(presence_penalty_processor_.get());
}
batch_beam_size_ = parameters.BatchBeamSize();
vocab_size_ = parameters.vocab_size;
}
@ -153,6 +209,8 @@ class LogitsProcessorList : public ILogitsProcessorList {
std::unique_ptr<VocabMaskLogitsProcessor<float>> vocab_mask_processor_;
std::unique_ptr<PrefixVocabMaskLogitsProcessor<float>> prefix_vocab_mask_processor_;
std::unique_ptr<MinLengthLogitsProcessor<float>> min_length_processor_;
std::unique_ptr<TemperatureLogitsProcessor<float>> temperature_processor_;
std::unique_ptr<PresencePenaltyLogitsProcessor<float>> presence_penalty_processor_;
};
} // namespace transformers

View file

@ -0,0 +1,175 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// there's no way to use a raw pointer as the copy destination with std::copy_n
// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
#ifdef _MSC_VER
#pragma warning(push)
#pragma warning(disable : 4996)
#endif
#include "core/framework/op_kernel_context_internal.h"
#include "core/framework/utils.h"
#include "contrib_ops/cpu/transformers/sampling.h"
#include "contrib_ops/cpu/transformers/logits_processor.h"
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/dump_tensor.h"
#include "contrib_ops/cpu/transformers/greedy_search_impl_gpt.h"
using namespace ONNX_NAMESPACE;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
#define REGISTER_KERNEL_TYPED(T) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
Sampling, \
kMSDomain, \
1, \
T, \
kCpuExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T", DataTypeImpl::GetTensorType<T>()), \
transformers::Sampling);
REGISTER_KERNEL_TYPED(float)
namespace transformers {
void Sampling::Init(const OpKernelInfo& info) {
parameters_.ParseFromAttributes(info);
parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size);
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt);
ONNX_NAMESPACE::GraphProto proto;
if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
// Make sure the encoder sub-graph attribute is present for the T5 model.
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("encoder", &proto).IsOK());
}
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
if (info.GetAttr<ONNX_NAMESPACE::GraphProto>("init_decoder", &proto).IsOK()) {
has_init_decoder_ = true;
}
}
// Make sure the decoder sub-graph attribute is present for all model types.
ORT_ENFORCE(info.GetAttr<ONNX_NAMESPACE::GraphProto>("decoder", &proto).IsOK());
}
Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
const auto& node = Node();
if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2
if (attribute_name == "decoder") {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
subgraph_session_state, parameters_);
auto status = res.first;
if (!status.IsOK()) {
return status;
}
gpt_subgraph_ = std::move(res.second);
decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
} else if (attribute_name == "init_decoder") {
ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
subgraph_session_state, parameters_);
auto status = res.first;
if (!status.IsOK()) {
return status;
}
init_run_gpt_subgraph_ = std::move(res.second);
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
}
} else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5
ORT_THROW("Not Implemented");
}
return Status::OK();
}
Status Sampling::Compute(OpKernelContext* ctx) const {
auto* ctx_internal = static_cast<OpKernelContextInternal*>(ctx);
auto* decoder_session_state = ctx_internal->SubgraphSessionState("decoder");
ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
auto* init_run_decoder_session_state = ctx_internal->SubgraphSessionState("init_decoder");
if (has_init_decoder_) {
ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute.");
ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph.");
ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_
&& init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_,
"past_present_share_buffer mode must be same for init decoder and decoder subgraphes");
}
concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool();
// make a copy since we will update the parameters based on inputs later
SamplingParameters parameters = parameters_;
if (parameters_.model_type == 0) { // GPT-2
// Subgraph has constraint that the output is either float or float16
if (!gpt_subgraph_->IsOutputFloat16()) {
GreedySearchGpt<float, SamplingParameters> impl{
*ctx_internal,
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
*decoder_session_state,
*gpt_subgraph_,
thread_pool,
ctx->GetComputeStream(),
dumper_,
parameters,
GenerationCpuDeviceHelper::CreateGptInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::GreedySearchProcessLogits<float>,
init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState<float>,
device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy<float>,
update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds<float>};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
} else {
GreedySearchGpt<MLFloat16, SamplingParameters> impl{
*ctx_internal,
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
*decoder_session_state,
*gpt_subgraph_,
thread_pool,
ctx->GetComputeStream(),
dumper_,
parameters,
GenerationCpuDeviceHelper::CreateGptInputs,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_greedy_state_fp16_func_,
device_copy_func_,
update_gpt_feeds_fp16_func_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
}
}
return Status::OK();
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,103 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <memory>
#include <string>
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cpu/controlflow/utils.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
#include "contrib_ops/cpu/transformers/generation_device_helper.h"
#include "contrib_ops/cpu/transformers/sampling_parameters.h"
namespace onnxruntime {
class FeedsFetchesManager;
namespace contrib {
namespace transformers {
using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
class Sampling : public IControlFlowKernel {
public:
explicit Sampling(const OpKernelInfo& info)
: IControlFlowKernel(info),
decoder_feeds_fetches_manager_(nullptr),
dumper_(nullptr) {
Init(info);
}
void Init(const OpKernelInfo& info);
Status Compute(OpKernelContext* ctx) const override;
Status SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) override;
protected:
void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; }
// device helpers that is same for both GPT and encoder-decoder models.
void SetDeviceHelpers(
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
const GenerationDeviceHelper::GreedySearchProcessLogitsFunc<float>& process_logits_func,
const GenerationDeviceHelper::GreedySearchProcessLogitsFunc<MLFloat16>& process_logits_fp16_func,
const GenerationDeviceHelper::InitGreedyStateFunc<float>& init_greedy_state_func,
const GenerationDeviceHelper::InitGreedyStateFunc<MLFloat16>& init_greedy_state_fp16_func) {
add_to_feeds_func_ = add_to_feeds_func;
topk_func_ = topk_func;
device_copy_func_ = device_copy_func;
process_logits_func_ = process_logits_func;
process_logits_fp16_func_ = process_logits_fp16_func;
init_greedy_state_func_ = init_greedy_state_func;
init_greedy_state_fp16_func_ = init_greedy_state_fp16_func;
}
void SetDeviceHelpers_Gpt(
const GenerationDeviceHelper::UpdateGptFeedsFunc<float>& update_gpt_feeds_func,
const GenerationDeviceHelper::UpdateGptFeedsFunc<MLFloat16>& update_gpt_feeds_fp16_func) {
update_gpt_feeds_func_ = update_gpt_feeds_func;
update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func;
}
private:
// Device specific functions
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::TopkFunc topk_func_;
GenerationDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
GenerationDeviceHelper::GreedySearchProcessLogitsFunc<float> process_logits_func_;
GenerationDeviceHelper::GreedySearchProcessLogitsFunc<MLFloat16> process_logits_fp16_func_;
GenerationDeviceHelper::InitGreedyStateFunc<float> init_greedy_state_func_;
GenerationDeviceHelper::InitGreedyStateFunc<MLFloat16> init_greedy_state_fp16_func_;
//------------------------------------------------------------
// Device specific functions for GPT
//------------------------------------------------------------
GenerationDeviceHelper::UpdateGptFeedsFunc<float> update_gpt_feeds_func_;
GenerationDeviceHelper::UpdateGptFeedsFunc<MLFloat16> update_gpt_feeds_fp16_func_;
//------------------------------------------------------------
// Subgraph and FeedsFetchesManager re-used for each subgraph execution.
//------------------------------------------------------------
std::unique_ptr<GptSubgraph> init_run_gpt_subgraph_;
std::unique_ptr<GptSubgraph> gpt_subgraph_;
FeedsFetchesManager* decoder_feeds_fetches_manager_;
FeedsFetchesManager* init_run_decoder_feeds_fetches_manager_;
IConsoleDumper* dumper_;
SamplingParameters parameters_;
bool has_init_decoder_ = false;
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,161 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
namespace onnxruntime {
namespace contrib {
namespace SamplingCpuHelper {
template <typename T>
void filter_scores(std::vector<size_t>& sorted_indice,
gsl::span<T>& next_token_score,
const transformers::IGenerationParameters* parameters,
size_t index) {
size_t real_index = sorted_indice[index];
next_token_score[real_index] = (T)parameters->filter_value;
}
template <typename T>
void cumulate_and_filter_custom(gsl::span<T>& next_token_scores,
gsl::span<T>& cumulative_probs,
const transformers::IGenerationParameters* parameters,
std::vector<size_t>& sorted_indices) {
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
size_t offset = i * parameters->vocab_size;
if (cumulative_probs[offset] > parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, 1 + offset);
}
for (size_t j = 1; j < static_cast<size_t>(parameters->vocab_size) - 1; j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] > parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, j + offset + 1);
}
}
}
}
template <typename T>
void cumulate_and_filter(gsl::span<T>& next_token_scores,
gsl::span<T>& cumulative_probs,
const transformers::IGenerationParameters* parameters,
std::vector<size_t>& sorted_indices) {
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
size_t offset = i * parameters->vocab_size;
if (cumulative_probs[offset] <= 1 - parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, offset);
}
for (size_t j = 1; j < static_cast<size_t>(parameters->vocab_size) - static_cast<size_t>(parameters->min_tokens_to_keep); j++) {
cumulative_probs[j + offset] += cumulative_probs[j + offset - 1];
if (cumulative_probs[j + offset] <= 1 - parameters->top_p) {
filter_scores(sorted_indices, next_token_scores, parameters, j + offset);
}
}
}
}
template <typename T>
Status Sample(AllocatorPtr& allocator,
onnxruntime::concurrency::ThreadPool* thread_pool,
gsl::span<T>& next_token_scores,
transformers::ISamplingState<T>* sampling_state,
transformers::IGreedySearchState<T>* greedy_state,
const transformers::IGenerationParameters* parameters,
const transformers::IConsoleDumper* dumper) {
ORT_UNUSED_PARAMETER(dumper);
gsl::span<T>& sorted_scores = sampling_state->sorted_scores;
memcpy(sorted_scores.data(), next_token_scores.data(), next_token_scores.size_bytes());
std::vector<size_t> sorted_indices(static_cast<size_t>(parameters->batch_size) * static_cast<size_t>(parameters->vocab_size));
std::function<bool(T, T)> predicator;
if (parameters->custom_sampling) {
predicator = std::greater<T>();
} else {
predicator = std::less<T>();
}
// TODO: This could be optimized with allocated buffer and handwritten sort algorithm
for (size_t i = 0; i < static_cast<size_t>(parameters->batch_size); i++) {
auto indices_begin = sorted_indices.begin() + i * parameters->vocab_size;
auto indices_end = sorted_indices.begin() + (i + 1) * parameters->vocab_size;
std::iota(indices_begin, indices_end, 0);
std::sort(indices_begin, indices_end,
[&next_token_scores, &predicator](size_t i1, size_t i2) {
return !predicator(next_token_scores[i1], next_token_scores[i2]);
});
std::sort(sorted_scores.begin() + i * parameters->vocab_size,
sorted_scores.begin() + (i + 1) * parameters->vocab_size,
predicator);
}
gsl::span<T>& cumulative_probs = sampling_state->cumulative_probs;
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters->batch_size,
parameters->vocab_size,
sorted_scores.data(),
cumulative_probs.data(),
false,
thread_pool));
if (parameters->custom_sampling) {
cumulate_and_filter_custom(next_token_scores, cumulative_probs, parameters, sorted_indices);
} else {
cumulate_and_filter(next_token_scores, cumulative_probs, parameters, sorted_indices);
}
gsl::span<T>& next_token_probs = sampling_state->h_softmaxed_score;
ORT_RETURN_IF_ERROR(SoftmaxCPU<T>(parameters->batch_size,
parameters->vocab_size,
next_token_scores.data(),
next_token_probs.data(),
false,
thread_pool));
// torch.multinomial()
int64_t next_token_probs_dims[] = {static_cast<int64_t>(parameters->batch_size), parameters->vocab_size};
TensorShape next_token_probs_shape(&next_token_probs_dims[0], 2);
auto element_type = DataTypeImpl::GetType<T>();
OrtValue next_token_probs_value;
Tensor::InitOrtValue(element_type,
next_token_probs_shape,
next_token_probs.data(),
allocator->Info(),
next_token_probs_value);
const Tensor& input = next_token_probs_value.Get<Tensor>();
std::default_random_engine& generator = sampling_state->generator;
int64_t sampled_idx_dims[] = {static_cast<int64_t>(parameters->batch_size), 1};
TensorShape sampled_idx_shape(&sampled_idx_dims[0], 2);
gsl::span<int64_t>& next_token_idx = greedy_state->next_tokens_cpu;
OrtValue sampled_idx_ov;
Tensor::InitOrtValue(DataTypeImpl::GetType<int64_t>(),
sampled_idx_shape,
next_token_idx.data(),
allocator->Info(),
sampled_idx_ov);
Tensor* sampled_idx = sampled_idx_ov.GetMutable<Tensor>();
// Copy the allocator because MultinomialComputeShared() uses move(allocator)
AllocatorPtr allocatortemp = allocator;
ORT_RETURN_IF_ERROR(MultinomialComputeShared<int64_t>(allocatortemp,
input,
parameters->batch_size,
parameters->vocab_size,
1,
generator,
*sampled_idx));
// TODO: update presense_mask()
#ifdef DEBUG_GENERATION
dumper->Print("sampled_idx", *sampled_idx);
#endif
return Status::OK();
}
} // namespace SamplingCudaHelper
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/transformers/sampling_parameters.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) {
model_type = static_cast<int>(info.GetAttrOrDefault<int64_t>("model_type", 0));
eos_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("eos_token_id", -1));
pad_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("pad_token_id", -1));
decoder_start_token_id = static_cast<int>(info.GetAttrOrDefault<int64_t>("decoder_start_token_id", -1));
no_repeat_ngram_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("no_repeat_ngram_size", 0));
temperature = info.GetAttrOrDefault<float>("temperature", 1.0f);
top_p = info.GetAttrOrDefault<float>("top_p", 0.0f);
filter_value = info.GetAttrOrDefault<float>("filter_value", -std::numeric_limits<float>::infinity());
min_tokens_to_keep = static_cast<int>(info.GetAttrOrDefault<int64_t>("min_tokens_to_keep", 0));
presence_penalty = info.GetAttrOrDefault<float>("presence_penalty", 0.0f);
custom_sampling = static_cast<int>(info.GetAttrOrDefault<int64_t>("custom", 0));
vocab_size = static_cast<int>(info.GetAttrOrDefault<int64_t>("vocab_size", -1));
}
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,19 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "contrib_ops/cpu/transformers/greedy_search_parameters.h"
namespace onnxruntime {
namespace contrib {
namespace transformers {
struct SamplingParameters : public GreedySearchParameters {
void ParseFromAttributes(const OpKernelInfo& info);
};
} // namespace transformers
} // namespace contrib
} // namespace onnxruntime

View file

@ -75,6 +75,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh);
@ -189,6 +190,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh)>,

View file

@ -1,286 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "cub/util_type.cuh"
#include "contrib_ops/cuda/transformers/beam_search_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
__global__ void InitKernel(float* beam_scores,
int num_beams,
int total_elements) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < total_elements) {
int beam_index = index % num_beams;
beam_scores[index] = beam_index > 0 ? static_cast<float>(-1e9) : 0.0f;
}
}
void LaunchInitKernel(
float* beam_scores,
int batch_size,
int num_beams,
cudaStream_t stream) {
int total_elements = batch_size * num_beams;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
InitKernel<<<gridSize, blockSize, 0, stream>>>(beam_scores, num_beams, total_elements);
}
__global__ void NextTokenKernel(const int64_t* next_token_indices,
int32_t* next_indices,
int32_t* next_tokens,
int vocab_size,
int total_elements) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < total_elements) {
next_indices[index] = next_token_indices[index] / vocab_size;
next_tokens[index] = next_token_indices[index] % vocab_size;
}
}
void LaunchNextTokenKernel(const int64_t* next_token_indices,
int32_t* next_indices,
int32_t* next_tokens,
int batch_size,
int top_k,
int vocab_size,
cudaStream_t stream) {
int total_elements = batch_size * top_k;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
NextTokenKernel<<<gridSize, blockSize, 0, stream>>>(next_token_indices,
next_indices,
next_tokens,
vocab_size,
total_elements);
}
template <typename T>
__global__ void LogitsProcessKernel(
T* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int num_beams,
int vocab_size,
int padded_vocab_size,
int total_elements,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < total_elements) {
int batch_beam_index = index / padded_vocab_size;
int word_id = index % padded_vocab_size;
if (word_id >= vocab_size) {
// Set any value within the padding region to the lowest value so that it isn't picked
next_token_scores[index] = cub::FpLimits<T>::Lowest();
} else {
// RepetitionPenaltyLogitsProcessor
if (repetition_penalty != 1.0f) {
int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length;
bool found = false;
for (int i = 0; i < current_sequence_length; i++) {
if (current_sequence[i] == word_id) {
found = true;
break;
}
}
if (found) {
float score = (float)next_token_scores[index];
next_token_scores[index] = (T)(score < 0 ? score * repetition_penalty : score / repetition_penalty);
}
}
// NoRepeatNGramLogitsProcessor
if (no_repeat_ngram_size > 0 && current_sequence_length >= no_repeat_ngram_size) {
int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length;
bool found = false;
for (int i = no_repeat_ngram_size - 1; i < current_sequence_length; i++) {
if (current_sequence[i] == word_id) { // last token of n-gram matched
found = true;
for (int j = 0; j < no_repeat_ngram_size - 1; j++) { // match the remaining N-1 tokens
if (current_sequence[i - j - 1] != current_sequence[current_sequence_length - 1 - j]) {
found = false;
break;
}
}
if (found) {
break;
}
}
}
if (found) {
next_token_scores[index] = cub::FpLimits<T>::Lowest();
return;
}
}
// VocabMaskLogitsProcessor
if (vocab_mask != nullptr && vocab_mask[word_id] == 0) {
next_token_scores[index] = cub::FpLimits<T>::Lowest();
return;
}
// PrefixVocabMaskLogitsProcessor
int batch_id = batch_beam_index / num_beams;
if (prefix_vocab_mask != nullptr && prefix_vocab_mask[batch_id * vocab_size + word_id] == 0) {
next_token_scores[index] = cub::FpLimits<T>::Lowest();
return;
}
// MinLengthLogitsProcessor
if (word_id == demote_token_id) {
next_token_scores[index] = cub::FpLimits<T>::Lowest();
}
}
}
}
template <typename T>
void LaunchLogitsProcessKernel(
T* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int batch_size,
int num_beams,
int vocab_size,
int padded_vocab_size,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream) {
int total_elements = batch_size * num_beams * padded_vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(
next_token_scores,
vocab_mask,
prefix_vocab_mask,
num_beams,
vocab_size,
padded_vocab_size,
total_elements,
demote_token_id,
sequences,
max_sequence_length,
current_sequence_length,
repetition_penalty,
no_repeat_ngram_size);
}
// Instantiation
template void LaunchLogitsProcessKernel(
float* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int batch_size,
int num_beams,
int vocab_size,
int padded_vocab_size,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream);
template void LaunchLogitsProcessKernel(
half* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int batch_size,
int num_beams,
int vocab_size,
int padded_vocab_size,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream);
__global__ void AddProbsKernel(float* log_probs,
float* cum_log_probs,
const int vocab_size,
const int total_elements) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int batch_beam_index = index / vocab_size;
if (index < total_elements)
log_probs[index] += cum_log_probs[batch_beam_index];
}
template <typename T>
void LaunchAddProbsKernel(T* log_probs,
T* cum_log_probs,
const int batch_size,
const int num_beams,
const int vocab_size,
cudaStream_t stream) {
int total_elements = batch_size * num_beams * vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
AddProbsKernel<<<gridSize, blockSize, 0, stream>>>(log_probs, cum_log_probs, vocab_size, total_elements);
}
template void LaunchAddProbsKernel(
float* log_probs,
float* cum_log_probs,
const int batch_size,
const int num_beams,
const int vocab_size,
cudaStream_t stream);
template <typename T>
__global__ void UpdateGptInputsKernel(const T* old_mask_data,
T* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < batch_beam_size * current_length) {
// Update attention mask.
int i = index / current_length;
int j = index % current_length;
mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast<T>(1);
if (next_positions != nullptr) {
// Update sequence length (or next positions).
if (index < batch_beam_size) {
next_positions[index]++;
}
}
}
}
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
int32_t* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length,
cudaStream_t stream) {
assert(current_length > 0);
int total_elements = batch_beam_size * current_length;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
UpdateGptInputsKernel<int32_t><<<gridSize, blockSize, 0, stream>>>(
old_mask_data, mask_data, next_positions, batch_beam_size, current_length);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -1,61 +0,0 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include <cuda_fp16.h>
namespace onnxruntime {
namespace contrib {
namespace cuda {
void LaunchInitKernel(
float* beam_scores,
int batch_size,
int num_beams,
cudaStream_t stream);
template <typename T>
void LaunchAddProbsKernel(T* log_probs,
T* cum_log_probs,
const int batch_size,
const int num_beams,
const int vocab_size,
cudaStream_t stream);
template <typename T>
void LaunchLogitsProcessKernel(
T* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int batch_size,
int num_beams,
int vocab_size,
int padded_vocab_size,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream);
void LaunchNextTokenKernel(const int64_t* next_token_indices,
int32_t* next_indices,
int32_t* next_tokens,
int batch_size,
int top_k,
int vocab_size,
cudaStream_t stream);
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
int32_t* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,758 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cu_inc/common.cuh"
#include "cub/util_type.cuh"
#include <cub/cub.cuh>
#include <cub/device/device_segmented_radix_sort.cuh>
#include "contrib_ops/cuda/transformers/generation_cuda_impl.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
__global__ void InitKernel(float* beam_scores,
int num_beams,
int total_elements) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < total_elements) {
int beam_index = index % num_beams;
beam_scores[index] = beam_index > 0 ? static_cast<float>(-1e9) : 0.0f;
}
}
void LaunchInitKernel(
float* beam_scores,
int batch_size,
int num_beams,
cudaStream_t stream) {
int total_elements = batch_size * num_beams;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
InitKernel<<<gridSize, blockSize, 0, stream>>>(beam_scores, num_beams, total_elements);
}
__global__ void NextTokenKernel(const int64_t* next_token_indices,
int32_t* next_indices,
int32_t* next_tokens,
int vocab_size,
int total_elements) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < total_elements) {
next_indices[index] = next_token_indices[index] / vocab_size;
next_tokens[index] = next_token_indices[index] % vocab_size;
}
}
void LaunchNextTokenKernel(const int64_t* next_token_indices,
int32_t* next_indices,
int32_t* next_tokens,
int batch_size,
int top_k,
int vocab_size,
cudaStream_t stream) {
int total_elements = batch_size * top_k;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
NextTokenKernel<<<gridSize, blockSize, 0, stream>>>(next_token_indices,
next_indices,
next_tokens,
vocab_size,
total_elements);
}
template <typename T>
__global__ void LogitsProcessKernel(
T* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
const int* presence_mask,
float presence_penalty,
float temperature,
int num_beams,
int vocab_size,
int padded_vocab_size,
int total_elements,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < total_elements) {
int batch_beam_index = index / padded_vocab_size;
int word_id = index % padded_vocab_size;
if (word_id >= vocab_size) {
// Set any value within the padding region to the lowest value so that it isn't picked
next_token_scores[index] = cub::FpLimits<T>::Lowest();
} else {
// RepetitionPenaltyLogitsProcessor
if (repetition_penalty != 1.0f) {
int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length;
bool found = false;
for (int i = 0; i < current_sequence_length; i++) {
if (current_sequence[i] == word_id) {
found = true;
break;
}
}
if (found) {
float score = (float)next_token_scores[index];
next_token_scores[index] = (T)(score < 0 ? score * repetition_penalty : score / repetition_penalty);
}
}
// NoRepeatNGramLogitsProcessor
if (no_repeat_ngram_size > 0 && current_sequence_length >= no_repeat_ngram_size) {
int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length;
bool found = false;
for (int i = no_repeat_ngram_size - 1; i < current_sequence_length; i++) {
if (current_sequence[i] == word_id) { // last token of n-gram matched
found = true;
for (int j = 0; j < no_repeat_ngram_size - 1; j++) { // match the remaining N-1 tokens
if (current_sequence[i - j - 1] != current_sequence[current_sequence_length - 1 - j]) {
found = false;
break;
}
}
if (found) {
break;
}
}
}
if (found) {
next_token_scores[index] = cub::FpLimits<T>::Lowest();
return;
}
}
// VocabMaskLogitsProcessor
if (vocab_mask != nullptr && vocab_mask[word_id] == 0) {
next_token_scores[index] = cub::FpLimits<T>::Lowest();
return;
}
// PrefixVocabMaskLogitsProcessor
int batch_id = batch_beam_index / num_beams;
if (prefix_vocab_mask != nullptr && prefix_vocab_mask[batch_id * vocab_size + word_id] == 0) {
next_token_scores[index] = cub::FpLimits<T>::Lowest();
return;
}
// MinLengthLogitsProcessor
if (word_id == demote_token_id) {
next_token_scores[index] = cub::FpLimits<T>::Lowest();
}
// PresencePenaltyLogitsProcessor
if (presence_mask != nullptr && presence_mask[index] == 1) {
float score = (float)next_token_scores[index] - presence_penalty;
next_token_scores[index] = (T)score;
}
// TemperatureLogitsProcessor
if (temperature != 1.0f) {
float score = (float)(next_token_scores[index]);
next_token_scores[index] = (T)(score / temperature);
}
}
}
}
template <typename T>
void LaunchLogitsProcessKernel(
T* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int* presence_mask,
float presence_penalty,
float temperature,
int batch_size,
int num_beams,
int vocab_size,
int padded_vocab_size,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream) {
int total_elements = batch_size * num_beams * vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(
next_token_scores,
vocab_mask,
prefix_vocab_mask,
presence_mask,
presence_penalty,
temperature,
num_beams,
vocab_size,
padded_vocab_size,
total_elements,
demote_token_id,
sequences,
max_sequence_length,
current_sequence_length,
repetition_penalty,
no_repeat_ngram_size);
}
// Instantiation
template void LaunchLogitsProcessKernel(
float* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int* presence_mask,
float presence_penalty,
float temperature,
int batch_size,
int num_beams,
int vocab_size,
int padded_vocab_size,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream);
template void LaunchLogitsProcessKernel(
half* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int* presence_mask,
float presence_penalty,
float temperature,
int batch_size,
int num_beams,
int vocab_size,
int padded_vocab_size,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream);
__global__ void AddProbsKernel(float* log_probs,
float* cum_log_probs,
const int vocab_size,
const int total_elements) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int batch_beam_index = index / vocab_size;
if (index < total_elements)
log_probs[index] += cum_log_probs[batch_beam_index];
}
template <typename T>
void LaunchAddProbsKernel(T* log_probs,
T* cum_log_probs,
const int batch_size,
const int num_beams,
const int vocab_size,
cudaStream_t stream) {
int total_elements = batch_size * num_beams * vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
AddProbsKernel<<<gridSize, blockSize, 0, stream>>>(log_probs, cum_log_probs, vocab_size, total_elements);
}
template void LaunchAddProbsKernel(
float* log_probs,
float* cum_log_probs,
const int batch_size,
const int num_beams,
const int vocab_size,
cudaStream_t stream);
template <typename T>
__global__ void UpdateGptInputsKernel(const T* old_mask_data,
T* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index < batch_beam_size * current_length) {
// Update attention mask.
int i = index / current_length;
int j = index % current_length;
mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast<T>(1);
if (next_positions != nullptr) {
// Update sequence length (or next positions).
if (index < batch_beam_size) {
next_positions[index]++;
}
}
}
}
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
int32_t* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length,
cudaStream_t stream) {
assert(current_length > 0);
int total_elements = batch_beam_size * current_length;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
UpdateGptInputsKernel<int32_t><<<gridSize, blockSize, 0, stream>>>(
old_mask_data, mask_data, next_positions, batch_beam_size, current_length);
}
template <typename T>
void GetTempStorageSize(const T *d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes) {
if (is_descending) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairs(nullptr,
temp_storage_bytes,
d_keys_in,
(T*)nullptr,
d_values_in,
(int*)nullptr,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
}
}
template void GetTempStorageSize(
const float *d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes);
template void GetTempStorageSize(
const half *d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes);
// TODO: merge to one kernel
__global__ void SetupParamsKernel(int* d_values_in,
int* d_offsets,
int batch_size,
int vocab_size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
int total_elements = batch_size * vocab_size;
if (index < total_elements) {
d_values_in[index] = index % vocab_size;
}
if (index < batch_size + 1) {
d_offsets[index] = index * vocab_size;
}
}
void LaunchSetupParamsKernel(int* d_values_in,
int* d_offsets,
int batch_size,
int vocab_size,
cudaStream_t stream) {
int total_elements = batch_size * vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
SetupParamsKernel<<<gridSize, blockSize, 0, stream>>>(d_values_in,
d_offsets,
batch_size,
vocab_size);
}
template <typename T>
void LaunchSortPairs(void *d_temp_storage,
size_t temp_storage_bytes,
const T *d_keys_in,
T *d_keys_out,
const int *d_values_in,
int *d_values_out,
int num_items,
int num_segments,
int *d_offsets,
cudaStream_t stream,
bool is_descending) {
if (is_descending) {
cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
} else {
cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage,
temp_storage_bytes,
d_keys_in,
d_keys_out,
d_values_in,
d_values_out,
num_items,
num_segments,
d_offsets,
d_offsets + 1,
0,
sizeof(T) * 8,
stream);
}
}
template void LaunchSortPairs(void *d_temp_storage,
size_t temp_storage_bytes,
const float *d_keys_in,
float *d_keys_out,
const int *d_values_in,
int *d_values_out,
int num_items,
int num_segments,
int *d_offsets,
cudaStream_t stream,
bool is_descending);
template void LaunchSortPairs(void *d_temp_storage,
size_t temp_storage_bytes,
const half *d_keys_in,
half *d_keys_out,
const int *d_values_in,
int *d_values_out,
int num_items,
int num_segments,
int *d_offsets,
cudaStream_t stream,
bool is_descending);
template <typename T>
__global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in,
const int* d_sorted_indices,
T* d_logits_in_out,
float top_p_threshold,
float filter_value,
int batch_size,
int vocab_size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= batch_size * vocab_size) {
return;
}
int vocab_idx = index % vocab_size;
int batch_id = index / vocab_size;
int start_index = batch_id * vocab_size;
int count = vocab_idx;
float sum = 0.0f;
while (count >= 0) {
sum += d_sorted_logits_in[start_index];
++start_index;
--count;
}
if (sum > top_p_threshold) {
// Shift the indices to the right by one according to the custom implementation.
int shifted_index = index + 1;
if (shifted_index % vocab_size != 0) {
int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index];
d_logits_in_out[original_index] = (T)filter_value;
}
}
}
template <typename T>
__global__ void FilterLogitsKernel(float* d_sorted_logits_in,
const int* d_sorted_indices,
T* d_logits_in_out,
float top_p_threshold,
float filter_value,
int min_tokens_to_keep,
int batch_size,
int vocab_size) {
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= batch_size * vocab_size) {
return;
}
int vocab_idx = index % vocab_size;
int batch_id = index / vocab_size;
int start_index = batch_id * vocab_size;
int count = vocab_idx;
float sum = 0.0f;
// TODO: Optimization needed. e.g. use CUB::SCAN() for cumulative probabilities.
while (count >= 0) {
sum += d_sorted_logits_in[start_index];
++start_index;
--count;
}
if (sum <= top_p_threshold) {
if (index % vocab_size + min_tokens_to_keep < vocab_size) {
int original_index = batch_id * vocab_size + d_sorted_indices[index];
d_logits_in_out[original_index] = (T)filter_value;
}
}
}
template <typename T>
void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
const int* d_sorted_indices,
T* d_logits_in_out,
float top_p,
float filter_value,
int min_tokens_to_keep,
int batch_size,
int vocab_size,
cudaStream_t stream,
bool is_descending) {
int total_elements = batch_size * vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
if (is_descending) {
FilterLogitsKernelCustom<<<gridSize, blockSize, 0, stream>>>(d_sorted_logits_in,
d_sorted_indices,
d_logits_in_out,
top_p,
filter_value,
batch_size,
vocab_size);
} else {
FilterLogitsKernel<<<gridSize, blockSize, 0, stream>>>(d_sorted_logits_in,
d_sorted_indices,
d_logits_in_out,
1 - top_p,
filter_value,
min_tokens_to_keep,
batch_size,
vocab_size);
}
}
template void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
const int* d_sorted_indices,
float* d_logits_in_out,
float top_p,
float filter_value,
int min_tokens_to_keep,
int batch_size,
int vocab_size,
cudaStream_t stream,
bool is_descending);
template void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
const int* d_sorted_indices,
half* d_logits_in_out,
float top_p,
float filter_value,
int min_tokens_to_keep,
int batch_size,
int vocab_size,
cudaStream_t stream,
bool is_descending);
// Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu
template <typename scalar_t, typename accscalar_t>
__global__ void sampleMultinomialOnce(int64_t* dest,
int distributions,
int categories,
scalar_t* sampled,
scalar_t* dist,
int stride_dist, // dist->stride(0)
int stride_categories, // dist->stride(1)
int* d_presence_mask) {
extern __shared__ unsigned char my_smem[];
__shared__ bool found;
__shared__ unsigned foundPos;
accscalar_t *smem = reinterpret_cast<accscalar_t *>(my_smem);
accscalar_t accZero = static_cast<accscalar_t>(0);
scalar_t zero = static_cast<scalar_t>(0);
for (int curDist = blockIdx.x;
curDist < distributions; curDist += gridDim.x) {
// Assume sum = 1 in Top P sampling as the input is softmaxed.
accscalar_t sum = 1;
// Broadcast sum and sample value
if (threadIdx.x == 0) {
// Make sure the sum of our distribution didn't overflow
// CUDA_KERNEL_ASSERT(!_isinf(val));
// CUDA_KERNEL_ASSERT(sum > accZero);
foundPos = 0;
smem[0] = sum;
smem[1] = sampled[curDist];
}
__syncthreads();
sum = smem[0];
scalar_t sample = static_cast<scalar_t>(smem[1]);
__syncthreads();
if (sum == accZero) {
// Choose the first element
if (threadIdx.x == 0) {
dest[curDist] = 0;
}
continue;
}
int chunks = (categories + (int)blockDim.x - 1) / blockDim.x;
accscalar_t prevHighProb = accZero;
found = false;
for (int chunk = 0; chunk < chunks && !found; ++chunk) {
// All threads in bounds load a value
int cat = chunk * blockDim.x + threadIdx.x;
accscalar_t dist_val = cat < categories ?
static_cast<accscalar_t>(dist[curDist * stride_dist + cat * stride_categories]) / sum :
accZero;
smem[threadIdx.x] = dist_val;
__syncthreads();
// Perform an inclusive prefix sum of the shared memory contents
for (int offset = 1; offset < blockDim.x; offset *= 2) {
accscalar_t val = accZero;
if (threadIdx.x >= offset) {
val = smem[threadIdx.x - offset] + smem[threadIdx.x];
}
__syncthreads();
if (threadIdx.x >= offset) {
smem[threadIdx.x] = val;
}
__syncthreads();
}
// Each thread will check to see if the sample falls in its bucket
scalar_t curBucket =
static_cast<scalar_t>(smem[threadIdx.x] + prevHighProb);
scalar_t prevBucket = static_cast<scalar_t>(
threadIdx.x == 0 ? prevHighProb
: smem[threadIdx.x - 1] + prevHighProb);
bool inBucket =
(cat < categories) &&
(!(sample >= curBucket) &&
(sample >= prevBucket) &&
(dist_val > zero));
if (inBucket) {
// We're done; we have the sample
// Torch indices are 1-based
atomicMax(&foundPos, cat);
found = true;
}
// Store the previous scan's high value for future use
prevHighProb = prevHighProb + smem[blockDim.x - 1];
__syncthreads();
}
if (threadIdx.x == 0) {
if (found) {
dest[curDist] = foundPos;
} else {
// This should address a rare bug where we don't select a valid index. This likely occurs when
// due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but
// and our uniform sample is greater than this value. In this case we likely have unitialized memory
// in dest[curDist]. So basically we will loop through the distribution and pick the largest index
// where the distribution is non-zero. This is obviously terribly inefficient, but due to the
// rarity in which this occurs, this should not be an issue.
for (int cat = categories - 1; cat >= 0; --cat) {
if (dist[curDist * stride_dist + cat * stride_categories] > zero) {
dest[curDist] = cat;
break;
}
}
}
}
}
// update presence mask
int index = blockIdx.x * blockDim.x + threadIdx.x;
if (index >= distributions * categories) {
return;
}
int dist_idx = index / categories;
int cat_idx = index % categories;
if (dest[dist_idx] == cat_idx) {
d_presence_mask[index] = 1;
}
}
// Only support n_sample = 1
void TorchMultinomialKernelLauncher(float* d_input,
float* d_sampled,
int64_t* d_output,
int batch_size,
int vocab_size,
int* d_presence_mask,
cudaStream_t stream)
{
// Store the props in class variables
int device;
cudaGetDevice(&device);
cudaDeviceProp props;
cudaGetDeviceProperties(&props, device);
int numSM = props.multiProcessorCount;
int maxThreads = props.maxThreadsPerBlock;
int warp_size = 32; //at::cuda::warp_size();
int requiredWarps = (vocab_size + warp_size - 1) / warp_size;
int requiredThreads = std::min(maxThreads, requiredWarps * warp_size);
int requiredShared = requiredThreads * sizeof(float);
dim3 block(requiredThreads);
dim3 grid(std::min(batch_size, numSM * 4));
sampleMultinomialOnce<float, float>
<<<grid, block, requiredShared, stream>>>(d_output,
batch_size,
vocab_size,
d_sampled,
d_input,
vocab_size,
1,
d_presence_mask);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,114 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include <stdint.h>
#include <cuda_fp16.h>
#include <curand_kernel.h>
namespace onnxruntime {
namespace contrib {
namespace cuda {
void LaunchInitKernel(
float* beam_scores,
int batch_size,
int num_beams,
cudaStream_t stream);
template <typename T>
void LaunchAddProbsKernel(T* log_probs,
T* cum_log_probs,
const int batch_size,
const int num_beams,
const int vocab_size,
cudaStream_t stream);
template <typename T>
void LaunchLogitsProcessKernel(
T* next_token_scores,
const int* vocab_mask,
const int* prefix_vocab_mask,
int* presence_mask,
float presence_penalty,
float temperature,
int batch_size,
int num_beams,
int vocab_size,
int padded_vocab_size,
int demote_token_id,
int32_t* sequences,
int max_sequence_length,
int current_sequence_length,
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream);
void LaunchNextTokenKernel(const int64_t* next_token_indices,
int32_t* next_indices,
int32_t* next_tokens,
int batch_size,
int top_k,
int vocab_size,
cudaStream_t stream);
void LaunchUpdateGptKernel(const int32_t* old_mask_data,
int32_t* mask_data,
int32_t* next_positions,
int batch_beam_size,
int current_length,
cudaStream_t stream);
template <typename T>
void GetTempStorageSize(const T *d_keys_in,
const int* d_values_in,
int* d_offsets,
int num_items,
int num_segments,
cudaStream_t stream,
bool is_descending,
size_t& temp_storage_bytes);
void LaunchSetupParamsKernel(int* d_values_in,
int* d_offsets,
int batch_size,
int vocab_size,
cudaStream_t stream);
template <typename T>
void LaunchSortPairs(void *d_temp_storage,
size_t temp_storage_bytes,
const T *d_keys_in,
T *d_keys_out,
const int *d_values_in,
int *d_values_out,
int num_items,
int num_segments,
int *d_offsets,
cudaStream_t stream,
bool is_descending);
template <typename T>
void LaunchFilterLogitsKernel(float* d_sorted_logits_in,
const int* d_sorted_indices,
T* d_logits_in_out,
float top_p,
float filter_value,
int min_tokens_to_keep,
int batch_size,
int vocab_size,
cudaStream_t stream,
bool is_descending);
void TorchMultinomialKernelLauncher(float* d_input,
float* d_sampled,
int64_t* d_output,
int batch_size,
int vocab_size,
int* d_presence_mask,
cudaStream_t stream);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -4,19 +4,25 @@
#include <utility>
#include <memory>
#include "core/providers/shared_library/provider_api.h"
#include "core/providers/cuda/cuda_kernel.h"
#include "core/providers/cuda/math/topk_impl.h"
#include "core/providers/cuda/math/softmax.h"
#include "core/providers/cuda/shared_inc/accumulation_type.h"
#include "core/framework/ort_value.h"
#include "contrib_ops/cuda/bert/transformer_cuda_common.h"
#include <cuda_runtime.h>
#include "contrib_ops/cuda/transformers/beam_search_impl.h"
#include "contrib_ops/cuda/transformers/generation_cuda_impl.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
#include "contrib_ops/cuda/transformers/beam_search_topk.h"
#include "core/providers/cuda/nvtx_profile.h"
#include "core/providers/cuda/nvtx_profile_context.h"
#include "sampling_cuda_helper.h"
#ifdef DEBUG_GENERATION
#include <iostream>
#endif
namespace onnxruntime {
namespace concurrency {
@ -128,7 +134,6 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
cudaStream_t stream = ort_stream ? static_cast<cudaStream_t>(ort_stream->GetHandle()) : nullptr;
auto pinned_buffer = IAllocator::MakeUniquePtr<void>(pinned_allocator, total_bytes);
char* pinned_data = static_cast<char*>(pinned_buffer.get());
// Copy tensors to one pinned memory buffer (so that we only need copy to GPU once)
char* destination = pinned_data;
for (auto& input : inputs) {
@ -145,16 +150,13 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
"AddToFeeds: An implementation for the input type ",
dataType, " is not supported yet");
}
// Do not need alignment because GPT has int32 inputs (past is empty) and T5 encoder has int64 inputs.
destination += bytes;
}
}
if (!buffer) {
buffer = provider->GetScratchBuffer<char>(total_bytes, ort_stream, WaitCudaNotificationOnDevice);
}
char* gpu_data = buffer.get();
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream));
@ -164,7 +166,6 @@ Status AddToFeeds(const IExecutionProvider* execution_provider,
CUDA_RETURN_IF_ERROR(cudaEventCreate(&isCopyDone));
CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, stream));
CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone));
// TODO(tianleiwu): allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call.
const OrtMemoryInfo& location = provider->GetAllocator(0, OrtMemTypeDefault)->Info();
for (auto& input : inputs) {
@ -253,7 +254,7 @@ Status ProcessLogits(const OrtValue& logits, //
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
transformers::IBeamScorer* beam_scorer, // beam scorer
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
int step, // iteration counter
Stream* ort_stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
@ -363,6 +364,9 @@ Status ProcessLogits(const OrtValue& logits, //
next_token_scores.data(),
parameters->vocab_mask.data(),
step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only.
nullptr, // parameters->presence_mask.data(),
parameters->presence_penalty,
parameters->temperature,
parameters->batch_size,
parameters->num_beams,
parameters->vocab_size,
@ -504,11 +508,13 @@ template <typename T>
Status GreedySearchProcessLogits(
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingState<T>* sampling_state, // buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
@ -519,7 +525,6 @@ Status GreedySearchProcessLogits(
#endif
ORT_UNUSED_PARAMETER(logits_processors);
#ifndef DEBUG_GENERATION
ORT_UNUSED_PARAMETER(dumper);
#endif
@ -596,6 +601,13 @@ Status GreedySearchProcessLogits(
cudaMemcpyHostToDevice, cuda_stream));
}
// Copy parameters->presence_mask to sampling_state->presence_mask
gsl::span<int>& presence_mask = sampling_state->d_presence_mask;
if (step == 1 && parameters->presence_mask.data() != nullptr) {
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(presence_mask.data(), parameters->presence_mask.data(),
sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream));
}
// TODO(hasesh): Can we avoid the const_cast by changing the interface of
// GreedySearchProcessLogits() to take in a non-const OrtValue for logits
// as this is the only place we will ever use the logits and it may be reasonable
@ -605,11 +617,15 @@ Status GreedySearchProcessLogits(
: reinterpret_cast<CudaT*>(next_token_scores.data()),
parameters->vocab_mask.data(),
step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only.
parameters->presence_mask.data() ? presence_mask.data() : nullptr,
parameters->presence_penalty,
parameters->temperature,
parameters->batch_size,
parameters->num_beams,
parameters->vocab_size,
is_reuse_logits_buffer ? padded_vocab_size : parameters->vocab_size,
(parameters->min_length > 0 && current_sequence_length < parameters->min_length) ? parameters->eos_token_id : -1,
(parameters->min_length > 0 && current_sequence_length < parameters->sequence_length + parameters->min_length)
? parameters->eos_token_id : -1,
reinterpret_cast<int32_t*>(sequences_buffer.get()),
parameters->max_length,
current_sequence_length,
@ -629,6 +645,19 @@ Status GreedySearchProcessLogits(
// TODO(wy): support output_scores in greedy search
ORT_UNUSED_PARAMETER(output_scores);
if (do_sampling) {
ORT_RETURN_IF_ERROR(SamplingCudaHelper::Sample(allocator,
cuda_stream,
next_token_scores,
sampling_state,
greedy_state,
parameters,
step,
dumper));
return Status::OK();
}
// next_tokens = torch.argmax(scores, dim=-1)
int64_t next_token_scores_dims[] = {static_cast<int64_t>(batch_size),
is_reuse_logits_buffer ? padded_vocab_size : vocab_size};
@ -657,8 +686,8 @@ Status GreedySearchProcessLogits(
*topk_scores, *topk_indices));
#ifdef DEBUG_GENERATION
dumper->Print("topk_scores", *(topk_scores.get()));
dumper->Print("topk_indices", *(topk_indices.get()));
dumper->Print("topk_scores", *(topk_scores.get()));
dumper->Print("topk_indices", *(topk_indices.get()));
#endif
const int64_t* next_token_indices = topk_indices->Data<int64_t>();
@ -996,7 +1025,7 @@ template Status ProcessLogits<float>(
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ILogitsProcessorList* logits_processors,
transformers::IBeamScorer* beam_scorer,
const transformers::IBeamSearchParameters* parameters,
const transformers::IGenerationParameters* parameters,
int step,
Stream* ort_stream,
const transformers::IConsoleDumper* dumper);
@ -1004,11 +1033,13 @@ template Status ProcessLogits<float>(
template Status GreedySearchProcessLogits<float>(
const OrtValue& logits,
transformers::IGreedySearchState<float>* greedy_state,
transformers::ISamplingState<float>* sampling_state,
transformers::ISequences* sequences,
AllocatorPtr& allocator,
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ILogitsProcessorList* logits_processors,
const transformers::IBeamSearchParameters* parameters,
const transformers::IGenerationParameters* parameters,
bool do_sampling,
int step,
Stream* ort_stream,
const transformers::IConsoleDumper* dumper);
@ -1063,7 +1094,7 @@ template Status ProcessLogits<MLFloat16>(
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ILogitsProcessorList* logits_processors,
transformers::IBeamScorer* beam_scorer,
const transformers::IBeamSearchParameters* parameters,
const transformers::IGenerationParameters* parameters,
int step,
Stream* ort_stream,
const transformers::IConsoleDumper* dumper);
@ -1071,11 +1102,13 @@ template Status ProcessLogits<MLFloat16>(
template Status GreedySearchProcessLogits<MLFloat16>(
const OrtValue& logits,
transformers::IGreedySearchState<MLFloat16>* greedy_state,
transformers::ISamplingState<MLFloat16>* sampling_state,
transformers::ISequences* sequences,
AllocatorPtr& allocator,
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ILogitsProcessorList* logits_processors,
const transformers::IBeamSearchParameters* parameters,
const transformers::IGenerationParameters* parameters,
bool do_sampling,
int step,
Stream* ort_stream,
const transformers::IConsoleDumper* dumper);

View file

@ -55,7 +55,7 @@ Status ProcessLogits(const OrtValue& logits, //
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
transformers::IBeamScorer* beam_scorer, // beam scorer
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper); // tensor dumper
@ -63,11 +63,13 @@ Status ProcessLogits(const OrtValue& logits, //
template <typename T>
Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState<T>* greedy_state, // state
transformers::ISamplingState<T>* sampling_state,// sampling buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
const transformers::IBeamSearchParameters* parameters, // parameters
const transformers::IGenerationParameters* parameters, // parameters
bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper); // tensor dumper

View file

@ -0,0 +1,68 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/cuda_execution_provider.h"
#include "contrib_ops/cuda/transformers/sampling.h"
#include "contrib_ops/cuda/transformers/generation_device_helper.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
ONNX_OPERATOR_KERNEL_EX(
Sampling,
kMSDomain,
1,
kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // 'input_ids' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 1) // 'max_length' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 2) // 'min_length' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 3) // 'repetition_penalty' needs to be on CPU
.InputMemoryType(OrtMemTypeCPUInput, 6) // 'custom_attention_mask' needs to be on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU
.OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'logits_to_debug' output on CPU
.TypeConstraint("T", {DataTypeImpl::GetTensorType<float>(),
DataTypeImpl::GetTensorType<MLFloat16>()}),
Sampling);
transformers::CudaTensorConsoleDumper g_cuda_dumper_sampling;
Sampling::Sampling(const OpKernelInfo& info)
: onnxruntime::contrib::transformers::Sampling(info) {
SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds,
GenerationCudaDeviceHelper::TopK,
GenerationCudaDeviceHelper::DeviceCopy<float>,
GenerationCudaDeviceHelper::GreedySearchProcessLogits<float>,
GenerationCudaDeviceHelper::GreedySearchProcessLogits<MLFloat16>,
GenerationCudaDeviceHelper::InitGreedyState<float>,
GenerationCudaDeviceHelper::InitGreedyState<MLFloat16>);
SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds<float>,
GenerationCudaDeviceHelper::UpdateGptFeeds<MLFloat16>);
SetConsoleDumper(&g_cuda_dumper_sampling);
}
Status Sampling::ComputeInternal(OpKernelContext* context) const {
return onnxruntime::contrib::transformers::Sampling::Compute(context);
}
Status Sampling::Compute(OpKernelContext* context) const {
auto s = ComputeInternal(context);
if (s.IsOK()) {
auto err = cudaGetLastError();
if (err != cudaSuccess) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err));
}
}
return s;
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,26 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "contrib_ops/cpu/transformers/sampling.h"
namespace onnxruntime {
class SessionState;
namespace contrib {
namespace cuda {
class Sampling final : public onnxruntime::contrib::transformers::Sampling {
public:
Sampling(const OpKernelInfo& info);
Status Compute(OpKernelContext* context) const override;
private:
Status ComputeInternal(OpKernelContext* context) const;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,186 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/providers/cpu/tensor/utils.h"
#include "contrib_ops/cpu/transformers/generation_shared.h"
#include "core/providers/cuda/math/softmax.h"
#ifdef DEBUG_GENERATION
#include <iostream>
#endif
namespace onnxruntime {
namespace contrib {
namespace SamplingCudaHelper {
template <typename T>
Status Sample(AllocatorPtr& allocator,
cudaStream_t cuda_stream,
gsl::span<T>& next_token_scores,
transformers::ISamplingState<T>* sampling_state,
transformers::IGreedySearchState<T>* greedy_state,
const transformers::IGenerationParameters* parameters,
int step,
const transformers::IConsoleDumper* dumper) {
ORT_UNUSED_PARAMETER(dumper);
typedef typename ToCudaType<T>::MappedType CudaT;
gsl::span<int>& d_index_in = sampling_state->d_index_in;
gsl::span<int>& d_offset = sampling_state->d_offset;
BufferUniquePtr& storage_buffer = sampling_state->storage_buffer;
size_t& temp_storage_bytes = sampling_state->temp_storage_bytes;
bool is_descending = parameters->custom_sampling;
if (step == 1) {
cuda::GetTempStorageSize<CudaT>(reinterpret_cast<CudaT*>(next_token_scores.data()),
d_index_in.data(),
d_offset.data(),
parameters->batch_size * parameters->vocab_size,
parameters->batch_size,
cuda_stream,
is_descending,
temp_storage_bytes);
cuda::LaunchSetupParamsKernel(d_index_in.data(),
d_offset.data(),
parameters->batch_size,
parameters->vocab_size,
cuda_stream);
#ifdef DEBUG_GENERATION
dumper->Print("d_offset_buffer", d_offset.data(), parameters->batch_size + 1, 1);
#endif
void* temp_storage = allocator->Alloc(sampling_state->temp_storage_bytes);
BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator));
storage_buffer = std::move(temp_storage_buffer);
}
gsl::span<T>& d_sorted_score = sampling_state->d_sorted_score;
gsl::span<int>& d_index_out = sampling_state->d_index_out;
#ifdef DEBUG_GENERATION
dumper->Print("temp_storage_bytes", sampling_state->temp_storage_bytes, true);
#endif
cuda::LaunchSortPairs<CudaT>(storage_buffer.get(),
temp_storage_bytes,
reinterpret_cast<CudaT*>(next_token_scores.data()),
reinterpret_cast<CudaT*>(d_sorted_score.data()),
d_index_in.data(),
d_index_out.data(),
parameters->batch_size * parameters->vocab_size,
parameters->batch_size,
d_offset.data(),
cuda_stream,
is_descending);
#ifdef DEBUG_GENERATION
dumper->Print("d_sorted_score_buffer",
reinterpret_cast<T*>(d_sorted_score.data()),
parameters->batch_size,
parameters->vocab_size);
dumper->Print("d_index_buffer_in", d_index_in.data(), parameters->batch_size, parameters->vocab_size);
dumper->Print("d_index_buffer_out", d_index_out.data(), parameters->batch_size, parameters->vocab_size);
#endif
gsl::span<float>& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score;
dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_sorted_softmaxed_score.data(),
reinterpret_cast<CudaT*>(d_sorted_score.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size);
#ifdef DEBUG_GENERATION
dumper->Print("d_sorted_softmaxed_score_buffer",
d_sorted_softmaxed_score.data(),
parameters->batch_size,
parameters->vocab_size);
#endif
cuda::LaunchFilterLogitsKernel<CudaT>(d_sorted_softmaxed_score.data(),
d_index_out.data(),
reinterpret_cast<CudaT*>(next_token_scores.data()),
parameters->top_p,
parameters->filter_value,
parameters->min_tokens_to_keep,
parameters->batch_size,
parameters->vocab_size,
cuda_stream,
is_descending);
#ifdef DEBUG_GENERATION
dumper->Print("next_token_scores after filtering logits",
reinterpret_cast<T*>(next_token_scores.data()),
parameters->batch_size,
parameters->vocab_size);
#endif
gsl::span<float>& d_softmaxed_score = sampling_state->d_softmaxed_score;
dispatch_blockwise_softmax_forward<CudaT, float, float, false>(cuda_stream,
d_softmaxed_score.data(),
reinterpret_cast<CudaT*>(next_token_scores.data()),
parameters->vocab_size,
parameters->vocab_size,
parameters->vocab_size,
parameters->batch_size);
#ifdef DEBUG_GENERATION
dumper->Print("d_softmaxed_score_buffer",
d_softmaxed_score.data(),
parameters->batch_size,
parameters->vocab_size);
#endif
// Multinomial sampling
gsl::span<float>& d_sampled = sampling_state->d_sampled;
gsl::span<float>& h_sampled_all = sampling_state->h_sampled_all;
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(),
h_sampled_all.data() + (step - 1) * parameters->batch_size,
sizeof(float) * parameters->batch_size,
cudaMemcpyHostToDevice,
cuda_stream));
#ifdef DEBUG_GENERATION
dumper->Print("d_sampled", d_sampled.data(), parameters->batch_size, 1);
#endif
gsl::span<int64_t>& d_indices = sampling_state->d_indices;
gsl::span<int>& presence_mask = sampling_state->d_presence_mask;
cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(),
d_sampled.data(),
d_indices.data(),
parameters->batch_size,
parameters->vocab_size,
presence_mask.data(),
cuda_stream);
#ifdef DEBUG_GENERATION
dumper->Print("d_indices", d_indices.data(), parameters->batch_size, 1);
#endif
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(),
sampling_state->d_indices.data(),
greedy_state->next_tokens_cpu.size_bytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state->h_softmaxed_score.data(),
sampling_state->d_softmaxed_score.data(),
sampling_state->h_softmaxed_score.size_bytes(),
cudaMemcpyDeviceToHost,
cuda_stream));
CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream));
return Status::OK();
}
} // namespace SamplingCudaHelper
} // namespace contrib
} // namespace onnxruntime

View file

@ -510,6 +510,13 @@ void GreedySearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) {
sequences_shape.add_dim()->set_dim_value(batch_size);
sequences_shape.add_dim()->set_dim_value(max_length_value);
updateOutputShape(ctx, 0, sequences_shape);
if (ctx.getNumOutputs() > 1) {
ONNX_NAMESPACE::TensorShapeProto logits_to_debug_shape;
logits_to_debug_shape.add_dim()->set_dim_value(batch_size);
logits_to_debug_shape.add_dim();
updateOutputShape(ctx, 1, logits_to_debug_shape);
}
}
constexpr const char* Gelu_ver1_doc =
@ -1114,6 +1121,48 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1,
GreedySearchShapeInference(ctx);
}));
ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1,
OpSchema()
.SetDoc("Greedy Sampling for text generation.")
.Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT)
.Attr("pad_token_id", "The id of the padding token", AttributeProto::INT)
.Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast<int64_t>(-1))
.Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("temperature", "The value used to module the next token probabilities.", AttributeProto::FLOAT, 1.0f)
.Attr("top_p",
"If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.",
AttributeProto::FLOAT, 0.0f)
.Attr("filter_value", "All filtered values will be set to this float value.", AttributeProto::FLOAT, -1e20f)
.Attr("min_tokens_to_keep", "Minimumber of tokens we keep per batch example in the output.", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("presence_penalty", "Presence penalty for custom sampling", AttributeProto::FLOAT, 0.0f)
.Attr("custom", "If 1 custom sampling logic", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("model_type", "Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast<int64_t>(0))
.Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE)
.Attr("init_decoder",
"The subgraph for the first decoding run. It will be called once before `decoder` subgraph. "
"This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs",
AttributeProto::GRAPH, OPTIONAL_VALUE)
.Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH)
.Attr("vocab_size",
"Size of the vocabulary. "
"If not provided, it will be inferred from the decoder subgraph's output shape",
AttributeProto::INT, static_cast<int64_t>(-1))
.Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I")
.Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I")
.Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional)
.Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional)
.Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional)
.Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional)
.Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional)
.Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I")
.Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional)
.TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
GreedySearchShapeInference(ctx);
}));
ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1,
OpSchema()
.Input(0, "X", "input", "T")

View file

@ -81,6 +81,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer);
@ -165,6 +166,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul)>());
fn(GetOpSchema<ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer)>());

View file

@ -33,6 +33,7 @@
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
#include "contrib_ops/cpu/transformers/beam_search.h"
#include "contrib_ops/cpu/transformers/greedy_search.h"
#include "contrib_ops/cpu/transformers/sampling.h"
#ifdef ENABLE_ATEN
#include "contrib_ops/cpu/aten_ops/aten_op.h"
#endif
@ -248,6 +249,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU {
subgraph_session_state);
}
void Sampling__Init(contrib::transformers::Sampling* p, const OpKernelInfo& info) override { p->contrib::transformers::Sampling::Init(info); }
Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); }
Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); }
#ifdef ENABLE_ATEN
Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); }
#endif

View file

@ -9,6 +9,7 @@ class AttentionBase;
namespace transformers {
class BeamSearch;
class GreedySearch;
class Sampling;
}
} // namespace contrib
@ -174,6 +175,10 @@ struct ProviderHostCPU {
const std::string& attribute_name,
const SessionState& subgraph_session_state) = 0;
virtual void Sampling__Init(contrib::transformers::Sampling* p, const OpKernelInfo& info) = 0;
virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0;
virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0;
#ifdef ENABLE_ATEN
virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0;
#endif

View file

@ -207,13 +207,13 @@ template <typename T, typename IndexType = int64_t>
using EigenVector = Eigen::TensorMap<Eigen::Tensor<T, 1, Eigen::RowMajor, IndexType>>;
template <typename OutputType>
static Status MultinomialCompute(OpKernelContext* ctx,
const Tensor& X,
const int64_t batch_size,
const int64_t num_classes,
const int64_t num_samples,
std::default_random_engine& generator,
Tensor& Y) {
Status MultinomialComputeShared(AllocatorPtr& alloc,
const Tensor& X,
const int64_t batch_size,
const int64_t num_classes,
const int64_t num_samples,
std::default_random_engine& generator,
Tensor& Y) {
if (!utils::HasType<EnabledMultinomialOutputTypes, OutputType>()) {
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build.");
}
@ -227,8 +227,6 @@ static Status MultinomialCompute(OpKernelContext* ctx,
Matrix<OutputType> output = Matrix<OutputType>(Y.MutableData<OutputType>(), Y_dims);
// BEGIN create temporary tensor
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
auto cdf_data = static_cast<double*>(alloc->Alloc(SafeInt<size_t>(sizeof(double)) * num_classes));
BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(std::move(alloc)));
Eigen::array<int64_t, 1> cdf_dims = {{num_classes}};
@ -271,6 +269,20 @@ static Status MultinomialCompute(OpKernelContext* ctx,
return Status::OK();
}
template <typename OutputType>
static Status MultinomialCompute(OpKernelContext* ctx,
const Tensor& X,
const int64_t batch_size,
const int64_t num_classes,
const int64_t num_samples,
std::default_random_engine& generator,
Tensor& Y) {
// BEGIN create temporary tensor
AllocatorPtr alloc;
ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc));
return MultinomialComputeShared<OutputType>(alloc, X, batch_size, num_classes, num_samples, generator, Y);
}
Status Multinomial::Compute(OpKernelContext* ctx) const {
const auto* tensor_pointer = ctx->Input<Tensor>(0);
if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch");
@ -408,4 +420,12 @@ void GenerateData(std::default_random_engine& generator, TDistribution distribut
}
}
template Status MultinomialComputeShared<int64_t>(AllocatorPtr& alloc,
const Tensor& X,
const int64_t batch_size,
const int64_t num_classes,
const int64_t num_samples,
std::default_random_engine& generator,
Tensor& Y);
} // namespace onnxruntime

View file

@ -13,6 +13,15 @@
namespace onnxruntime {
template <typename OutputType>
Status MultinomialComputeShared(AllocatorPtr& alloc,
const Tensor& X,
const int64_t batch_size,
const int64_t num_classes,
const int64_t num_samples,
std::default_random_engine& generator,
Tensor& Y);
class RandomNormal final : public OpKernel {
public:
RandomNormal(const OpKernelInfo& info) : OpKernel(info) {

View file

@ -33,6 +33,7 @@
#include "contrib_ops/cpu/bert/longformer_attention_base.h"
#include "contrib_ops/cpu/transformers/beam_search.h"
#include "contrib_ops/cpu/transformers/greedy_search.h"
#include "contrib_ops/cpu/transformers/sampling.h"
#ifdef ENABLE_ATEN
#include "contrib_ops/cpu/aten_ops/aten_op.h"
#endif
@ -601,6 +602,15 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
return g_host_cpu.GreedySearch__SetupSubgraphExecutionInfo(this, session_state, attribute_name,
subgraph_session_state);
}
void Sampling::Init(const OpKernelInfo& info) { g_host_cpu.Sampling__Init(this, info); }
Status Sampling::Compute(OpKernelContext* ctx) const { return g_host_cpu.Sampling__Compute(this, ctx); }
Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name,
const SessionState& subgraph_session_state) {
return g_host_cpu.Sampling__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); }
} // namespace transformers
#ifdef ENABLE_ATEN

View file

@ -25,6 +25,9 @@ Example 4: convert MT5 model with external data file like mt5-base-beamsearch.on
Example 5: convert gpt2 model with greedy search:
python convert_generation.py -m gpt2 --output gpt2_greedy_search.onnx --num_beams 1 --num_return_sequences 1
Example 6: convert gpt2 model with sampling:
python convert_generation.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 --top_p 0.6
"""
import argparse
@ -72,6 +75,7 @@ logger = logging.getLogger("")
class GenerationType(Enum):
BEAMSEARCH = "beam_search"
GREEDYSEARCH = "greedy_search"
SAMPLING = "sampling"
def __str__(self):
return self.value
@ -261,6 +265,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
)
model_group.set_defaults(custom_attention_mask=False)
model_group.add_argument(
"--presence_mask",
required=False,
action="store_true",
help="Presence mask for custom sampling",
)
model_group.set_defaults(presence_mask=False)
beam_parameters_group = parser.add_argument_group(
"Beam search parameters not stored in the output model, for testing parity and performance"
)
@ -295,6 +307,54 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace:
help="Positive. >1 to penalize and <1 to encourage.",
)
beam_parameters_group.add_argument(
"--temperature",
type=float,
required=False,
default=1.0,
help="The value used to module the next token probabilities.",
)
beam_parameters_group.add_argument(
"--top_p",
type=float,
required=False,
default=1.0,
help="Top P for sampling",
)
beam_parameters_group.add_argument(
"--filter_value",
type=float,
required=False,
default=-float("Inf"),
help="Filter value for Top P sampling",
)
beam_parameters_group.add_argument(
"--min_tokens_to_keep",
type=int,
required=False,
default=1,
help="Minimumber of tokens we keep per batch example in the output.",
)
beam_parameters_group.add_argument(
"--presence_penalty",
type=float,
required=False,
default=0.0,
help="presence penalty for custom sampling.",
)
beam_parameters_group.add_argument(
"--custom",
type=int,
required=False,
default=0,
help="If 1 customized top P logic is applied",
)
beam_parameters_group.add_argument(
"--vocab_size",
type=int,
@ -1201,18 +1261,20 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
args (argparse.Namespace): arguments parsed from command line
"""
is_gpt2: bool = args.model_type == "gpt2"
is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH
is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH
is_sampling: bool = generation_type == GenerationType.SAMPLING
past_present_share_buffer: bool = args.past_present_share_buffer and is_greedysearch
logger.info(f"**** past_present_share_buffer={past_present_share_buffer}, is_greedysearch={is_greedysearch}")
if is_greedysearch:
if is_greedysearch or is_sampling:
if not is_gpt2:
raise NotImplementedError("Currently only gpt2 with greedy search is supported")
raise NotImplementedError("Currently only gpt2 with greedy search/sampling is supported")
if args.output_sequences_scores:
raise NotImplementedError("output_sequences_scores currently is not supported in greedy search")
raise NotImplementedError("output_sequences_scores currently is not supported in greedy search/sampling")
if args.output_token_scores:
raise NotImplementedError("output_token_scores currently is not supported in greedy search")
raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling")
if is_gpt2:
if args.decoder_onnx and os.path.exists(args.decoder_onnx):
@ -1243,7 +1305,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
args.pad_vocab_size
and args.precision == Precision.FLOAT16
and is_gpt2
and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH)
and (is_beamsearch or is_greedysearch or is_sampling)
):
logger.info(
f"Pad logits MatMul weights for optimal MatMul perf in fp16 on {args.decoder_onnx}. "
@ -1257,11 +1319,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
gpt2_init_decoder_generated = False
gpt2_init_decoder_onnx_path = None
if (
args.separate_gpt2_decoder_for_init_run
and is_gpt2
and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH)
):
if args.separate_gpt2_decoder_for_init_run and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling):
logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ")
gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format(
@ -1331,8 +1389,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
else:
verify_t5_decoder_subgraph(decoder_model.graph, args.precision)
inputs = (
[
inputs = None
if is_beamsearch:
inputs = [
"input_ids",
"max_length",
"min_length",
@ -1341,14 +1400,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
"length_penalty",
"repetition_penalty",
]
if not is_greedysearch
else [
elif is_greedysearch or is_sampling:
inputs = [
"input_ids",
"max_length",
"min_length",
"repetition_penalty",
]
)
if args.vocab_mask:
inputs.append("vocab_mask")
@ -1365,6 +1423,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
else:
inputs.append("")
if is_sampling and args.custom and args.presence_mask:
inputs.append("presence_mask")
outputs = ["sequences"]
if args.output_sequences_scores:
outputs.append("sequences_scores")
@ -1373,40 +1434,60 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores"
outputs.append("scores")
node = (
onnx.helper.make_node(
node = None
if is_beamsearch:
node = onnx.helper.make_node(
"BeamSearch",
inputs=inputs,
outputs=outputs,
name=f"BeamSearch_{args.model_type}",
)
if not is_greedysearch
else onnx.helper.make_node(
elif is_greedysearch:
node = onnx.helper.make_node(
"GreedySearch",
inputs=inputs,
outputs=outputs,
name=f"GreedySearch_{args.model_type}",
)
)
elif is_sampling:
node = onnx.helper.make_node(
"Sampling",
inputs=inputs,
outputs=outputs,
name=f"Sampling_{args.model_type}",
)
node.domain = "com.microsoft"
attr_to_extend = (
[
attr_to_extend = None
if is_beamsearch:
attr_to_extend = [
onnx.helper.make_attribute("eos_token_id", eos_token_id),
onnx.helper.make_attribute("pad_token_id", pad_token_id),
onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0),
onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
]
if not is_greedysearch
else [
elif is_greedysearch:
attr_to_extend = [
onnx.helper.make_attribute("eos_token_id", eos_token_id),
onnx.helper.make_attribute("pad_token_id", pad_token_id),
onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
]
)
elif is_sampling:
attr_to_extend = [
onnx.helper.make_attribute("eos_token_id", eos_token_id),
onnx.helper.make_attribute("pad_token_id", pad_token_id),
onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1),
onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
onnx.helper.make_attribute("temperature", args.temperature),
onnx.helper.make_attribute("top_p", args.top_p),
onnx.helper.make_attribute("filter_value", args.filter_value),
onnx.helper.make_attribute("min_tokens_to_keep", args.min_tokens_to_keep),
onnx.helper.make_attribute("custom", args.custom),
onnx.helper.make_attribute("presence_penalty", args.presence_penalty),
]
# Explicitly pass in the vocab size via an attribute
if logits_matmul_weight_padded:
@ -1481,8 +1562,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
graph_inputs = (
[
graph_inputs = None
if is_beamsearch:
graph_inputs = [
input_ids,
max_length,
min_length,
@ -1491,14 +1573,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
length_penalty,
repetition_penalty,
]
if not is_greedysearch
else [
elif is_greedysearch or is_sampling:
graph_inputs = [
input_ids,
max_length,
min_length,
repetition_penalty,
]
)
if args.vocab_mask:
vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size])
@ -1516,37 +1597,41 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati
)
graph_inputs.append(attention_mask)
if args.custom and args.presence_mask:
presence_mask = onnx.helper.make_tensor_value_info(
"presence_mask", TensorProto.INT32, ["batch_size", vocab_size]
)
graph_inputs.append(presence_mask)
# graph outputs
sequences = (
onnx.helper.make_tensor_value_info(
sequences = None
if is_beamsearch:
sequences = onnx.helper.make_tensor_value_info(
"sequences",
TensorProto.INT32,
["batch_size", "num_return_sequences", "max_length"],
)
if not is_greedysearch
else onnx.helper.make_tensor_value_info(
elif is_greedysearch or is_sampling:
sequences = onnx.helper.make_tensor_value_info(
"sequences",
TensorProto.INT32,
["batch_size", "max_length"],
)
)
sequences_scores = onnx.helper.make_tensor_value_info(
"sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"]
)
scores = onnx.helper.make_tensor_value_info(
"scores",
TensorProto.FLOAT,
["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
)
graph_outputs = [sequences]
if args.output_sequences_scores:
sequences_scores = onnx.helper.make_tensor_value_info(
"sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"]
)
graph_outputs.append(sequences_scores)
if args.output_token_scores:
scores = onnx.helper.make_tensor_value_info(
"scores",
TensorProto.FLOAT,
["max_length - sequence_length", "batch_size", "num_beams", vocab_size],
)
graph_outputs.append(scores)
new_graph = onnx.helper.make_graph(
@ -2085,6 +2170,10 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None
is_greedy = args.num_beams == 1 and args.num_return_sequences == 1
if args.model_type == "gpt2" and is_greedy:
if args.top_p > 0.0 and args.top_p < 1.0:
convert_generation_model(args, GenerationType.SAMPLING)
logger.info("The test for gpt2_sampling onnx model is not implemented yet")
return
convert_generation_model(args, GenerationType.GREEDYSEARCH)
else:
convert_generation_model(args)