diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 2784a1c431..23ea679519 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -75,10 +75,13 @@ set(contrib_ops_excluded_files
"transformers/beam_search.h"
"transformers/generation_device_helper.cc"
"transformers/generation_device_helper.h"
- "transformers/beam_search_impl.cu"
- "transformers/beam_search_impl.h"
+ "transformers/generation_cuda_impl.cu"
+ "transformers/generation_cuda_impl.h"
"transformers/greedy_search.cc"
"transformers/greedy_search.h"
+ "transformers/sampling.cc"
+ "transformers/sampling.h"
+ "transformers/sampling_cuda_helper.h"
"transformers/dump_cuda_tensor.cc"
"transformers/dump_cuda_tensor.h"
"conv_transpose_with_dynamic_pads.cc"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 76e67a31c9..4fc0d4c3e1 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -74,6 +74,7 @@ Do not modify directly.*
* com.microsoft.RestorePadding
* com.microsoft.Rfft
* com.microsoft.SampleOp
+ * com.microsoft.Sampling
* com.microsoft.SkipLayerNormalization
* com.microsoft.Snpe
* com.microsoft.SparseToDenseMatMul
@@ -3810,6 +3811,89 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.Sampling**
+
+ Greedy Sampling for text generation.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- custom : int
+- If 1 custom sampling logic
+- decoder : graph (required)
+- Decoder subgraph to execute in a loop.
+- decoder_start_token_id : int
+- The id of the token that indicates decoding starts.
+- encoder : graph
+- The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
+- eos_token_id : int (required)
+- The id of the end-of-sequence token
+- filter_value : float
+- All filtered values will be set to this float value.
+- init_decoder : graph
+- The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs
+- min_tokens_to_keep : int
+- Minimumber of tokens we keep per batch example in the output.
+- model_type : int
+- Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart
+- no_repeat_ngram_size : int
+- no repeat ngrams size
+- pad_token_id : int (required)
+- The id of the padding token
+- presence_penalty : float
+- Presence penalty for custom sampling
+- temperature : float
+- The value used to module the next token probabilities.
+- top_p : float
+- If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
+- vocab_size : int
+- Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
+
+
+#### Inputs (2 - 8)
+
+
+- input_ids : I
+- The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)
+- max_length : I
+- The maximum length of the sequence to be generated. Shape is (1)
+- min_length (optional) : I
+- The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)
+- repetition_penalty (optional) : T
+- The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
+- vocab_mask (optional) : I
+- Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+- prefix_vocab_mask (optional) : I
+- Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
+- attention_mask (optional) : I
+- Custom attention mask. Shape is (batch_size, sequence_length)
+- presence_mask (optional) : I
+- Presence penalty mask. Shape is (batch_size, vocab_size)
+
+
+#### Outputs (1 - 2)
+
+
+- sequences : I
+- Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)
+- filtered_logits (optional) : T
+- Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)
+
+
+#### Type Constraints
+
+
+- T : tensor(float)
+- Constrain input and output types to float tensors.
+- I : tensor(int32)
+- Constrain to integer types
+
+
+
### **com.microsoft.SkipLayerNormalization**
Skip and Layer Normalization Fusion
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index e1789154c7..e86305de29 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -438,6 +438,7 @@ Do not modify directly.*
|QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
|Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)|
|SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)|
+|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_sum:**T**|1+|**T** = tensor(double), tensor(float)|
|SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)|
@@ -797,6 +798,7 @@ Do not modify directly.*
|RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)|
|RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)|
|SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_sum:**T**|1+|**T** = tensor(float), tensor(float16)|
|TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
index 3f459dff99..a04ef0d71b 100644
--- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc
@@ -18,6 +18,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch);
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range);
@@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
index 1cb2c2050b..868356f70a 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc
@@ -61,16 +61,17 @@ void BeamSearch::Init(const OpKernelInfo& info) {
parameters_.ParseFromAttributes(info);
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
- ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
- parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
+ ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt ||
+ parameters_.model_type == IGenerationParameters::kModelTypeT5);
ONNX_NAMESPACE::GraphProto proto;
- if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
+
+ if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
// Make sure the encoder sub-graph attribute is present for the T5 model.
ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK());
}
- if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
+ if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
if (info.GetAttr("init_decoder", &proto).IsOK()) {
has_init_decoder_ = true;
@@ -87,7 +88,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
const auto& node = Node();
- if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
+ if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
if (attribute_name == "decoder") {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
@@ -113,8 +114,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state,
init_run_gpt_subgraph_ = std::move(res.second);
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
}
-
- } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) {
+ } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) {
if (attribute_name == "encoder") {
ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
"SetupSubgraphExecutionInfo should only be called once for each subgraph.");
@@ -167,7 +167,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
// Make a copy of parameters since we will update it based on inputs later
BeamSearchParameters parameters = parameters_;
- if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
+ if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32
BeamSearchGpt impl{
*ctx_internal,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
index d909e77e0f..490a68d240 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h
@@ -191,7 +191,8 @@ Status BeamSearchBase::CheckInputs(const OpKernelContextInternal& context) {
context.Input(0), // input_ids
context.Input(7), // vocab_mask
context.Input(8), // prefix_vocab_mask
- context.Input(9))); // attention_mask
+ context.Input(9), // attention_mask
+ nullptr)); // presence_mask
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
index 0269efd319..bd3a72e989 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc
@@ -18,7 +18,7 @@ Status BeamSearchParameters::Validate() const {
}
void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) {
- model_type = static_cast(info.GetAttrOrDefault("model_type", IBeamSearchParameters::kModelTypeGpt));
+ model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeGpt));
early_stopping = info.GetAttrOrDefault("early_stopping", 0) == 1;
eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1));
pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1));
@@ -35,7 +35,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
batch_size = static_cast(dims[0]);
// For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1
- sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast(dims[1]) : 1;
+ sequence_length = (this->model_type == IGenerationParameters::kModelTypeGpt) ? static_cast(dims[1]) : 1;
auto* max_length_tensor = context->Input(1);
max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : kMaxSequenceLength;
@@ -71,10 +71,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) {
void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) {
// Override vocab_size using the inferred shape from the decoder subgraph ONLY IF
// the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch)
- if (vocab_size == -1) {
+ if (vocab_size == -1 || vocab_size == 0) {
vocab_size = vocabulary_size;
}
-
num_heads = heads;
head_size = hidden_size_per_head;
num_layers = layers;
diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h
index 1a3a87bd3f..0cb2b39976 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h
@@ -10,7 +10,7 @@ namespace onnxruntime {
namespace contrib {
namespace transformers {
-struct BeamSearchParameters : public IBeamSearchParameters {
+struct BeamSearchParameters : public IGenerationParameters {
Status Validate() const;
int BatchBeamSize() const { return batch_size * num_beams; }
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
index 1eb0d0634f..7f04f936a6 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h
@@ -87,7 +87,8 @@ class GenerateBase {
const Tensor* input_ids,
const Tensor* vocab_mask,
const Tensor* prefix_vocab_mask,
- const Tensor* attention_mask) const {
+ const Tensor* attention_mask,
+ const Tensor* presence_mask) const {
const auto& dims = input_ids->Shape().GetDims();
if (dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
@@ -149,6 +150,28 @@ class GenerateBase {
}
}
+ if (presence_mask != nullptr) {
+ const auto& dims_presence = presence_mask->Shape().GetDims();
+ if (dims_presence.size() != 2) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'presence_mask' is expected to have 2 dimensions, got ", dims_presence.size());
+ }
+
+ // presence_mask first dimension should be same as the first dimension of input_ids
+ if (static_cast(dims_presence[0]) != static_cast(dims[0])) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "input_ids and presence_mask must have the same batch_size");
+ }
+
+ if (static_cast(dims_presence[1]) != parameters->vocab_size) {
+ return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
+ "Input 'presence_mask' shape[1] shall be vocab_size, got ", dims_presence[1]);
+ }
+
+ // store prefix vocab mask in parameters.
+ parameters->presence_mask = presence_mask->DataAsSpan();
+ }
+
return Status::OK();
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
index d62eccb84a..518afa294a 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc
@@ -6,11 +6,13 @@
#include
#include "core/providers/cpu/math/top_k.h"
#include "core/providers/cpu/math/softmax_shared.h"
+#include "core/providers/cpu/generator/random.h"
#include "core/common/safeint.h"
#include "core/common/gsl.h"
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/beam_search_scorer.h"
#include "contrib_ops/cpu/transformers/generation_device_helper.h"
+#include "contrib_ops/cpu/transformers/sampling_cpu_helper.h"
#include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h"
#include "contrib_ops/cpu/transformers/subgraph_gpt.h"
@@ -175,6 +177,13 @@ Status CreateGptInputs(
// Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length)
// TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance.
+ if (num_beams == 1) {
+ expanded_input_ids = std::move(input_ids);
+ expanded_position_ids = std::move(position_ids);
+ expanded_attention_mask = std::move(attention_mask);
+ return Status::OK();
+ }
+
ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids);
ExpandInputs(position_ids, num_beams, allocator, expanded_position_ids);
ExpandInputs(attention_mask, num_beams, allocator, expanded_attention_mask);
@@ -243,7 +252,7 @@ Status ProcessLogits(const OrtValue& logits, //
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
transformers::IBeamScorer* beam_scorer, // beam scorer
- const transformers::IBeamSearchParameters* parameters, // parameters
+ const transformers::IGenerationParameters* parameters, // parameters
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
@@ -400,17 +409,16 @@ template
Status GreedySearchProcessLogits(
const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState* greedy_state, // state
+ transformers::ISamplingState* sampling_state, // sampling_state
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
- const transformers::IBeamSearchParameters* parameters, // parameters
+ const transformers::IGenerationParameters* parameters, // parameters
+ bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper) { // tensor dumper
-#ifndef DEBUG_GENERATION
- ORT_UNUSED_PARAMETER(dumper);
-#endif
int batch_size = parameters->batch_size;
int vocab_size = parameters->vocab_size;
@@ -448,6 +456,18 @@ Status GreedySearchProcessLogits(
dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size);
#endif
+ if (do_sampling) {
+ ORT_RETURN_IF_ERROR(SamplingCpuHelper::Sample(allocator,
+ thread_pool,
+ next_token_scores,
+ sampling_state,
+ greedy_state,
+ parameters,
+ dumper));
+
+ return Status::OK();
+ }
+
// next_tokens = torch.argmax(scores, dim=-1)
int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size};
TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2);
@@ -460,28 +480,27 @@ Status GreedySearchProcessLogits(
next_token_scores_value);
const Tensor& input = next_token_scores_value.Get();
- constexpr int axis = 1;
constexpr unsigned top_k = 1;
+ constexpr int axis = 1;
constexpr bool largest = true;
constexpr bool sorted = false;
Tensor topk_scores;
Tensor topk_indices;
- ORT_RETURN_IF_ERROR(
- TopK(&input,
- axis,
- top_k,
- largest,
- sorted,
- allocator,
- stream,
- thread_pool,
- topk_scores,
- topk_indices));
+ ORT_RETURN_IF_ERROR(TopK(&input,
+ axis,
+ top_k,
+ largest,
+ sorted,
+ allocator,
+ stream,
+ thread_pool,
+ topk_scores,
+ topk_indices));
#ifdef DEBUG_GENERATION
- dumper->Print("topk_scores", topk_scores);
- dumper->Print("topk_indices", topk_indices);
+ dumper->Print("topk_scores", topk_scores);
+ dumper->Print("topk_indices", topk_indices);
#endif
gsl::span next_token_indices = topk_indices.DataAsSpan();
@@ -829,7 +848,7 @@ template Status ProcessLogits(
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ILogitsProcessorList* logits_processors,
transformers::IBeamScorer* beam_scorer,
- const transformers::IBeamSearchParameters* parameters,
+ const transformers::IGenerationParameters* parameters,
int step,
Stream* stream,
const transformers::IConsoleDumper* dumper);
@@ -837,11 +856,13 @@ template Status ProcessLogits(
template Status GreedySearchProcessLogits(
const OrtValue& logits,
transformers::IGreedySearchState* greedy_state,
+ transformers::ISamplingState* sampling_state,
transformers::ISequences* sequences,
AllocatorPtr& allocator,
onnxruntime::concurrency::ThreadPool* thread_pool,
transformers::ILogitsProcessorList* logits_processors,
- const transformers::IBeamSearchParameters* parameters,
+ const transformers::IGenerationParameters* parameters,
+ bool do_sampling,
int step,
Stream* ort_stream,
const transformers::IConsoleDumper* dumper);
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
index 67dd55ca2a..16b0c4d6a3 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h
@@ -83,7 +83,7 @@ using ProcessLogitsFunc = std::function; // tensor dumper
@@ -92,11 +92,13 @@ template
using GreedySearchProcessLogitsFunc = std::function* greedy_state, // state
+ transformers::ISamplingState* sampling_state, // sampling buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
- const transformers::IBeamSearchParameters* parameters, // parameters
+ const transformers::IGenerationParameters* parameters, // parameters
+ bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* ort_stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper)>; // tensor dumper
@@ -203,7 +205,7 @@ Status ProcessLogits(const OrtValue& logits, //
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
transformers::IBeamScorer* beam_scorer, // beam scorer
- const transformers::IBeamSearchParameters* parameters, // parameters
+ const transformers::IGenerationParameters* parameters, // parameters
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper); // tensor dumper
@@ -211,11 +213,13 @@ Status ProcessLogits(const OrtValue& logits, //
template
Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph
transformers::IGreedySearchState* greedy_state, // state
+ transformers::ISamplingState* sampling_state, // sampling buffers
transformers::ISequences* sequences, // sequences
AllocatorPtr& allocator, // default allocator
onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only)
transformers::ILogitsProcessorList* logits_processors, // logits processors
- const transformers::IBeamSearchParameters* parameters, // parameters
+ const transformers::IGenerationParameters* parameters, // parameters
+ bool do_sampling, // whether to do sampling
int step, // iteration counter
Stream* stream, // cuda stream (for CUDA only)
const transformers::IConsoleDumper* dumper); // tensor dumper
diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
index edbebcc81a..4cc5bf380f 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h
@@ -4,6 +4,7 @@
#pragma once
#include
+#include
#include "core/common/gsl.h"
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
@@ -33,7 +34,7 @@ struct IBeamSearchState {
gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size)
gsl::span remaining_scores; // portion of scores that is available for appending next token scores.
gsl::span topk_buffer; // temp buffer for topk computation, including:
- // 1st stage needs:
+ // 1st stage needs:
// temp score: (batch_size * num_beams * parts_vocab, 2 * num_beams)
// temp token: (batch_size * num_beams * parts_vocab, 2 * num_beams)
// 2nd stage needs:
@@ -65,6 +66,28 @@ struct IGreedySearchState {
gsl::span next_tokens; // shape (batch_size)
};
+template
+struct ISamplingState {
+ gsl::span d_index_in;
+ gsl::span d_index_out;
+ gsl::span d_offset;
+ gsl::span d_sorted_score;
+ gsl::span d_sorted_softmaxed_score;
+ gsl::span d_softmaxed_score;
+ gsl::span h_softmaxed_score;
+ gsl::span d_sampled;
+ gsl::span h_sampled_all;
+ gsl::span d_indices;
+ gsl::span d_presence_mask;
+
+ BufferUniquePtr storage_buffer;
+ size_t temp_storage_bytes;
+ std::default_random_engine generator;
+
+ gsl::span sorted_scores;
+ gsl::span cumulative_probs;
+};
+
class ISequences {
public:
virtual ~ISequences() {}
@@ -96,7 +119,7 @@ class IBeamScorer {
Tensor* output_sequence_scores) = 0;
};
-struct IBeamSearchParameters {
+struct IGenerationParameters {
static constexpr int kModelTypeGpt = 0;
static constexpr int kModelTypeT5 = 1;
@@ -120,6 +143,7 @@ struct IBeamSearchParameters {
gsl::span vocab_mask;
gsl::span prefix_vocab_mask;
+ gsl::span presence_mask;
// Parameters from outputs.
bool output_scores; // whether scores existed in output
@@ -129,6 +153,15 @@ struct IBeamSearchParameters {
int num_heads;
int head_size;
int num_layers;
+
+ // Parameters for TopK/TopP sampling.
+ float presence_penalty;
+ float filter_value;
+ float temperature = 1.0f;
+ float top_p = 0.0f;
+ int seed = 0;
+ int min_tokens_to_keep = 1;
+ bool custom_sampling = false;
};
class IConsoleDumper {
diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc
index 01747d71cb..a33d03738e 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc
@@ -81,16 +81,15 @@ void GreedySearch::Init(const OpKernelInfo& info) {
parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size);
// Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
- ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt ||
- parameters_.model_type == IBeamSearchParameters::kModelTypeT5);
+ ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt);
ONNX_NAMESPACE::GraphProto proto;
- if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) {
+ if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
// Make sure the encoder sub-graph attribute is present for the T5 model.
ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK());
}
- if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) {
+ if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
// Check if the init_decoder sub-graph attribute is present for the GPT2 model.
if (info.GetAttr("init_decoder", &proto).IsOK()) {
has_init_decoder_ = true;
@@ -105,7 +104,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
const std::string& attribute_name,
const SessionState& subgraph_session_state) {
const auto& node = Node();
- if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { // GPT-2
+ if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2
if (attribute_name == "decoder") {
ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
@@ -131,8 +130,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat
init_run_gpt_subgraph_ = std::move(res.second);
init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
}
-
- } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { // encoder-decoder like T5
+ } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5
ORT_THROW("Not Implemented");
// if (attribute_name == "encoder") {
// ORT_ENFORCE(t5_encoder_subgraph_ == nullptr,
@@ -186,7 +184,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
if (parameters_.model_type == 0) { // GPT-2
// Subgraph has constraint that the output is either float or float16
if (!gpt_subgraph_->IsOutputFloat16()) {
- GreedySearchGpt impl{
+ GreedySearchGpt impl{
*ctx_internal,
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
@@ -207,7 +205,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const {
return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
} else {
- GreedySearchGpt impl{
+ GreedySearchGpt impl{
*ctx_internal,
has_init_decoder_ ? init_run_decoder_session_state : nullptr,
has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr,
diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h
index 46f242400f..e0e611a2c3 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h
@@ -21,7 +21,6 @@ namespace transformers {
using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel
-// bugbug: refactor
class GreedySearch : public IControlFlowKernel {
public:
explicit GreedySearch(const OpKernelInfo& info)
diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h
index 408d2655b2..724db62219 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#pragma once
+#include
#include
#include "contrib_ops/cpu/transformers/generation_shared.h"
#include "contrib_ops/cpu/transformers/generate_impl_base.h"
@@ -11,6 +12,63 @@ namespace contrib {
namespace transformers {
+template
+struct SamplingState : public ISamplingState {
+ void Init(AllocatorPtr allocator,
+ AllocatorPtr cpu_allocator,
+ int batch_size,
+ int vocab_size,
+ int max_iter,
+ int seed,
+ bool is_cuda) {
+ int total_count = batch_size * vocab_size;
+
+ this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count));
+
+ this->generator = std::default_random_engine{gsl::narrow_cast(seed)};
+
+ if (is_cuda) {
+ this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count));
+ this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count));
+ this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1));
+ this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count));
+ this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count));
+ this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count));
+ this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size));
+ this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter));
+ this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size));
+ this->temp_storage_bytes = 0;
+ // TODO: Do not allocate this buffer if there's no presence_mask
+ this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count));
+
+ std::uniform_real_distribution distribution(0.0, 1.0);
+ static_cast(distribution(this->generator));
+ for (size_t i = 0; i < this->h_sampled_all.size(); ++i) {
+ this->h_sampled_all[i] = distribution(this->generator);
+ }
+ } else {
+ // TODO: Some buffer can be reused for CPU
+ this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count));
+ this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count));
+ }
+ }
+
+ private:
+ BufferUniquePtr d_index_in_buffer_;
+ BufferUniquePtr d_index_out_buffer_;
+ BufferUniquePtr d_offset_buffer_;
+ BufferUniquePtr d_sorted_score_buffer_;
+ BufferUniquePtr d_sorted_softmaxed_score_buffer_;
+ BufferUniquePtr d_softmaxed_score_buffer_;
+ BufferUniquePtr h_softmaxed_score_buffer_;
+ BufferUniquePtr d_sampled_buffer_;
+ BufferUniquePtr h_sampled_all_buffer_;
+ BufferUniquePtr d_indices_buffer_;
+ BufferUniquePtr d_presence_mask_buffer_;
+ BufferUniquePtr sorted_scores_buffer_;
+ BufferUniquePtr cumulative_probs_buffer_;
+};
+
template
struct GreedySearchState : public IGreedySearchState {
Sequences sequences;
@@ -68,7 +126,7 @@ struct GreedySearchState : public IGreedySearchState {
};
// Base class of gready search implementation that is common for both GPT-2 and Bart/T5.
-template
+template
class GreedySearchBase : public GenerateBase {
public:
GreedySearchBase(OpKernelContextInternal& context,
@@ -76,7 +134,7 @@ class GreedySearchBase : public GenerateBase {
concurrency::ThreadPool* thread_pool,
Stream* ort_stream,
IConsoleDumper* cuda_dumper,
- GreedySearchParameters& params,
+ ParametersT& params,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_func,
const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func)
@@ -105,23 +163,25 @@ class GreedySearchBase : public GenerateBase {
Status GenerateNextToken(const OrtValue& logits,
gsl::span& next_tokens,
GreedySearchState& greedy_state,
+ ISamplingState& sampling_state,
int counter,
int eos_token_id);
// Calculate scores from logits, then apply filtering and select next token for each beam.
Status ProcessLogits(const OrtValue& logits, // logits output of subgraph
GreedySearchState& greedy_state,
+ ISamplingState& sampling_state,
AllocatorPtr& allocator,
int counter);
- GreedySearchParameters* parameters_;
+ ParametersT* parameters_;
// Device specific functions
GenerationDeviceHelper::GreedySearchProcessLogitsFunc process_logits_func_;
};
-template
-Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context) {
+template
+Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context) {
// Input shapes:
// input_ids : (batch_size, sequence_length)
// vocab_mask : (vocab_size) or nullptr
@@ -129,13 +189,14 @@ Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context)
context.Input(0), // input_ids
context.Input(4), // vocab_mask
context.Input(5), // prefix_vocab_mask
- nullptr)); // attention_mask
+ context.Input(6), // attention_mask
+ context.Input(7))); // presence_mask
return Status::OK();
}
-template
-Status GreedySearchBase::Initialize() {
+template
+Status GreedySearchBase::Initialize() {
ORT_RETURN_IF_ERROR(this->context_.GetTempSpaceAllocator(&this->temp_space_allocator_));
ORT_RETURN_IF_ERROR(CheckScalarInput("max_length", 1, true));
@@ -155,26 +216,29 @@ Status GreedySearchBase::Initialize() {
return Status::OK();
}
-template
-Status GreedySearchBase::ProcessLogits(
+template
+Status GreedySearchBase::ProcessLogits(
const OrtValue& logits,
GreedySearchState& greedy_state,
+ ISamplingState& sampling_state,
AllocatorPtr& allocator,
int counter) {
- return process_logits_func_(logits, &greedy_state, &(greedy_state.sequences), allocator,
- this->thread_pool_, &this->logits_processors_,
- parameters_, counter, this->ort_stream_, this->GetConsoleDumper());
+ bool use_sampling = std::is_same::value;
+ return process_logits_func_(logits, &greedy_state, &sampling_state, &(greedy_state.sequences), allocator,
+ this->thread_pool_, &this->logits_processors_, parameters_,
+ use_sampling, counter, this->ort_stream_, this->GetConsoleDumper());
}
-template
-Status GreedySearchBase::GenerateNextToken(
+template
+Status GreedySearchBase::GenerateNextToken(
const OrtValue& logits,
gsl::span& next_tokens,
GreedySearchState& greedy_state,
+ ISamplingState& sampling_state,
int counter,
int eos_token_id) {
// Process logits to get next token scores
- ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, this->temp_space_allocator_, counter));
+ ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, sampling_state, this->temp_space_allocator_, counter));
next_tokens = greedy_state.next_tokens;
for (size_t i = 0; i < next_tokens.size(); i++) {
diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h
index cf12fe6ba1..dfcc309274 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h
@@ -24,8 +24,8 @@ std::pair> CreateGptSubgraphAndUpdateParame
} // namespace gpt_details
// Greedy search implementation for GPT-2 model.
-template
-class GreedySearchGpt : public GreedySearchBase {
+template
+class GreedySearchGpt : public GreedySearchBase {
public:
GreedySearchGpt(OpKernelContextInternal& context,
const SessionState* init_run_decoder_session_state,
@@ -35,7 +35,7 @@ class GreedySearchGpt : public GreedySearchBase {
concurrency::ThreadPool* thread_pool,
Stream* ort_stream,
IConsoleDumper* cuda_dumper,
- GreedySearchParameters& params,
+ ParametersT& params,
const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func,
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
@@ -43,15 +43,15 @@ class GreedySearchGpt : public GreedySearchBase {
const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_func,
const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func,
const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func)
- : GreedySearchBase(context,
- decoder_session_state,
- thread_pool,
- ort_stream,
- cuda_dumper,
- params,
- topk_func,
- process_logits_func,
- device_copy_func),
+ : GreedySearchBase(context,
+ decoder_session_state,
+ thread_pool,
+ ort_stream,
+ cuda_dumper,
+ params,
+ topk_func,
+ process_logits_func,
+ device_copy_func),
init_run_decoder_session_state_(init_run_decoder_session_state),
init_run_gpt_subgraph_(init_run_gpt_subgraph),
gpt_subgraph_(gpt_subgraph),
@@ -94,8 +94,8 @@ class GreedySearchGpt : public GreedySearchBase {
GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_;
};
-template
-Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths,
+template
+Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths,
OrtValue& expanded_input_ids,
std::vector& feeds,
IAllocatorUniquePtr& buffer) {
@@ -134,8 +134,8 @@ Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengt
this->parameters_->max_length);
}
-template
-Status GreedySearchGpt::UpdateFeeds(
+template
+Status GreedySearchGpt::UpdateFeeds(
const std::vector& last_outputs,
std::vector& next_inputs,
int current_length,
@@ -161,11 +161,11 @@ Status GreedySearchGpt::UpdateFeeds(
);
}
-template
-Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager,
- const FeedsFetchesManager& feeds_fetches_manager) {
+template
+Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager,
+ const FeedsFetchesManager& feeds_fetches_manager) {
auto status = Status::OK();
- const GreedySearchParameters* parameters = this->parameters_;
+ const ParametersT* parameters = this->parameters_;
// Allocate output tensors.
int64_t sequences_dims[] = {parameters->batch_size, parameters->max_length};
@@ -184,6 +184,17 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fet
parameters->max_length,
this->IsCuda());
+ SamplingState sampling_state;
+ if (std::is_same::value) {
+ sampling_state.Init(this->temp_space_allocator_,
+ this->cpu_allocator_,
+ static_cast(parameters->BatchBeamSize()),
+ static_cast(parameters->vocab_size),
+ static_cast(parameters->max_length - parameters->sequence_length),
+ parameters->seed,
+ this->IsCuda());
+ }
+
IAllocatorUniquePtr buffer;
OrtValue expanded_input_ids_in_cpu;
ORT_RETURN_IF_ERROR(CreateInitialFeeds(greedy_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer));
@@ -276,6 +287,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fet
ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits,
next_tokens,
greedy_state,
+ sampling_state,
iteration_counter,
parameters->eos_token_id));
@@ -324,6 +336,27 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fet
gsl::copy(sequence_source, batch_output);
}
+#ifdef DEBUG_GENERATION
+ // Debug the one step filtered logits for sampling
+ int64_t filtered_logits_dims[] = {parameters->batch_size, parameters->vocab_size};
+ TensorShape filtered_logits_shape(&filtered_logits_dims[0],
+ sizeof(filtered_logits_dims) / sizeof(filtered_logits_dims[0]));
+ Tensor* filtered_logits = this->context_.Output(1, filtered_logits_shape);
+ if (filtered_logits != nullptr) {
+ gsl::span filtered_logits_span = filtered_logits->MutableDataAsSpan();
+ for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) {
+ auto batch_output = filtered_logits_span.subspan(
+ static_cast(batch_id) * parameters->vocab_size,
+ parameters->vocab_size);
+ gsl::span batch_filtered_logits = gsl::make_span(sampling_state.h_softmaxed_score.data() +
+ batch_id * parameters->vocab_size,
+ parameters->vocab_size);
+
+ gsl::copy(batch_filtered_logits, batch_output);
+ }
+ }
+#endif
+
return status;
}
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
index 2f1e657c8e..d0641fedf9 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc
@@ -6,8 +6,12 @@
#include "core/common/narrow.h"
#include "core/common/safeint.h"
#include "core/common/span_utils.h"
+#include "core/providers/cpu/math/softmax_shared.h"
#include "contrib_ops/cpu/transformers/logits_processor.h"
#include "contrib_ops/cpu/transformers/dump_tensor.h"
+#include
+#include
+#include
namespace onnxruntime {
namespace contrib {
@@ -187,6 +191,53 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/,
#endif
}
+template
+TemperatureLogitsProcessor::TemperatureLogitsProcessor(float temperature) : temperature_(temperature) {
+}
+
+template
+void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/,
+ NextTokenScores& next_token_scores) {
+ if (temperature_ == 1.0f) {
+ return;
+ }
+
+ T* p = next_token_scores.scores.data();
+ for (size_t i = 0; i < next_token_scores.scores.size(); i++) {
+ *p /= temperature_;
+ ++p;
+ }
+
+#ifdef DEBUG_GENERATION
+ DumpScores("TemperatureLogitsProcessor", next_token_scores);
+#endif
+}
+
+template
+PresencePenaltyLogitsProcessor::PresencePenaltyLogitsProcessor(const gsl::span& presence_mask,
+ float presence_penalty)
+ : presence_mask_(presence_mask), presence_penalty_(presence_penalty) {
+}
+
+template
+void PresencePenaltyLogitsProcessor::Process(const ISequences*,
+ NextTokenScores& next_token_scores) {
+ if (presence_penalty_ == 0.0f) {
+ return;
+ }
+
+ assert(!presence_mask_.empty());
+
+ T* p = next_token_scores.scores.data();
+ for (size_t i = 0; i < next_token_scores.scores.size(); i++) {
+ *p -= presence_mask_[i] * presence_penalty_;
+ }
+
+#ifdef DEBUG_GENERATION
+ DumpScores("PresencePenaltyLogitsProcessor", next_token_scores);
+#endif
+}
+
void LogitsProcessorList::Init(const BeamSearchParameters& parameters) {
LogitsProcessorInitImpl(parameters);
}
@@ -195,6 +246,10 @@ void LogitsProcessorList::Init(const GreedySearchParameters& parameters) {
LogitsProcessorInitImpl(parameters);
}
+void LogitsProcessorList::Init(const SamplingParameters& parameters) {
+ LogitsProcessorInitImpl(parameters);
+}
+
void LogitsProcessorList::Process(const ISequences* sequences,
gsl::span& next_token_scores,
int step) {
diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
index 1a2fba19bf..4a516474c8 100644
--- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
+++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h
@@ -7,6 +7,7 @@
#include "contrib_ops/cpu/transformers/sequences.h"
#include "contrib_ops/cpu/transformers/beam_search_parameters.h"
#include "contrib_ops/cpu/transformers/greedy_search_parameters.h"
+#include "contrib_ops/cpu/transformers/sampling_parameters.h"
#include "contrib_ops/cpu/transformers/generation_shared.h"
namespace onnxruntime {
@@ -96,11 +97,53 @@ class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor {
const int batch_size_;
};
+template
+class TemperatureLogitsProcessor : public ILogitsProcessor {
+ public:
+ TemperatureLogitsProcessor(float temperature);
+
+ void Process(const ISequences* sequences,
+ NextTokenScores& next_token_scores) override;
+
+ private:
+ float temperature_;
+};
+
+// template
+// class TopPLogitsProcessor : public ILogitsProcessor {
+// public:
+// TopPLogitsProcessor(float top_p, float filter_value,
+// onnxruntime::concurrency::ThreadPool* thread_pool);
+
+// void Process(const ISequences* sequences,
+// NextTokenScores& next_token_scores) override;
+
+// private:
+// float top_p_;
+// float filter_value_;
+// onnxruntime::concurrency::ThreadPool* thread_pool_;
+// };
+
+template
+class PresencePenaltyLogitsProcessor : public ILogitsProcessor {
+ public:
+ PresencePenaltyLogitsProcessor(const gsl::span& presence_mask,
+ float presence_penalty);
+
+ void Process(const ISequences* sequences,
+ NextTokenScores& next_token_scores) override;
+
+ private:
+ gsl::span presence_mask_;
+ float presence_penalty_;
+};
+
class LogitsProcessorList : public ILogitsProcessorList {
public:
LogitsProcessorList() = default;
void Init(const BeamSearchParameters& parameters);
void Init(const GreedySearchParameters& parameters);
+ void Init(const SamplingParameters& parameters);
void Process(const ISequences* sequences, gsl::span& next_token_scores, int step);
private:
@@ -140,6 +183,19 @@ class LogitsProcessorList : public ILogitsProcessorList {
processor_list_.push_back(min_length_processor_.get());
}
+ if (parameters.temperature > 0) {
+ temperature_processor_ = std::make_unique>(parameters.temperature);
+ processor_list_.push_back(temperature_processor_.get());
+ }
+
+ if (!parameters.presence_mask.empty()) {
+ presence_penalty_processor_ = std::make_unique<
+ PresencePenaltyLogitsProcessor
+ >(parameters.presence_mask,
+ parameters.presence_penalty);
+ processor_list_.push_back(presence_penalty_processor_.get());
+ }
+
batch_beam_size_ = parameters.BatchBeamSize();
vocab_size_ = parameters.vocab_size;
}
@@ -153,6 +209,8 @@ class LogitsProcessorList : public ILogitsProcessorList {
std::unique_ptr> vocab_mask_processor_;
std::unique_ptr> prefix_vocab_mask_processor_;
std::unique_ptr> min_length_processor_;
+ std::unique_ptr> temperature_processor_;
+ std::unique_ptr> presence_penalty_processor_;
};
} // namespace transformers
diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc
new file mode 100644
index 0000000000..a9b2db40e9
--- /dev/null
+++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc
@@ -0,0 +1,175 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+// there's no way to use a raw pointer as the copy destination with std::copy_n
+// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset
+// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning.
+#ifdef _MSC_VER
+#pragma warning(push)
+#pragma warning(disable : 4996)
+#endif
+
+#include "core/framework/op_kernel_context_internal.h"
+#include "core/framework/utils.h"
+#include "contrib_ops/cpu/transformers/sampling.h"
+#include "contrib_ops/cpu/transformers/logits_processor.h"
+#include "contrib_ops/cpu/transformers/sequences.h"
+#include "contrib_ops/cpu/transformers/dump_tensor.h"
+#include "contrib_ops/cpu/transformers/greedy_search_impl_gpt.h"
+
+using namespace ONNX_NAMESPACE;
+using namespace onnxruntime::common;
+
+namespace onnxruntime {
+namespace contrib {
+
+#define REGISTER_KERNEL_TYPED(T) \
+ ONNX_OPERATOR_TYPED_KERNEL_EX( \
+ Sampling, \
+ kMSDomain, \
+ 1, \
+ T, \
+ kCpuExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("T", DataTypeImpl::GetTensorType()), \
+ transformers::Sampling);
+
+REGISTER_KERNEL_TYPED(float)
+
+namespace transformers {
+
+void Sampling::Init(const OpKernelInfo& info) {
+ parameters_.ParseFromAttributes(info);
+ parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size);
+
+ // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5)
+ ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt);
+
+ ONNX_NAMESPACE::GraphProto proto;
+ if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) {
+ // Make sure the encoder sub-graph attribute is present for the T5 model.
+ ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK());
+ }
+
+ if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) {
+ // Check if the init_decoder sub-graph attribute is present for the GPT2 model.
+ if (info.GetAttr("init_decoder", &proto).IsOK()) {
+ has_init_decoder_ = true;
+ }
+ }
+
+ // Make sure the decoder sub-graph attribute is present for all model types.
+ ORT_ENFORCE(info.GetAttr("decoder", &proto).IsOK());
+}
+
+Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state,
+ const std::string& attribute_name,
+ const SessionState& subgraph_session_state) {
+ const auto& node = Node();
+ if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2
+ if (attribute_name == "decoder") {
+ ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
+ auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
+ subgraph_session_state, parameters_);
+
+ auto status = res.first;
+ if (!status.IsOK()) {
+ return status;
+ }
+
+ gpt_subgraph_ = std::move(res.second);
+ decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager();
+ } else if (attribute_name == "init_decoder") {
+ ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph.");
+ auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name,
+ subgraph_session_state, parameters_);
+
+ auto status = res.first;
+ if (!status.IsOK()) {
+ return status;
+ }
+
+ init_run_gpt_subgraph_ = std::move(res.second);
+ init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager();
+ }
+ } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5
+ ORT_THROW("Not Implemented");
+ }
+
+ return Status::OK();
+}
+
+Status Sampling::Compute(OpKernelContext* ctx) const {
+ auto* ctx_internal = static_cast