diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 2784a1c431..23ea679519 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -75,10 +75,13 @@ set(contrib_ops_excluded_files "transformers/beam_search.h" "transformers/generation_device_helper.cc" "transformers/generation_device_helper.h" - "transformers/beam_search_impl.cu" - "transformers/beam_search_impl.h" + "transformers/generation_cuda_impl.cu" + "transformers/generation_cuda_impl.h" "transformers/greedy_search.cc" "transformers/greedy_search.h" + "transformers/sampling.cc" + "transformers/sampling.h" + "transformers/sampling_cuda_helper.h" "transformers/dump_cuda_tensor.cc" "transformers/dump_cuda_tensor.h" "conv_transpose_with_dynamic_pads.cc" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 76e67a31c9..4fc0d4c3e1 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -74,6 +74,7 @@ Do not modify directly.* * com.microsoft.RestorePadding * com.microsoft.Rfft * com.microsoft.SampleOp + * com.microsoft.Sampling * com.microsoft.SkipLayerNormalization * com.microsoft.Snpe * com.microsoft.SparseToDenseMatMul @@ -3810,6 +3811,89 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.Sampling** + + Greedy Sampling for text generation. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
custom : int
+
If 1 custom sampling logic
+
decoder : graph (required)
+
Decoder subgraph to execute in a loop.
+
decoder_start_token_id : int
+
The id of the token that indicates decoding starts.
+
encoder : graph
+
The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.
+
eos_token_id : int (required)
+
The id of the end-of-sequence token
+
filter_value : float
+
All filtered values will be set to this float value.
+
init_decoder : graph
+
The subgraph for the first decoding run. It will be called once before `decoder` subgraph. This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs
+
min_tokens_to_keep : int
+
Minimumber of tokens we keep per batch example in the output.
+
model_type : int
+
Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart
+
no_repeat_ngram_size : int
+
no repeat ngrams size
+
pad_token_id : int (required)
+
The id of the padding token
+
presence_penalty : float
+
Presence penalty for custom sampling
+
temperature : float
+
The value used to module the next token probabilities.
+
top_p : float
+
If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.
+
vocab_size : int
+
Size of the vocabulary. If not provided, it will be inferred from the decoder subgraph's output shape
+
+ +#### Inputs (2 - 8) + +
+
input_ids : I
+
The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)
+
max_length : I
+
The maximum length of the sequence to be generated. Shape is (1)
+
min_length (optional) : I
+
The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)
+
repetition_penalty (optional) : T
+
The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)
+
vocab_mask (optional) : I
+
Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)
+
prefix_vocab_mask (optional) : I
+
Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)
+
attention_mask (optional) : I
+
Custom attention mask. Shape is (batch_size, sequence_length)
+
presence_mask (optional) : I
+
Presence penalty mask. Shape is (batch_size, vocab_size)
+
+ +#### Outputs (1 - 2) + +
+
sequences : I
+
Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)
+
filtered_logits (optional) : T
+
Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)
+
+ +#### Type Constraints + +
+
T : tensor(float)
+
Constrain input and output types to float tensors.
+
I : tensor(int32)
+
Constrain to integer types
+
+ + ### **com.microsoft.SkipLayerNormalization** Skip and Layer Normalization Fusion diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index e1789154c7..e86305de29 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -438,6 +438,7 @@ Do not modify directly.* |QuickGelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| |Range|*in* start:**T**
*in* limit:**T**
*in* delta:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(int16), tensor(int32), tensor(int64)| |SampleOp|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(float)| +|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_sum:**T**|1+|**T** = tensor(double), tensor(float)| |SparseToDenseMatMul|*in* A:**T**
*in* B:**T1**
*out* Y:**T1**|1+|**T** = sparse_tensor(double), sparse_tensor(float), sparse_tensor(int32), sparse_tensor(int64), sparse_tensor(uint32), sparse_tensor(uint64)
**T1** = tensor(double), tensor(float), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |Tokenizer|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(string)| @@ -797,6 +798,7 @@ Do not modify directly.* |RemovePadding|*in* input:**T**
*in* sequence_token_count:**M**
*out* output:**T**
*out* token_offset:**M**
*out* cumulated_seq_len:**M**
*out* max_seq_len:**M**|1+|**T** = tensor(float), tensor(float16)| |RestorePadding|*in* input:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Rfft|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|Sampling|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*in* presence_mask:**I**
*out* sequences:**I**
*out* filtered_logits:**T**|1+|**T** = tensor(float), tensor(float16)| |SkipLayerNormalization|*in* input:**T**
*in* skip:**T**
*in* gamma:**T**
*in* beta:**T**
*in* bias:**T**
*out* output:**T**
*out* mean:**U**
*out* inv_std_var:**U**
*out* input_skip_sum:**T**|1+|**T** = tensor(float), tensor(float16)| |TransposeMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |Trilu|*in* X:**T**
*in* k:**tensor(int64)**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc index 3f459dff99..a04ef0d71b 100644 --- a/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc @@ -18,6 +18,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedConv); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, FusedGemm); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, GreedySearch); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, Sampling); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, AttnLSTM); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, string, Tokenizer); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Range); @@ -199,6 +200,7 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index 1cb2c2050b..868356f70a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -61,16 +61,17 @@ void BeamSearch::Init(const OpKernelInfo& info) { parameters_.ParseFromAttributes(info); // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5) - ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt || - parameters_.model_type == IBeamSearchParameters::kModelTypeT5); + ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt || + parameters_.model_type == IGenerationParameters::kModelTypeT5); ONNX_NAMESPACE::GraphProto proto; - if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) { + + if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { // Make sure the encoder sub-graph attribute is present for the T5 model. ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); } - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // Check if the init_decoder sub-graph attribute is present for the GPT2 model. if (info.GetAttr("init_decoder", &proto).IsOK()) { has_init_decoder_ = true; @@ -87,7 +88,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) { const auto& node = Node(); - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { if (attribute_name == "decoder") { ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, @@ -113,8 +114,7 @@ Status BeamSearch::SetupSubgraphExecutionInfo(const SessionState& session_state, init_run_gpt_subgraph_ = std::move(res.second); init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager(); } - - } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { + } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { if (attribute_name == "encoder") { ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); @@ -167,7 +167,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { // Make a copy of parameters since we will update it based on inputs later BeamSearchParameters parameters = parameters_; - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { if (!gpt_subgraph_->IsOutputFloat16()) { // Output float32 BeamSearchGpt impl{ *ctx_internal, diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h index d909e77e0f..490a68d240 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_base.h @@ -191,7 +191,8 @@ Status BeamSearchBase::CheckInputs(const OpKernelContextInternal& context) { context.Input(0), // input_ids context.Input(7), // vocab_mask context.Input(8), // prefix_vocab_mask - context.Input(9))); // attention_mask + context.Input(9), // attention_mask + nullptr)); // presence_mask return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc index 0269efd319..bd3a72e989 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.cc @@ -18,7 +18,7 @@ Status BeamSearchParameters::Validate() const { } void BeamSearchParameters::ParseFromAttributes(const OpKernelInfo& info) { - model_type = static_cast(info.GetAttrOrDefault("model_type", IBeamSearchParameters::kModelTypeGpt)); + model_type = static_cast(info.GetAttrOrDefault("model_type", IGenerationParameters::kModelTypeGpt)); early_stopping = info.GetAttrOrDefault("early_stopping", 0) == 1; eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); @@ -35,7 +35,7 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { batch_size = static_cast(dims[0]); // For T5, output sequence starts with decoder_start_token_id, so its sequence length is 1 - sequence_length = (this->model_type == IBeamSearchParameters::kModelTypeGpt) ? static_cast(dims[1]) : 1; + sequence_length = (this->model_type == IGenerationParameters::kModelTypeGpt) ? static_cast(dims[1]) : 1; auto* max_length_tensor = context->Input(1); max_length = max_length_tensor ? static_cast(*max_length_tensor->Data()) : kMaxSequenceLength; @@ -71,10 +71,9 @@ void BeamSearchParameters::ParseFromInputs(OpKernelContext* context) { void BeamSearchParameters::SetSubgraphParameters(int vocabulary_size, int heads, int hidden_size_per_head, int layers) { // Override vocab_size using the inferred shape from the decoder subgraph ONLY IF // the vocab_size hasn't been explicitly specified by the user (as an attribute of BeamSearch) - if (vocab_size == -1) { + if (vocab_size == -1 || vocab_size == 0) { vocab_size = vocabulary_size; } - num_heads = heads; head_size = hidden_size_per_head; num_layers = layers; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h index 1a3a87bd3f..0cb2b39976 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_parameters.h @@ -10,7 +10,7 @@ namespace onnxruntime { namespace contrib { namespace transformers { -struct BeamSearchParameters : public IBeamSearchParameters { +struct BeamSearchParameters : public IGenerationParameters { Status Validate() const; int BatchBeamSize() const { return batch_size * num_beams; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h index 1eb0d0634f..7f04f936a6 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generate_impl_base.h @@ -87,7 +87,8 @@ class GenerateBase { const Tensor* input_ids, const Tensor* vocab_mask, const Tensor* prefix_vocab_mask, - const Tensor* attention_mask) const { + const Tensor* attention_mask, + const Tensor* presence_mask) const { const auto& dims = input_ids->Shape().GetDims(); if (dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, @@ -149,6 +150,28 @@ class GenerateBase { } } + if (presence_mask != nullptr) { + const auto& dims_presence = presence_mask->Shape().GetDims(); + if (dims_presence.size() != 2) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'presence_mask' is expected to have 2 dimensions, got ", dims_presence.size()); + } + + // presence_mask first dimension should be same as the first dimension of input_ids + if (static_cast(dims_presence[0]) != static_cast(dims[0])) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "input_ids and presence_mask must have the same batch_size"); + } + + if (static_cast(dims_presence[1]) != parameters->vocab_size) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, + "Input 'presence_mask' shape[1] shall be vocab_size, got ", dims_presence[1]); + } + + // store prefix vocab mask in parameters. + parameters->presence_mask = presence_mask->DataAsSpan(); + } + return Status::OK(); } diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc index d62eccb84a..518afa294a 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.cc @@ -6,11 +6,13 @@ #include #include "core/providers/cpu/math/top_k.h" #include "core/providers/cpu/math/softmax_shared.h" +#include "core/providers/cpu/generator/random.h" #include "core/common/safeint.h" #include "core/common/gsl.h" #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_scorer.h" #include "contrib_ops/cpu/transformers/generation_device_helper.h" +#include "contrib_ops/cpu/transformers/sampling_cpu_helper.h" #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" @@ -175,6 +177,13 @@ Status CreateGptInputs( // Expand (batch_size, sequence_length) to (batch_size * num_beams, sequence_length) // TODO(tianleiwu): Try expand outputs after first subgraph call instead. That may get better performance. + if (num_beams == 1) { + expanded_input_ids = std::move(input_ids); + expanded_position_ids = std::move(position_ids); + expanded_attention_mask = std::move(attention_mask); + return Status::OK(); + } + ExpandInputs(input_ids, num_beams, allocator, expanded_input_ids); ExpandInputs(position_ids, num_beams, allocator, expanded_position_ids); ExpandInputs(attention_mask, num_beams, allocator, expanded_attention_mask); @@ -243,7 +252,7 @@ Status ProcessLogits(const OrtValue& logits, // onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors transformers::IBeamScorer* beam_scorer, // beam scorer - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter Stream* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper @@ -400,17 +409,16 @@ template Status GreedySearchProcessLogits( const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingState* sampling_state, // sampling_state transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter Stream* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper -#ifndef DEBUG_GENERATION - ORT_UNUSED_PARAMETER(dumper); -#endif int batch_size = parameters->batch_size; int vocab_size = parameters->vocab_size; @@ -448,6 +456,18 @@ Status GreedySearchProcessLogits( dumper->Print("next_token_scores after logits processor", next_token_scores.data(), batch_size, 1, vocab_size); #endif + if (do_sampling) { + ORT_RETURN_IF_ERROR(SamplingCpuHelper::Sample(allocator, + thread_pool, + next_token_scores, + sampling_state, + greedy_state, + parameters, + dumper)); + + return Status::OK(); + } + // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), vocab_size}; TensorShape next_token_scores_shape(&next_token_scores_dims[0], 2); @@ -460,28 +480,27 @@ Status GreedySearchProcessLogits( next_token_scores_value); const Tensor& input = next_token_scores_value.Get(); - constexpr int axis = 1; constexpr unsigned top_k = 1; + constexpr int axis = 1; constexpr bool largest = true; constexpr bool sorted = false; Tensor topk_scores; Tensor topk_indices; - ORT_RETURN_IF_ERROR( - TopK(&input, - axis, - top_k, - largest, - sorted, - allocator, - stream, - thread_pool, - topk_scores, - topk_indices)); + ORT_RETURN_IF_ERROR(TopK(&input, + axis, + top_k, + largest, + sorted, + allocator, + stream, + thread_pool, + topk_scores, + topk_indices)); #ifdef DEBUG_GENERATION - dumper->Print("topk_scores", topk_scores); - dumper->Print("topk_indices", topk_indices); + dumper->Print("topk_scores", topk_scores); + dumper->Print("topk_indices", topk_indices); #endif gsl::span next_token_indices = topk_indices.DataAsSpan(); @@ -829,7 +848,7 @@ template Status ProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, transformers::IBeamScorer* beam_scorer, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, Stream* stream, const transformers::IConsoleDumper* dumper); @@ -837,11 +856,13 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, + transformers::ISamplingState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, + bool do_sampling, int step, Stream* ort_stream, const transformers::IConsoleDumper* dumper); diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 67dd55ca2a..16b0c4d6a3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -83,7 +83,7 @@ using ProcessLogitsFunc = std::function; // tensor dumper @@ -92,11 +92,13 @@ template using GreedySearchProcessLogitsFunc = std::function* greedy_state, // state + transformers::ISamplingState* sampling_state, // sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter Stream* ort_stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper)>; // tensor dumper @@ -203,7 +205,7 @@ Status ProcessLogits(const OrtValue& logits, // onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors transformers::IBeamScorer* beam_scorer, // beam scorer - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter Stream* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper @@ -211,11 +213,13 @@ Status ProcessLogits(const OrtValue& logits, // template Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingState* sampling_state, // sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter Stream* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h index edbebcc81a..4cc5bf380f 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_shared.h @@ -4,6 +4,7 @@ #pragma once #include +#include #include "core/common/gsl.h" #include "core/framework/allocator.h" #include "core/framework/ort_value.h" @@ -33,7 +34,7 @@ struct IBeamSearchState { gsl::span scores; // shape (max_length - sequence_length + 1, batch_size, num_beams * vocab_size) gsl::span remaining_scores; // portion of scores that is available for appending next token scores. gsl::span topk_buffer; // temp buffer for topk computation, including: - // 1st stage needs: + // 1st stage needs: // temp score: (batch_size * num_beams * parts_vocab, 2 * num_beams) // temp token: (batch_size * num_beams * parts_vocab, 2 * num_beams) // 2nd stage needs: @@ -65,6 +66,28 @@ struct IGreedySearchState { gsl::span next_tokens; // shape (batch_size) }; +template +struct ISamplingState { + gsl::span d_index_in; + gsl::span d_index_out; + gsl::span d_offset; + gsl::span d_sorted_score; + gsl::span d_sorted_softmaxed_score; + gsl::span d_softmaxed_score; + gsl::span h_softmaxed_score; + gsl::span d_sampled; + gsl::span h_sampled_all; + gsl::span d_indices; + gsl::span d_presence_mask; + + BufferUniquePtr storage_buffer; + size_t temp_storage_bytes; + std::default_random_engine generator; + + gsl::span sorted_scores; + gsl::span cumulative_probs; +}; + class ISequences { public: virtual ~ISequences() {} @@ -96,7 +119,7 @@ class IBeamScorer { Tensor* output_sequence_scores) = 0; }; -struct IBeamSearchParameters { +struct IGenerationParameters { static constexpr int kModelTypeGpt = 0; static constexpr int kModelTypeT5 = 1; @@ -120,6 +143,7 @@ struct IBeamSearchParameters { gsl::span vocab_mask; gsl::span prefix_vocab_mask; + gsl::span presence_mask; // Parameters from outputs. bool output_scores; // whether scores existed in output @@ -129,6 +153,15 @@ struct IBeamSearchParameters { int num_heads; int head_size; int num_layers; + + // Parameters for TopK/TopP sampling. + float presence_penalty; + float filter_value; + float temperature = 1.0f; + float top_p = 0.0f; + int seed = 0; + int min_tokens_to_keep = 1; + bool custom_sampling = false; }; class IConsoleDumper { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc index 01747d71cb..a33d03738e 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.cc @@ -81,16 +81,15 @@ void GreedySearch::Init(const OpKernelInfo& info) { parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size); // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5) - ORT_ENFORCE(parameters_.model_type == IBeamSearchParameters::kModelTypeGpt || - parameters_.model_type == IBeamSearchParameters::kModelTypeT5); + ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt); ONNX_NAMESPACE::GraphProto proto; - if (parameters_.model_type != IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { // Make sure the encoder sub-graph attribute is present for the T5 model. ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); } - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // Check if the init_decoder sub-graph attribute is present for the GPT2 model. if (info.GetAttr("init_decoder", &proto).IsOK()) { has_init_decoder_ = true; @@ -105,7 +104,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat const std::string& attribute_name, const SessionState& subgraph_session_state) { const auto& node = Node(); - if (parameters_.model_type == IBeamSearchParameters::kModelTypeGpt) { // GPT-2 + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2 if (attribute_name == "decoder") { ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, @@ -131,8 +130,7 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat init_run_gpt_subgraph_ = std::move(res.second); init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager(); } - - } else if (parameters_.model_type == IBeamSearchParameters::kModelTypeT5) { // encoder-decoder like T5 + } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5 ORT_THROW("Not Implemented"); // if (attribute_name == "encoder") { // ORT_ENFORCE(t5_encoder_subgraph_ == nullptr, @@ -186,7 +184,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const { if (parameters_.model_type == 0) { // GPT-2 // Subgraph has constraint that the output is either float or float16 if (!gpt_subgraph_->IsOutputFloat16()) { - GreedySearchGpt impl{ + GreedySearchGpt impl{ *ctx_internal, has_init_decoder_ ? init_run_decoder_session_state : nullptr, has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr, @@ -207,7 +205,7 @@ Status GreedySearch::Compute(OpKernelContext* ctx) const { return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); } else { - GreedySearchGpt impl{ + GreedySearchGpt impl{ *ctx_internal, has_init_decoder_ ? init_run_decoder_session_state : nullptr, has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr, diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h index 46f242400f..e0e611a2c3 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search.h @@ -21,7 +21,6 @@ namespace transformers { using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel -// bugbug: refactor class GreedySearch : public IControlFlowKernel { public: explicit GreedySearch(const OpKernelInfo& info) diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h index 408d2655b2..724db62219 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_base.h @@ -2,6 +2,7 @@ // Licensed under the MIT License. #pragma once +#include #include #include "contrib_ops/cpu/transformers/generation_shared.h" #include "contrib_ops/cpu/transformers/generate_impl_base.h" @@ -11,6 +12,63 @@ namespace contrib { namespace transformers { +template +struct SamplingState : public ISamplingState { + void Init(AllocatorPtr allocator, + AllocatorPtr cpu_allocator, + int batch_size, + int vocab_size, + int max_iter, + int seed, + bool is_cuda) { + int total_count = batch_size * vocab_size; + + this->h_softmaxed_score = AllocateBuffer(cpu_allocator, h_softmaxed_score_buffer_, SafeInt(total_count)); + + this->generator = std::default_random_engine{gsl::narrow_cast(seed)}; + + if (is_cuda) { + this->d_index_in = AllocateBuffer(allocator, d_index_in_buffer_, SafeInt(total_count)); + this->d_index_out = AllocateBuffer(allocator, d_index_out_buffer_, SafeInt(total_count)); + this->d_offset = AllocateBuffer(allocator, d_offset_buffer_, SafeInt(batch_size + 1)); + this->d_sorted_score = AllocateBuffer(allocator, d_sorted_score_buffer_, SafeInt(total_count)); + this->d_sorted_softmaxed_score = AllocateBuffer(allocator, d_sorted_softmaxed_score_buffer_, SafeInt(total_count)); + this->d_softmaxed_score = AllocateBuffer(allocator, d_softmaxed_score_buffer_, SafeInt(total_count)); + this->d_sampled = AllocateBuffer(allocator, d_sampled_buffer_, SafeInt(batch_size)); + this->h_sampled_all = AllocateBuffer(cpu_allocator, h_sampled_all_buffer_, SafeInt(batch_size * max_iter)); + this->d_indices = AllocateBuffer(allocator, d_indices_buffer_, SafeInt(batch_size)); + this->temp_storage_bytes = 0; + // TODO: Do not allocate this buffer if there's no presence_mask + this->d_presence_mask = AllocateBuffer(allocator, d_presence_mask_buffer_, SafeInt(total_count)); + + std::uniform_real_distribution distribution(0.0, 1.0); + static_cast(distribution(this->generator)); + for (size_t i = 0; i < this->h_sampled_all.size(); ++i) { + this->h_sampled_all[i] = distribution(this->generator); + } + } else { + // TODO: Some buffer can be reused for CPU + this->sorted_scores = AllocateBuffer(cpu_allocator, sorted_scores_buffer_, SafeInt(total_count)); + this->cumulative_probs = AllocateBuffer(cpu_allocator, cumulative_probs_buffer_, SafeInt(total_count)); + } + } + + private: + BufferUniquePtr d_index_in_buffer_; + BufferUniquePtr d_index_out_buffer_; + BufferUniquePtr d_offset_buffer_; + BufferUniquePtr d_sorted_score_buffer_; + BufferUniquePtr d_sorted_softmaxed_score_buffer_; + BufferUniquePtr d_softmaxed_score_buffer_; + BufferUniquePtr h_softmaxed_score_buffer_; + BufferUniquePtr d_sampled_buffer_; + BufferUniquePtr h_sampled_all_buffer_; + BufferUniquePtr d_indices_buffer_; + BufferUniquePtr d_presence_mask_buffer_; + BufferUniquePtr sorted_scores_buffer_; + BufferUniquePtr cumulative_probs_buffer_; +}; + template struct GreedySearchState : public IGreedySearchState { Sequences sequences; @@ -68,7 +126,7 @@ struct GreedySearchState : public IGreedySearchState { }; // Base class of gready search implementation that is common for both GPT-2 and Bart/T5. -template +template class GreedySearchBase : public GenerateBase { public: GreedySearchBase(OpKernelContextInternal& context, @@ -76,7 +134,7 @@ class GreedySearchBase : public GenerateBase { concurrency::ThreadPool* thread_pool, Stream* ort_stream, IConsoleDumper* cuda_dumper, - GreedySearchParameters& params, + ParametersT& params, const GenerationDeviceHelper::TopkFunc& topk_func, const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func) @@ -105,23 +163,25 @@ class GreedySearchBase : public GenerateBase { Status GenerateNextToken(const OrtValue& logits, gsl::span& next_tokens, GreedySearchState& greedy_state, + ISamplingState& sampling_state, int counter, int eos_token_id); // Calculate scores from logits, then apply filtering and select next token for each beam. Status ProcessLogits(const OrtValue& logits, // logits output of subgraph GreedySearchState& greedy_state, + ISamplingState& sampling_state, AllocatorPtr& allocator, int counter); - GreedySearchParameters* parameters_; + ParametersT* parameters_; // Device specific functions GenerationDeviceHelper::GreedySearchProcessLogitsFunc process_logits_func_; }; -template -Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context) { +template +Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context) { // Input shapes: // input_ids : (batch_size, sequence_length) // vocab_mask : (vocab_size) or nullptr @@ -129,13 +189,14 @@ Status GreedySearchBase::CheckInputs(const OpKernelContextInternal& context) context.Input(0), // input_ids context.Input(4), // vocab_mask context.Input(5), // prefix_vocab_mask - nullptr)); // attention_mask + context.Input(6), // attention_mask + context.Input(7))); // presence_mask return Status::OK(); } -template -Status GreedySearchBase::Initialize() { +template +Status GreedySearchBase::Initialize() { ORT_RETURN_IF_ERROR(this->context_.GetTempSpaceAllocator(&this->temp_space_allocator_)); ORT_RETURN_IF_ERROR(CheckScalarInput("max_length", 1, true)); @@ -155,26 +216,29 @@ Status GreedySearchBase::Initialize() { return Status::OK(); } -template -Status GreedySearchBase::ProcessLogits( +template +Status GreedySearchBase::ProcessLogits( const OrtValue& logits, GreedySearchState& greedy_state, + ISamplingState& sampling_state, AllocatorPtr& allocator, int counter) { - return process_logits_func_(logits, &greedy_state, &(greedy_state.sequences), allocator, - this->thread_pool_, &this->logits_processors_, - parameters_, counter, this->ort_stream_, this->GetConsoleDumper()); + bool use_sampling = std::is_same::value; + return process_logits_func_(logits, &greedy_state, &sampling_state, &(greedy_state.sequences), allocator, + this->thread_pool_, &this->logits_processors_, parameters_, + use_sampling, counter, this->ort_stream_, this->GetConsoleDumper()); } -template -Status GreedySearchBase::GenerateNextToken( +template +Status GreedySearchBase::GenerateNextToken( const OrtValue& logits, gsl::span& next_tokens, GreedySearchState& greedy_state, + ISamplingState& sampling_state, int counter, int eos_token_id) { // Process logits to get next token scores - ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, this->temp_space_allocator_, counter)); + ORT_RETURN_IF_ERROR(ProcessLogits(logits, greedy_state, sampling_state, this->temp_space_allocator_, counter)); next_tokens = greedy_state.next_tokens; for (size_t i = 0; i < next_tokens.size(); i++) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h index cf12fe6ba1..dfcc309274 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h +++ b/onnxruntime/contrib_ops/cpu/transformers/greedy_search_impl_gpt.h @@ -24,8 +24,8 @@ std::pair> CreateGptSubgraphAndUpdateParame } // namespace gpt_details // Greedy search implementation for GPT-2 model. -template -class GreedySearchGpt : public GreedySearchBase { +template +class GreedySearchGpt : public GreedySearchBase { public: GreedySearchGpt(OpKernelContextInternal& context, const SessionState* init_run_decoder_session_state, @@ -35,7 +35,7 @@ class GreedySearchGpt : public GreedySearchBase { concurrency::ThreadPool* thread_pool, Stream* ort_stream, IConsoleDumper* cuda_dumper, - GreedySearchParameters& params, + ParametersT& params, const GenerationDeviceHelper::CreateGptInputsFunc& create_inputs_func, const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, const GenerationDeviceHelper::TopkFunc& topk_func, @@ -43,15 +43,15 @@ class GreedySearchGpt : public GreedySearchBase { const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, const GenerationDeviceHelper::UpdateGptFeedsFunc& update_feeds_func) - : GreedySearchBase(context, - decoder_session_state, - thread_pool, - ort_stream, - cuda_dumper, - params, - topk_func, - process_logits_func, - device_copy_func), + : GreedySearchBase(context, + decoder_session_state, + thread_pool, + ort_stream, + cuda_dumper, + params, + topk_func, + process_logits_func, + device_copy_func), init_run_decoder_session_state_(init_run_decoder_session_state), init_run_gpt_subgraph_(init_run_gpt_subgraph), gpt_subgraph_(gpt_subgraph), @@ -94,8 +94,8 @@ class GreedySearchGpt : public GreedySearchBase { GenerationDeviceHelper::UpdateGptFeedsFunc update_feeds_func_; }; -template -Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths, +template +Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengths, OrtValue& expanded_input_ids, std::vector& feeds, IAllocatorUniquePtr& buffer) { @@ -134,8 +134,8 @@ Status GreedySearchGpt::CreateInitialFeeds(gsl::span& sequence_lengt this->parameters_->max_length); } -template -Status GreedySearchGpt::UpdateFeeds( +template +Status GreedySearchGpt::UpdateFeeds( const std::vector& last_outputs, std::vector& next_inputs, int current_length, @@ -161,11 +161,11 @@ Status GreedySearchGpt::UpdateFeeds( ); } -template -Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager, - const FeedsFetchesManager& feeds_fetches_manager) { +template +Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fetches_manager, + const FeedsFetchesManager& feeds_fetches_manager) { auto status = Status::OK(); - const GreedySearchParameters* parameters = this->parameters_; + const ParametersT* parameters = this->parameters_; // Allocate output tensors. int64_t sequences_dims[] = {parameters->batch_size, parameters->max_length}; @@ -184,6 +184,17 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fet parameters->max_length, this->IsCuda()); + SamplingState sampling_state; + if (std::is_same::value) { + sampling_state.Init(this->temp_space_allocator_, + this->cpu_allocator_, + static_cast(parameters->BatchBeamSize()), + static_cast(parameters->vocab_size), + static_cast(parameters->max_length - parameters->sequence_length), + parameters->seed, + this->IsCuda()); + } + IAllocatorUniquePtr buffer; OrtValue expanded_input_ids_in_cpu; ORT_RETURN_IF_ERROR(CreateInitialFeeds(greedy_state.sequence_lengths, expanded_input_ids_in_cpu, feeds, buffer)); @@ -276,6 +287,7 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fet ORT_RETURN_IF_ERROR(this->GenerateNextToken(logits, next_tokens, greedy_state, + sampling_state, iteration_counter, parameters->eos_token_id)); @@ -324,6 +336,27 @@ Status GreedySearchGpt::Execute(const FeedsFetchesManager* init_run_feeds_fet gsl::copy(sequence_source, batch_output); } +#ifdef DEBUG_GENERATION + // Debug the one step filtered logits for sampling + int64_t filtered_logits_dims[] = {parameters->batch_size, parameters->vocab_size}; + TensorShape filtered_logits_shape(&filtered_logits_dims[0], + sizeof(filtered_logits_dims) / sizeof(filtered_logits_dims[0])); + Tensor* filtered_logits = this->context_.Output(1, filtered_logits_shape); + if (filtered_logits != nullptr) { + gsl::span filtered_logits_span = filtered_logits->MutableDataAsSpan(); + for (int batch_id = 0; batch_id < parameters->batch_size; ++batch_id) { + auto batch_output = filtered_logits_span.subspan( + static_cast(batch_id) * parameters->vocab_size, + parameters->vocab_size); + gsl::span batch_filtered_logits = gsl::make_span(sampling_state.h_softmaxed_score.data() + + batch_id * parameters->vocab_size, + parameters->vocab_size); + + gsl::copy(batch_filtered_logits, batch_output); + } + } +#endif + return status; } diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc index 2f1e657c8e..d0641fedf9 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.cc @@ -6,8 +6,12 @@ #include "core/common/narrow.h" #include "core/common/safeint.h" #include "core/common/span_utils.h" +#include "core/providers/cpu/math/softmax_shared.h" #include "contrib_ops/cpu/transformers/logits_processor.h" #include "contrib_ops/cpu/transformers/dump_tensor.h" +#include +#include +#include namespace onnxruntime { namespace contrib { @@ -187,6 +191,53 @@ void PrefixVocabMaskLogitsProcessor::Process(const ISequences* /*sequences*/, #endif } +template +TemperatureLogitsProcessor::TemperatureLogitsProcessor(float temperature) : temperature_(temperature) { +} + +template +void TemperatureLogitsProcessor::Process(const ISequences* /*sequences*/, + NextTokenScores& next_token_scores) { + if (temperature_ == 1.0f) { + return; + } + + T* p = next_token_scores.scores.data(); + for (size_t i = 0; i < next_token_scores.scores.size(); i++) { + *p /= temperature_; + ++p; + } + +#ifdef DEBUG_GENERATION + DumpScores("TemperatureLogitsProcessor", next_token_scores); +#endif +} + +template +PresencePenaltyLogitsProcessor::PresencePenaltyLogitsProcessor(const gsl::span& presence_mask, + float presence_penalty) + : presence_mask_(presence_mask), presence_penalty_(presence_penalty) { +} + +template +void PresencePenaltyLogitsProcessor::Process(const ISequences*, + NextTokenScores& next_token_scores) { + if (presence_penalty_ == 0.0f) { + return; + } + + assert(!presence_mask_.empty()); + + T* p = next_token_scores.scores.data(); + for (size_t i = 0; i < next_token_scores.scores.size(); i++) { + *p -= presence_mask_[i] * presence_penalty_; + } + +#ifdef DEBUG_GENERATION + DumpScores("PresencePenaltyLogitsProcessor", next_token_scores); +#endif +} + void LogitsProcessorList::Init(const BeamSearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } @@ -195,6 +246,10 @@ void LogitsProcessorList::Init(const GreedySearchParameters& parameters) { LogitsProcessorInitImpl(parameters); } +void LogitsProcessorList::Init(const SamplingParameters& parameters) { + LogitsProcessorInitImpl(parameters); +} + void LogitsProcessorList::Process(const ISequences* sequences, gsl::span& next_token_scores, int step) { diff --git a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h index 1a2fba19bf..4a516474c8 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h +++ b/onnxruntime/contrib_ops/cpu/transformers/logits_processor.h @@ -7,6 +7,7 @@ #include "contrib_ops/cpu/transformers/sequences.h" #include "contrib_ops/cpu/transformers/beam_search_parameters.h" #include "contrib_ops/cpu/transformers/greedy_search_parameters.h" +#include "contrib_ops/cpu/transformers/sampling_parameters.h" #include "contrib_ops/cpu/transformers/generation_shared.h" namespace onnxruntime { @@ -96,11 +97,53 @@ class PrefixVocabMaskLogitsProcessor : public ILogitsProcessor { const int batch_size_; }; +template +class TemperatureLogitsProcessor : public ILogitsProcessor { + public: + TemperatureLogitsProcessor(float temperature); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + float temperature_; +}; + +// template +// class TopPLogitsProcessor : public ILogitsProcessor { +// public: +// TopPLogitsProcessor(float top_p, float filter_value, +// onnxruntime::concurrency::ThreadPool* thread_pool); + +// void Process(const ISequences* sequences, +// NextTokenScores& next_token_scores) override; + +// private: +// float top_p_; +// float filter_value_; +// onnxruntime::concurrency::ThreadPool* thread_pool_; +// }; + +template +class PresencePenaltyLogitsProcessor : public ILogitsProcessor { + public: + PresencePenaltyLogitsProcessor(const gsl::span& presence_mask, + float presence_penalty); + + void Process(const ISequences* sequences, + NextTokenScores& next_token_scores) override; + + private: + gsl::span presence_mask_; + float presence_penalty_; +}; + class LogitsProcessorList : public ILogitsProcessorList { public: LogitsProcessorList() = default; void Init(const BeamSearchParameters& parameters); void Init(const GreedySearchParameters& parameters); + void Init(const SamplingParameters& parameters); void Process(const ISequences* sequences, gsl::span& next_token_scores, int step); private: @@ -140,6 +183,19 @@ class LogitsProcessorList : public ILogitsProcessorList { processor_list_.push_back(min_length_processor_.get()); } + if (parameters.temperature > 0) { + temperature_processor_ = std::make_unique>(parameters.temperature); + processor_list_.push_back(temperature_processor_.get()); + } + + if (!parameters.presence_mask.empty()) { + presence_penalty_processor_ = std::make_unique< + PresencePenaltyLogitsProcessor + >(parameters.presence_mask, + parameters.presence_penalty); + processor_list_.push_back(presence_penalty_processor_.get()); + } + batch_beam_size_ = parameters.BatchBeamSize(); vocab_size_ = parameters.vocab_size; } @@ -153,6 +209,8 @@ class LogitsProcessorList : public ILogitsProcessorList { std::unique_ptr> vocab_mask_processor_; std::unique_ptr> prefix_vocab_mask_processor_; std::unique_ptr> min_length_processor_; + std::unique_ptr> temperature_processor_; + std::unique_ptr> presence_penalty_processor_; }; } // namespace transformers diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc new file mode 100644 index 0000000000..a9b2db40e9 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.cc @@ -0,0 +1,175 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// there's no way to use a raw pointer as the copy destination with std::copy_n +// (which gsl::copy uses with span::data() which returns a raw pointer) with the 14.11 toolset +// without generating a 4996 warning. going through an iterator is way too much overhead so turn off the warning. +#ifdef _MSC_VER +#pragma warning(push) +#pragma warning(disable : 4996) +#endif + +#include "core/framework/op_kernel_context_internal.h" +#include "core/framework/utils.h" +#include "contrib_ops/cpu/transformers/sampling.h" +#include "contrib_ops/cpu/transformers/logits_processor.h" +#include "contrib_ops/cpu/transformers/sequences.h" +#include "contrib_ops/cpu/transformers/dump_tensor.h" +#include "contrib_ops/cpu/transformers/greedy_search_impl_gpt.h" + +using namespace ONNX_NAMESPACE; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { + +#define REGISTER_KERNEL_TYPED(T) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + Sampling, \ + kMSDomain, \ + 1, \ + T, \ + kCpuExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("T", DataTypeImpl::GetTensorType()), \ + transformers::Sampling); + +REGISTER_KERNEL_TYPED(float) + +namespace transformers { + +void Sampling::Init(const OpKernelInfo& info) { + parameters_.ParseFromAttributes(info); + parameters_.vocab_size = (parameters_.vocab_size == 0 ? -1 : parameters_.vocab_size); + + // Model_type could be either 0 (GPT-2) or 1 (encoder-decoder like T5) + ORT_ENFORCE(parameters_.model_type == IGenerationParameters::kModelTypeGpt); + + ONNX_NAMESPACE::GraphProto proto; + if (parameters_.model_type != IGenerationParameters::kModelTypeGpt) { + // Make sure the encoder sub-graph attribute is present for the T5 model. + ORT_ENFORCE(info.GetAttr("encoder", &proto).IsOK()); + } + + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { + // Check if the init_decoder sub-graph attribute is present for the GPT2 model. + if (info.GetAttr("init_decoder", &proto).IsOK()) { + has_init_decoder_ = true; + } + } + + // Make sure the decoder sub-graph attribute is present for all model types. + ORT_ENFORCE(info.GetAttr("decoder", &proto).IsOK()); +} + +Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, + const std::string& attribute_name, + const SessionState& subgraph_session_state) { + const auto& node = Node(); + if (parameters_.model_type == IGenerationParameters::kModelTypeGpt) { // GPT-2 + if (attribute_name == "decoder") { + ORT_ENFORCE(gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, + subgraph_session_state, parameters_); + + auto status = res.first; + if (!status.IsOK()) { + return status; + } + + gpt_subgraph_ = std::move(res.second); + decoder_feeds_fetches_manager_ = gpt_subgraph_->GetFeedsFetchesManager(); + } else if (attribute_name == "init_decoder") { + ORT_ENFORCE(init_run_gpt_subgraph_ == nullptr, "SetupSubgraphExecutionInfo should only be called once for each subgraph."); + auto res = gpt_details::CreateGptSubgraphAndUpdateParameters(node, session_state, attribute_name, + subgraph_session_state, parameters_); + + auto status = res.first; + if (!status.IsOK()) { + return status; + } + + init_run_gpt_subgraph_ = std::move(res.second); + init_run_decoder_feeds_fetches_manager_ = init_run_gpt_subgraph_->GetFeedsFetchesManager(); + } + } else if (parameters_.model_type == IGenerationParameters::kModelTypeT5) { // encoder-decoder like T5 + ORT_THROW("Not Implemented"); + } + + return Status::OK(); +} + +Status Sampling::Compute(OpKernelContext* ctx) const { + auto* ctx_internal = static_cast(ctx); + + auto* decoder_session_state = ctx_internal->SubgraphSessionState("decoder"); + ORT_ENFORCE(decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute."); + ORT_ENFORCE(decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); + + auto* init_run_decoder_session_state = ctx_internal->SubgraphSessionState("init_decoder"); + if (has_init_decoder_) { + ORT_ENFORCE(init_run_decoder_session_state, "Subgraph SessionState was not found for 'decoder' attribute."); + ORT_ENFORCE(init_run_decoder_feeds_fetches_manager_, "CreateFeedsFetchesManager must be called prior to execution of graph."); + ORT_ENFORCE(init_run_gpt_subgraph_ && gpt_subgraph_ + && init_run_gpt_subgraph_->past_present_share_buffer_ == gpt_subgraph_->past_present_share_buffer_, + "past_present_share_buffer mode must be same for init decoder and decoder subgraphes"); + } + + concurrency::ThreadPool* thread_pool = ctx->GetOperatorThreadPool(); + + // make a copy since we will update the parameters based on inputs later + SamplingParameters parameters = parameters_; + + if (parameters_.model_type == 0) { // GPT-2 + // Subgraph has constraint that the output is either float or float16 + if (!gpt_subgraph_->IsOutputFloat16()) { + GreedySearchGpt impl{ + *ctx_internal, + has_init_decoder_ ? init_run_decoder_session_state : nullptr, + has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr, + *decoder_session_state, + *gpt_subgraph_, + thread_pool, + ctx->GetComputeStream(), + dumper_, + parameters, + GenerationCpuDeviceHelper::CreateGptInputs, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::GreedySearchProcessLogits, + init_greedy_state_func_ ? init_greedy_state_func_ : GenerationCpuDeviceHelper::InitGreedyState, + device_copy_func_ ? device_copy_func_ : GenerationCpuDeviceHelper::DeviceCopy, + update_gpt_feeds_func_ ? update_gpt_feeds_func_ : GenerationCpuDeviceHelper::UpdateGptFeeds}; + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); + } else { + GreedySearchGpt impl{ + *ctx_internal, + has_init_decoder_ ? init_run_decoder_session_state : nullptr, + has_init_decoder_ ? init_run_gpt_subgraph_.get() : nullptr, + *decoder_session_state, + *gpt_subgraph_, + thread_pool, + ctx->GetComputeStream(), + dumper_, + parameters, + GenerationCpuDeviceHelper::CreateGptInputs, + add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, + topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, + process_logits_fp16_func_, + init_greedy_state_fp16_func_, + device_copy_func_, + update_gpt_feeds_fp16_func_}; + ORT_RETURN_IF_ERROR(impl.Initialize()); + + return impl.Execute(init_run_decoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); + } + } + + return Status::OK(); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling.h b/onnxruntime/contrib_ops/cpu/transformers/sampling.h new file mode 100644 index 0000000000..ea57ce15e2 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling.h @@ -0,0 +1,103 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include +#include +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cpu/controlflow/utils.h" +#include "contrib_ops/cpu/transformers/subgraph_gpt.h" +#include "contrib_ops/cpu/transformers/generation_device_helper.h" +#include "contrib_ops/cpu/transformers/sampling_parameters.h" + +namespace onnxruntime { +class FeedsFetchesManager; + +namespace contrib { +namespace transformers { + +using namespace onnxruntime::controlflow; // namespace of IControlFlowKernel + +class Sampling : public IControlFlowKernel { + public: + explicit Sampling(const OpKernelInfo& info) + : IControlFlowKernel(info), + decoder_feeds_fetches_manager_(nullptr), + dumper_(nullptr) { + Init(info); + } + + void Init(const OpKernelInfo& info); + + Status Compute(OpKernelContext* ctx) const override; + + Status SetupSubgraphExecutionInfo(const SessionState& session_state, + const std::string& attribute_name, + const SessionState& subgraph_session_state) override; + + protected: + void SetConsoleDumper(IConsoleDumper* dumper) { dumper_ = dumper; } + + // device helpers that is same for both GPT and encoder-decoder models. + void SetDeviceHelpers( + const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, + const GenerationDeviceHelper::TopkFunc& topk_func, + const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, + const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_func, + const GenerationDeviceHelper::GreedySearchProcessLogitsFunc& process_logits_fp16_func, + const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_func, + const GenerationDeviceHelper::InitGreedyStateFunc& init_greedy_state_fp16_func) { + add_to_feeds_func_ = add_to_feeds_func; + topk_func_ = topk_func; + device_copy_func_ = device_copy_func; + process_logits_func_ = process_logits_func; + process_logits_fp16_func_ = process_logits_fp16_func; + init_greedy_state_func_ = init_greedy_state_func; + init_greedy_state_fp16_func_ = init_greedy_state_fp16_func; + } + + void SetDeviceHelpers_Gpt( + const GenerationDeviceHelper::UpdateGptFeedsFunc& update_gpt_feeds_func, + const GenerationDeviceHelper::UpdateGptFeedsFunc& update_gpt_feeds_fp16_func) { + update_gpt_feeds_func_ = update_gpt_feeds_func; + update_gpt_feeds_fp16_func_ = update_gpt_feeds_fp16_func; + } + + private: + // Device specific functions + GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; + GenerationDeviceHelper::TopkFunc topk_func_; + GenerationDeviceHelper::DeviceCopyFunc device_copy_func_; + + GenerationDeviceHelper::GreedySearchProcessLogitsFunc process_logits_func_; + GenerationDeviceHelper::GreedySearchProcessLogitsFunc process_logits_fp16_func_; + + GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_func_; + GenerationDeviceHelper::InitGreedyStateFunc init_greedy_state_fp16_func_; + + //------------------------------------------------------------ + // Device specific functions for GPT + //------------------------------------------------------------ + GenerationDeviceHelper::UpdateGptFeedsFunc update_gpt_feeds_func_; + GenerationDeviceHelper::UpdateGptFeedsFunc update_gpt_feeds_fp16_func_; + + //------------------------------------------------------------ + // Subgraph and FeedsFetchesManager re-used for each subgraph execution. + //------------------------------------------------------------ + std::unique_ptr init_run_gpt_subgraph_; + std::unique_ptr gpt_subgraph_; + + FeedsFetchesManager* decoder_feeds_fetches_manager_; + FeedsFetchesManager* init_run_decoder_feeds_fetches_manager_; + + IConsoleDumper* dumper_; + + SamplingParameters parameters_; + + bool has_init_decoder_ = false; +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h new file mode 100644 index 0000000000..1e3c7035ff --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_cpu_helper.h @@ -0,0 +1,161 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +namespace onnxruntime { +namespace contrib { +namespace SamplingCpuHelper { + +template +void filter_scores(std::vector& sorted_indice, + gsl::span& next_token_score, + const transformers::IGenerationParameters* parameters, + size_t index) { + size_t real_index = sorted_indice[index]; + next_token_score[real_index] = (T)parameters->filter_value; +} + +template +void cumulate_and_filter_custom(gsl::span& next_token_scores, + gsl::span& cumulative_probs, + const transformers::IGenerationParameters* parameters, + std::vector& sorted_indices) { + for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { + size_t offset = i * parameters->vocab_size; + if (cumulative_probs[offset] > parameters->top_p) { + filter_scores(sorted_indices, next_token_scores, parameters, 1 + offset); + } + for (size_t j = 1; j < static_cast(parameters->vocab_size) - 1; j++) { + cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; + if (cumulative_probs[j + offset] > parameters->top_p) { + filter_scores(sorted_indices, next_token_scores, parameters, j + offset + 1); + } + } + } +} + +template +void cumulate_and_filter(gsl::span& next_token_scores, + gsl::span& cumulative_probs, + const transformers::IGenerationParameters* parameters, + std::vector& sorted_indices) { + for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { + size_t offset = i * parameters->vocab_size; + if (cumulative_probs[offset] <= 1 - parameters->top_p) { + filter_scores(sorted_indices, next_token_scores, parameters, offset); + } + for (size_t j = 1; j < static_cast(parameters->vocab_size) - static_cast(parameters->min_tokens_to_keep); j++) { + cumulative_probs[j + offset] += cumulative_probs[j + offset - 1]; + if (cumulative_probs[j + offset] <= 1 - parameters->top_p) { + filter_scores(sorted_indices, next_token_scores, parameters, j + offset); + } + } + } +} + +template +Status Sample(AllocatorPtr& allocator, + onnxruntime::concurrency::ThreadPool* thread_pool, + gsl::span& next_token_scores, + transformers::ISamplingState* sampling_state, + transformers::IGreedySearchState* greedy_state, + const transformers::IGenerationParameters* parameters, + const transformers::IConsoleDumper* dumper) { + ORT_UNUSED_PARAMETER(dumper); + + gsl::span& sorted_scores = sampling_state->sorted_scores; + memcpy(sorted_scores.data(), next_token_scores.data(), next_token_scores.size_bytes()); + std::vector sorted_indices(static_cast(parameters->batch_size) * static_cast(parameters->vocab_size)); + + std::function predicator; + if (parameters->custom_sampling) { + predicator = std::greater(); + } else { + predicator = std::less(); + } + + // TODO: This could be optimized with allocated buffer and handwritten sort algorithm + for (size_t i = 0; i < static_cast(parameters->batch_size); i++) { + auto indices_begin = sorted_indices.begin() + i * parameters->vocab_size; + auto indices_end = sorted_indices.begin() + (i + 1) * parameters->vocab_size; + std::iota(indices_begin, indices_end, 0); + std::sort(indices_begin, indices_end, + [&next_token_scores, &predicator](size_t i1, size_t i2) { + return !predicator(next_token_scores[i1], next_token_scores[i2]); + }); + + std::sort(sorted_scores.begin() + i * parameters->vocab_size, + sorted_scores.begin() + (i + 1) * parameters->vocab_size, + predicator); + } + + gsl::span& cumulative_probs = sampling_state->cumulative_probs; + + ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters->batch_size, + parameters->vocab_size, + sorted_scores.data(), + cumulative_probs.data(), + false, + thread_pool)); + + if (parameters->custom_sampling) { + cumulate_and_filter_custom(next_token_scores, cumulative_probs, parameters, sorted_indices); + } else { + cumulate_and_filter(next_token_scores, cumulative_probs, parameters, sorted_indices); + } + + gsl::span& next_token_probs = sampling_state->h_softmaxed_score; + ORT_RETURN_IF_ERROR(SoftmaxCPU(parameters->batch_size, + parameters->vocab_size, + next_token_scores.data(), + next_token_probs.data(), + false, + thread_pool)); + + // torch.multinomial() + int64_t next_token_probs_dims[] = {static_cast(parameters->batch_size), parameters->vocab_size}; + TensorShape next_token_probs_shape(&next_token_probs_dims[0], 2); + auto element_type = DataTypeImpl::GetType(); + OrtValue next_token_probs_value; + Tensor::InitOrtValue(element_type, + next_token_probs_shape, + next_token_probs.data(), + allocator->Info(), + next_token_probs_value); + const Tensor& input = next_token_probs_value.Get(); + + std::default_random_engine& generator = sampling_state->generator; + + int64_t sampled_idx_dims[] = {static_cast(parameters->batch_size), 1}; + TensorShape sampled_idx_shape(&sampled_idx_dims[0], 2); + + gsl::span& next_token_idx = greedy_state->next_tokens_cpu; + + OrtValue sampled_idx_ov; + Tensor::InitOrtValue(DataTypeImpl::GetType(), + sampled_idx_shape, + next_token_idx.data(), + allocator->Info(), + sampled_idx_ov); + Tensor* sampled_idx = sampled_idx_ov.GetMutable(); + + // Copy the allocator because MultinomialComputeShared() uses move(allocator) + AllocatorPtr allocatortemp = allocator; + ORT_RETURN_IF_ERROR(MultinomialComputeShared(allocatortemp, + input, + parameters->batch_size, + parameters->vocab_size, + 1, + generator, + *sampled_idx)); + // TODO: update presense_mask() +#ifdef DEBUG_GENERATION + dumper->Print("sampled_idx", *sampled_idx); +#endif + + return Status::OK(); +} + +} // namespace SamplingCudaHelper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc new file mode 100644 index 0000000000..537b0d7538 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.cc @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#include "contrib_ops/cpu/transformers/sampling_parameters.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +void SamplingParameters::ParseFromAttributes(const OpKernelInfo& info) { + model_type = static_cast(info.GetAttrOrDefault("model_type", 0)); + eos_token_id = static_cast(info.GetAttrOrDefault("eos_token_id", -1)); + pad_token_id = static_cast(info.GetAttrOrDefault("pad_token_id", -1)); + decoder_start_token_id = static_cast(info.GetAttrOrDefault("decoder_start_token_id", -1)); + no_repeat_ngram_size = static_cast(info.GetAttrOrDefault("no_repeat_ngram_size", 0)); + temperature = info.GetAttrOrDefault("temperature", 1.0f); + top_p = info.GetAttrOrDefault("top_p", 0.0f); + filter_value = info.GetAttrOrDefault("filter_value", -std::numeric_limits::infinity()); + min_tokens_to_keep = static_cast(info.GetAttrOrDefault("min_tokens_to_keep", 0)); + presence_penalty = info.GetAttrOrDefault("presence_penalty", 0.0f); + custom_sampling = static_cast(info.GetAttrOrDefault("custom", 0)); + vocab_size = static_cast(info.GetAttrOrDefault("vocab_size", -1)); +} + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h new file mode 100644 index 0000000000..6c0f866f09 --- /dev/null +++ b/onnxruntime/contrib_ops/cpu/transformers/sampling_parameters.h @@ -0,0 +1,19 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "contrib_ops/cpu/transformers/greedy_search_parameters.h" + +namespace onnxruntime { +namespace contrib { +namespace transformers { + +struct SamplingParameters : public GreedySearchParameters { + void ParseFromAttributes(const OpKernelInfo& info); +}; + +} // namespace transformers +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 84e8b0bfc9..385f3bc9d7 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -75,6 +75,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ParametricSoftplus); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ParametricSoftplus); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, Sampling); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, float, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, double, ScaledTanh); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 1, MLFloat16, ScaledTanh); @@ -189,6 +190,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu deleted file mode 100644 index 3b530cef6b..0000000000 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.cu +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#include "core/providers/cuda/cuda_common.h" -#include "core/providers/cuda/cu_inc/common.cuh" -#include "cub/util_type.cuh" -#include "contrib_ops/cuda/transformers/beam_search_impl.h" - -namespace onnxruntime { -namespace contrib { -namespace cuda { -__global__ void InitKernel(float* beam_scores, - int num_beams, - int total_elements) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < total_elements) { - int beam_index = index % num_beams; - beam_scores[index] = beam_index > 0 ? static_cast(-1e9) : 0.0f; - } -} - -void LaunchInitKernel( - float* beam_scores, - int batch_size, - int num_beams, - cudaStream_t stream) { - int total_elements = batch_size * num_beams; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - InitKernel<<>>(beam_scores, num_beams, total_elements); -} - -__global__ void NextTokenKernel(const int64_t* next_token_indices, - int32_t* next_indices, - int32_t* next_tokens, - int vocab_size, - int total_elements) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < total_elements) { - next_indices[index] = next_token_indices[index] / vocab_size; - next_tokens[index] = next_token_indices[index] % vocab_size; - } -} - -void LaunchNextTokenKernel(const int64_t* next_token_indices, - int32_t* next_indices, - int32_t* next_tokens, - int batch_size, - int top_k, - int vocab_size, - cudaStream_t stream) { - int total_elements = batch_size * top_k; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - NextTokenKernel<<>>(next_token_indices, - next_indices, - next_tokens, - vocab_size, - total_elements); -} - -template -__global__ void LogitsProcessKernel( - T* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int num_beams, - int vocab_size, - int padded_vocab_size, - int total_elements, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < total_elements) { - int batch_beam_index = index / padded_vocab_size; - int word_id = index % padded_vocab_size; - - if (word_id >= vocab_size) { - // Set any value within the padding region to the lowest value so that it isn't picked - next_token_scores[index] = cub::FpLimits::Lowest(); - } else { - // RepetitionPenaltyLogitsProcessor - if (repetition_penalty != 1.0f) { - int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length; - bool found = false; - for (int i = 0; i < current_sequence_length; i++) { - if (current_sequence[i] == word_id) { - found = true; - break; - } - } - if (found) { - float score = (float)next_token_scores[index]; - next_token_scores[index] = (T)(score < 0 ? score * repetition_penalty : score / repetition_penalty); - } - } - - // NoRepeatNGramLogitsProcessor - if (no_repeat_ngram_size > 0 && current_sequence_length >= no_repeat_ngram_size) { - int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length; - bool found = false; - for (int i = no_repeat_ngram_size - 1; i < current_sequence_length; i++) { - if (current_sequence[i] == word_id) { // last token of n-gram matched - found = true; - for (int j = 0; j < no_repeat_ngram_size - 1; j++) { // match the remaining N-1 tokens - if (current_sequence[i - j - 1] != current_sequence[current_sequence_length - 1 - j]) { - found = false; - break; - } - } - if (found) { - break; - } - } - } - - if (found) { - next_token_scores[index] = cub::FpLimits::Lowest(); - return; - } - } - - // VocabMaskLogitsProcessor - if (vocab_mask != nullptr && vocab_mask[word_id] == 0) { - next_token_scores[index] = cub::FpLimits::Lowest(); - return; - } - - // PrefixVocabMaskLogitsProcessor - int batch_id = batch_beam_index / num_beams; - if (prefix_vocab_mask != nullptr && prefix_vocab_mask[batch_id * vocab_size + word_id] == 0) { - next_token_scores[index] = cub::FpLimits::Lowest(); - return; - } - - // MinLengthLogitsProcessor - if (word_id == demote_token_id) { - next_token_scores[index] = cub::FpLimits::Lowest(); - } - } - } -} - -template -void LaunchLogitsProcessKernel( - T* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int batch_size, - int num_beams, - int vocab_size, - int padded_vocab_size, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size, - cudaStream_t stream) { - int total_elements = batch_size * num_beams * padded_vocab_size; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - LogitsProcessKernel<<>>( - next_token_scores, - vocab_mask, - prefix_vocab_mask, - num_beams, - vocab_size, - padded_vocab_size, - total_elements, - demote_token_id, - sequences, - max_sequence_length, - current_sequence_length, - repetition_penalty, - no_repeat_ngram_size); -} - -// Instantiation -template void LaunchLogitsProcessKernel( - float* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int batch_size, - int num_beams, - int vocab_size, - int padded_vocab_size, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size, - cudaStream_t stream); - -template void LaunchLogitsProcessKernel( - half* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int batch_size, - int num_beams, - int vocab_size, - int padded_vocab_size, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size, - cudaStream_t stream); - -__global__ void AddProbsKernel(float* log_probs, - float* cum_log_probs, - const int vocab_size, - const int total_elements) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - int batch_beam_index = index / vocab_size; - - if (index < total_elements) - log_probs[index] += cum_log_probs[batch_beam_index]; -} - -template -void LaunchAddProbsKernel(T* log_probs, - T* cum_log_probs, - const int batch_size, - const int num_beams, - const int vocab_size, - cudaStream_t stream) { - int total_elements = batch_size * num_beams * vocab_size; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - AddProbsKernel<<>>(log_probs, cum_log_probs, vocab_size, total_elements); -} - -template void LaunchAddProbsKernel( - float* log_probs, - float* cum_log_probs, - const int batch_size, - const int num_beams, - const int vocab_size, - cudaStream_t stream); - -template -__global__ void UpdateGptInputsKernel(const T* old_mask_data, - T* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length) { - int index = blockIdx.x * blockDim.x + threadIdx.x; - if (index < batch_beam_size * current_length) { - // Update attention mask. - int i = index / current_length; - int j = index % current_length; - mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast(1); - - if (next_positions != nullptr) { - // Update sequence length (or next positions). - if (index < batch_beam_size) { - next_positions[index]++; - } - } - } -} - -void LaunchUpdateGptKernel(const int32_t* old_mask_data, - int32_t* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length, - cudaStream_t stream) { - assert(current_length > 0); - int total_elements = batch_beam_size * current_length; - constexpr int blockSize = 256; - const int gridSize = (total_elements + blockSize - 1) / blockSize; - UpdateGptInputsKernel<<>>( - old_mask_data, mask_data, next_positions, batch_beam_size, current_length); -} - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h b/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h deleted file mode 100644 index 9b122ca797..0000000000 --- a/onnxruntime/contrib_ops/cuda/transformers/beam_search_impl.h +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright (c) Microsoft Corporation. All rights reserved. -// Licensed under the MIT License. - -#pragma once - -#include -#include - -namespace onnxruntime { -namespace contrib { -namespace cuda { - -void LaunchInitKernel( - float* beam_scores, - int batch_size, - int num_beams, - cudaStream_t stream); - -template -void LaunchAddProbsKernel(T* log_probs, - T* cum_log_probs, - const int batch_size, - const int num_beams, - const int vocab_size, - cudaStream_t stream); - -template -void LaunchLogitsProcessKernel( - T* next_token_scores, - const int* vocab_mask, - const int* prefix_vocab_mask, - int batch_size, - int num_beams, - int vocab_size, - int padded_vocab_size, - int demote_token_id, - int32_t* sequences, - int max_sequence_length, - int current_sequence_length, - float repetition_penalty, - int no_repeat_ngram_size, - cudaStream_t stream); - -void LaunchNextTokenKernel(const int64_t* next_token_indices, - int32_t* next_indices, - int32_t* next_tokens, - int batch_size, - int top_k, - int vocab_size, - cudaStream_t stream); - -void LaunchUpdateGptKernel(const int32_t* old_mask_data, - int32_t* mask_data, - int32_t* next_positions, - int batch_beam_size, - int current_length, - cudaStream_t stream); - -} // namespace cuda -} // namespace contrib -} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu new file mode 100644 index 0000000000..f35ef8a40b --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -0,0 +1,758 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "cub/util_type.cuh" +#include +#include +#include "contrib_ops/cuda/transformers/generation_cuda_impl.h" + + +namespace onnxruntime { +namespace contrib { +namespace cuda { +__global__ void InitKernel(float* beam_scores, + int num_beams, + int total_elements) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < total_elements) { + int beam_index = index % num_beams; + beam_scores[index] = beam_index > 0 ? static_cast(-1e9) : 0.0f; + } +} + +void LaunchInitKernel( + float* beam_scores, + int batch_size, + int num_beams, + cudaStream_t stream) { + int total_elements = batch_size * num_beams; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + InitKernel<<>>(beam_scores, num_beams, total_elements); +} + +__global__ void NextTokenKernel(const int64_t* next_token_indices, + int32_t* next_indices, + int32_t* next_tokens, + int vocab_size, + int total_elements) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < total_elements) { + next_indices[index] = next_token_indices[index] / vocab_size; + next_tokens[index] = next_token_indices[index] % vocab_size; + } +} + +void LaunchNextTokenKernel(const int64_t* next_token_indices, + int32_t* next_indices, + int32_t* next_tokens, + int batch_size, + int top_k, + int vocab_size, + cudaStream_t stream) { + int total_elements = batch_size * top_k; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + NextTokenKernel<<>>(next_token_indices, + next_indices, + next_tokens, + vocab_size, + total_elements); +} + +template +__global__ void LogitsProcessKernel( + T* next_token_scores, + const int* vocab_mask, + const int* prefix_vocab_mask, + const int* presence_mask, + float presence_penalty, + float temperature, + int num_beams, + int vocab_size, + int padded_vocab_size, + int total_elements, + int demote_token_id, + int32_t* sequences, + int max_sequence_length, + int current_sequence_length, + float repetition_penalty, + int no_repeat_ngram_size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < total_elements) { + int batch_beam_index = index / padded_vocab_size; + int word_id = index % padded_vocab_size; + + if (word_id >= vocab_size) { + // Set any value within the padding region to the lowest value so that it isn't picked + next_token_scores[index] = cub::FpLimits::Lowest(); + } else { + // RepetitionPenaltyLogitsProcessor + if (repetition_penalty != 1.0f) { + int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length; + bool found = false; + for (int i = 0; i < current_sequence_length; i++) { + if (current_sequence[i] == word_id) { + found = true; + break; + } + } + if (found) { + float score = (float)next_token_scores[index]; + next_token_scores[index] = (T)(score < 0 ? score * repetition_penalty : score / repetition_penalty); + } + } + + // NoRepeatNGramLogitsProcessor + if (no_repeat_ngram_size > 0 && current_sequence_length >= no_repeat_ngram_size) { + int32_t* current_sequence = sequences + batch_beam_index * max_sequence_length; + bool found = false; + for (int i = no_repeat_ngram_size - 1; i < current_sequence_length; i++) { + if (current_sequence[i] == word_id) { // last token of n-gram matched + found = true; + for (int j = 0; j < no_repeat_ngram_size - 1; j++) { // match the remaining N-1 tokens + if (current_sequence[i - j - 1] != current_sequence[current_sequence_length - 1 - j]) { + found = false; + break; + } + } + if (found) { + break; + } + } + } + + if (found) { + next_token_scores[index] = cub::FpLimits::Lowest(); + return; + } + } + + // VocabMaskLogitsProcessor + if (vocab_mask != nullptr && vocab_mask[word_id] == 0) { + next_token_scores[index] = cub::FpLimits::Lowest(); + return; + } + + // PrefixVocabMaskLogitsProcessor + int batch_id = batch_beam_index / num_beams; + if (prefix_vocab_mask != nullptr && prefix_vocab_mask[batch_id * vocab_size + word_id] == 0) { + next_token_scores[index] = cub::FpLimits::Lowest(); + return; + } + + // MinLengthLogitsProcessor + if (word_id == demote_token_id) { + next_token_scores[index] = cub::FpLimits::Lowest(); + } + + // PresencePenaltyLogitsProcessor + if (presence_mask != nullptr && presence_mask[index] == 1) { + float score = (float)next_token_scores[index] - presence_penalty; + next_token_scores[index] = (T)score; + } + + // TemperatureLogitsProcessor + if (temperature != 1.0f) { + float score = (float)(next_token_scores[index]); + next_token_scores[index] = (T)(score / temperature); + } + } + } +} + +template +void LaunchLogitsProcessKernel( + T* next_token_scores, + const int* vocab_mask, + const int* prefix_vocab_mask, + int* presence_mask, + float presence_penalty, + float temperature, + int batch_size, + int num_beams, + int vocab_size, + int padded_vocab_size, + int demote_token_id, + int32_t* sequences, + int max_sequence_length, + int current_sequence_length, + float repetition_penalty, + int no_repeat_ngram_size, + cudaStream_t stream) { + int total_elements = batch_size * num_beams * vocab_size; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + LogitsProcessKernel<<>>( + next_token_scores, + vocab_mask, + prefix_vocab_mask, + presence_mask, + presence_penalty, + temperature, + num_beams, + vocab_size, + padded_vocab_size, + total_elements, + demote_token_id, + sequences, + max_sequence_length, + current_sequence_length, + repetition_penalty, + no_repeat_ngram_size); +} + +// Instantiation +template void LaunchLogitsProcessKernel( + float* next_token_scores, + const int* vocab_mask, + const int* prefix_vocab_mask, + int* presence_mask, + float presence_penalty, + float temperature, + int batch_size, + int num_beams, + int vocab_size, + int padded_vocab_size, + int demote_token_id, + int32_t* sequences, + int max_sequence_length, + int current_sequence_length, + float repetition_penalty, + int no_repeat_ngram_size, + cudaStream_t stream); + +template void LaunchLogitsProcessKernel( + half* next_token_scores, + const int* vocab_mask, + const int* prefix_vocab_mask, + int* presence_mask, + float presence_penalty, + float temperature, + int batch_size, + int num_beams, + int vocab_size, + int padded_vocab_size, + int demote_token_id, + int32_t* sequences, + int max_sequence_length, + int current_sequence_length, + float repetition_penalty, + int no_repeat_ngram_size, + cudaStream_t stream); + +__global__ void AddProbsKernel(float* log_probs, + float* cum_log_probs, + const int vocab_size, + const int total_elements) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int batch_beam_index = index / vocab_size; + + if (index < total_elements) + log_probs[index] += cum_log_probs[batch_beam_index]; +} + +template +void LaunchAddProbsKernel(T* log_probs, + T* cum_log_probs, + const int batch_size, + const int num_beams, + const int vocab_size, + cudaStream_t stream) { + int total_elements = batch_size * num_beams * vocab_size; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + AddProbsKernel<<>>(log_probs, cum_log_probs, vocab_size, total_elements); +} + +template void LaunchAddProbsKernel( + float* log_probs, + float* cum_log_probs, + const int batch_size, + const int num_beams, + const int vocab_size, + cudaStream_t stream); + +template +__global__ void UpdateGptInputsKernel(const T* old_mask_data, + T* mask_data, + int32_t* next_positions, + int batch_beam_size, + int current_length) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index < batch_beam_size * current_length) { + // Update attention mask. + int i = index / current_length; + int j = index % current_length; + mask_data[index] = (j < current_length - 1) ? old_mask_data[i * (current_length - 1) + j] : static_cast(1); + + if (next_positions != nullptr) { + // Update sequence length (or next positions). + if (index < batch_beam_size) { + next_positions[index]++; + } + } + } +} + +void LaunchUpdateGptKernel(const int32_t* old_mask_data, + int32_t* mask_data, + int32_t* next_positions, + int batch_beam_size, + int current_length, + cudaStream_t stream) { + assert(current_length > 0); + int total_elements = batch_beam_size * current_length; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + UpdateGptInputsKernel<<>>( + old_mask_data, mask_data, next_positions, batch_beam_size, current_length); +} + +template +void GetTempStorageSize(const T *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes) { + if (is_descending) { + cub::DeviceSegmentedRadixSort::SortPairsDescending(nullptr, + temp_storage_bytes, + d_keys_in, + (T*)nullptr, + d_values_in, + (int*)nullptr, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); + } else { + cub::DeviceSegmentedRadixSort::SortPairs(nullptr, + temp_storage_bytes, + d_keys_in, + (T*)nullptr, + d_values_in, + (int*)nullptr, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); + } +} + +template void GetTempStorageSize( + const float *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes); + +template void GetTempStorageSize( + const half *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes); + +// TODO: merge to one kernel +__global__ void SetupParamsKernel(int* d_values_in, + int* d_offsets, + int batch_size, + int vocab_size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + int total_elements = batch_size * vocab_size; + if (index < total_elements) { + d_values_in[index] = index % vocab_size; + } + if (index < batch_size + 1) { + d_offsets[index] = index * vocab_size; + } +} + +void LaunchSetupParamsKernel(int* d_values_in, + int* d_offsets, + int batch_size, + int vocab_size, + cudaStream_t stream) { + int total_elements = batch_size * vocab_size; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + SetupParamsKernel<<>>(d_values_in, + d_offsets, + batch_size, + vocab_size); +} + +template +void LaunchSortPairs(void *d_temp_storage, + size_t temp_storage_bytes, + const T *d_keys_in, + T *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream, + bool is_descending) { + if (is_descending) { + cub::DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); + } else { + cub::DeviceSegmentedRadixSort::SortPairs(d_temp_storage, + temp_storage_bytes, + d_keys_in, + d_keys_out, + d_values_in, + d_values_out, + num_items, + num_segments, + d_offsets, + d_offsets + 1, + 0, + sizeof(T) * 8, + stream); + } +} + +template void LaunchSortPairs(void *d_temp_storage, + size_t temp_storage_bytes, + const float *d_keys_in, + float *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream, + bool is_descending); + +template void LaunchSortPairs(void *d_temp_storage, + size_t temp_storage_bytes, + const half *d_keys_in, + half *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream, + bool is_descending); + +template +__global__ void FilterLogitsKernelCustom(float* d_sorted_logits_in, + const int* d_sorted_indices, + T* d_logits_in_out, + float top_p_threshold, + float filter_value, + int batch_size, + int vocab_size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= batch_size * vocab_size) { + return; + } + + int vocab_idx = index % vocab_size; + int batch_id = index / vocab_size; + int start_index = batch_id * vocab_size; + + int count = vocab_idx; + float sum = 0.0f; + while (count >= 0) { + sum += d_sorted_logits_in[start_index]; + ++start_index; + --count; + } + + if (sum > top_p_threshold) { + // Shift the indices to the right by one according to the custom implementation. + int shifted_index = index + 1; + if (shifted_index % vocab_size != 0) { + int original_index = batch_id * vocab_size + d_sorted_indices[shifted_index]; + d_logits_in_out[original_index] = (T)filter_value; + } + } +} + +template +__global__ void FilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + T* d_logits_in_out, + float top_p_threshold, + float filter_value, + int min_tokens_to_keep, + int batch_size, + int vocab_size) { + int index = blockIdx.x * blockDim.x + threadIdx.x; + + if (index >= batch_size * vocab_size) { + return; + } + + int vocab_idx = index % vocab_size; + int batch_id = index / vocab_size; + int start_index = batch_id * vocab_size; + + int count = vocab_idx; + float sum = 0.0f; + // TODO: Optimization needed. e.g. use CUB::SCAN() for cumulative probabilities. + while (count >= 0) { + sum += d_sorted_logits_in[start_index]; + ++start_index; + --count; + } + + if (sum <= top_p_threshold) { + if (index % vocab_size + min_tokens_to_keep < vocab_size) { + int original_index = batch_id * vocab_size + d_sorted_indices[index]; + d_logits_in_out[original_index] = (T)filter_value; + } + } +} + +template +void LaunchFilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + T* d_logits_in_out, + float top_p, + float filter_value, + int min_tokens_to_keep, + int batch_size, + int vocab_size, + cudaStream_t stream, + bool is_descending) { + int total_elements = batch_size * vocab_size; + constexpr int blockSize = 256; + const int gridSize = (total_elements + blockSize - 1) / blockSize; + if (is_descending) { + FilterLogitsKernelCustom<<>>(d_sorted_logits_in, + d_sorted_indices, + d_logits_in_out, + top_p, + filter_value, + batch_size, + vocab_size); + } else { + FilterLogitsKernel<<>>(d_sorted_logits_in, + d_sorted_indices, + d_logits_in_out, + 1 - top_p, + filter_value, + min_tokens_to_keep, + batch_size, + vocab_size); + } +} + +template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + float* d_logits_in_out, + float top_p, + float filter_value, + int min_tokens_to_keep, + int batch_size, + int vocab_size, + cudaStream_t stream, + bool is_descending); + +template void LaunchFilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + half* d_logits_in_out, + float top_p, + float filter_value, + int min_tokens_to_keep, + int batch_size, + int vocab_size, + cudaStream_t stream, + bool is_descending); + + +// Ref: https://github.com/pytorch/pytorch/blob/release/1.13/aten/src/ATen/native/cuda/MultinomialKernel.cu +template +__global__ void sampleMultinomialOnce(int64_t* dest, + int distributions, + int categories, + scalar_t* sampled, + scalar_t* dist, + int stride_dist, // dist->stride(0) + int stride_categories, // dist->stride(1) + int* d_presence_mask) { + extern __shared__ unsigned char my_smem[]; + __shared__ bool found; + __shared__ unsigned foundPos; + accscalar_t *smem = reinterpret_cast(my_smem); + accscalar_t accZero = static_cast(0); + scalar_t zero = static_cast(0); + for (int curDist = blockIdx.x; + curDist < distributions; curDist += gridDim.x) { + + // Assume sum = 1 in Top P sampling as the input is softmaxed. + accscalar_t sum = 1; + + // Broadcast sum and sample value + if (threadIdx.x == 0) { + // Make sure the sum of our distribution didn't overflow + // CUDA_KERNEL_ASSERT(!_isinf(val)); + // CUDA_KERNEL_ASSERT(sum > accZero); + foundPos = 0; + smem[0] = sum; + smem[1] = sampled[curDist]; + } + __syncthreads(); + sum = smem[0]; + scalar_t sample = static_cast(smem[1]); + __syncthreads(); + if (sum == accZero) { + // Choose the first element + if (threadIdx.x == 0) { + dest[curDist] = 0; + } + continue; + } + int chunks = (categories + (int)blockDim.x - 1) / blockDim.x; + accscalar_t prevHighProb = accZero; + found = false; + for (int chunk = 0; chunk < chunks && !found; ++chunk) { + // All threads in bounds load a value + int cat = chunk * blockDim.x + threadIdx.x; + accscalar_t dist_val = cat < categories ? + static_cast(dist[curDist * stride_dist + cat * stride_categories]) / sum : + accZero; + smem[threadIdx.x] = dist_val; + __syncthreads(); + // Perform an inclusive prefix sum of the shared memory contents + for (int offset = 1; offset < blockDim.x; offset *= 2) { + accscalar_t val = accZero; + if (threadIdx.x >= offset) { + val = smem[threadIdx.x - offset] + smem[threadIdx.x]; + } + __syncthreads(); + if (threadIdx.x >= offset) { + smem[threadIdx.x] = val; + } + __syncthreads(); + } + // Each thread will check to see if the sample falls in its bucket + scalar_t curBucket = + static_cast(smem[threadIdx.x] + prevHighProb); + scalar_t prevBucket = static_cast( + threadIdx.x == 0 ? prevHighProb + : smem[threadIdx.x - 1] + prevHighProb); + bool inBucket = + (cat < categories) && + (!(sample >= curBucket) && + (sample >= prevBucket) && + (dist_val > zero)); + if (inBucket) { + // We're done; we have the sample + // Torch indices are 1-based + atomicMax(&foundPos, cat); + found = true; + } + // Store the previous scan's high value for future use + prevHighProb = prevHighProb + smem[blockDim.x - 1]; + __syncthreads(); + } + if (threadIdx.x == 0) { + if (found) { + dest[curDist] = foundPos; + } else { + // This should address a rare bug where we don't select a valid index. This likely occurs when + // due to floating point arithmetic rounding errors, our cumulative sum does not add up to 1, but + // and our uniform sample is greater than this value. In this case we likely have unitialized memory + // in dest[curDist]. So basically we will loop through the distribution and pick the largest index + // where the distribution is non-zero. This is obviously terribly inefficient, but due to the + // rarity in which this occurs, this should not be an issue. + for (int cat = categories - 1; cat >= 0; --cat) { + if (dist[curDist * stride_dist + cat * stride_categories] > zero) { + dest[curDist] = cat; + break; + } + } + } + } + } + + // update presence mask + int index = blockIdx.x * blockDim.x + threadIdx.x; + if (index >= distributions * categories) { + return; + } + int dist_idx = index / categories; + int cat_idx = index % categories; + if (dest[dist_idx] == cat_idx) { + d_presence_mask[index] = 1; + } +} + +// Only support n_sample = 1 +void TorchMultinomialKernelLauncher(float* d_input, + float* d_sampled, + int64_t* d_output, + int batch_size, + int vocab_size, + int* d_presence_mask, + cudaStream_t stream) +{ + // Store the props in class variables + int device; + cudaGetDevice(&device); + cudaDeviceProp props; + cudaGetDeviceProperties(&props, device); + + int numSM = props.multiProcessorCount; + int maxThreads = props.maxThreadsPerBlock; + int warp_size = 32; //at::cuda::warp_size(); + int requiredWarps = (vocab_size + warp_size - 1) / warp_size; + int requiredThreads = std::min(maxThreads, requiredWarps * warp_size); + int requiredShared = requiredThreads * sizeof(float); + + dim3 block(requiredThreads); + dim3 grid(std::min(batch_size, numSM * 4)); + + sampleMultinomialOnce + <<>>(d_output, + batch_size, + vocab_size, + d_sampled, + d_input, + vocab_size, + 1, + d_presence_mask); + +} + + + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h new file mode 100644 index 0000000000..0aa16a3b94 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.h @@ -0,0 +1,114 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include +#include +#include + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +void LaunchInitKernel( + float* beam_scores, + int batch_size, + int num_beams, + cudaStream_t stream); + +template +void LaunchAddProbsKernel(T* log_probs, + T* cum_log_probs, + const int batch_size, + const int num_beams, + const int vocab_size, + cudaStream_t stream); + +template +void LaunchLogitsProcessKernel( + T* next_token_scores, + const int* vocab_mask, + const int* prefix_vocab_mask, + int* presence_mask, + float presence_penalty, + float temperature, + int batch_size, + int num_beams, + int vocab_size, + int padded_vocab_size, + int demote_token_id, + int32_t* sequences, + int max_sequence_length, + int current_sequence_length, + float repetition_penalty, + int no_repeat_ngram_size, + cudaStream_t stream); + +void LaunchNextTokenKernel(const int64_t* next_token_indices, + int32_t* next_indices, + int32_t* next_tokens, + int batch_size, + int top_k, + int vocab_size, + cudaStream_t stream); + +void LaunchUpdateGptKernel(const int32_t* old_mask_data, + int32_t* mask_data, + int32_t* next_positions, + int batch_beam_size, + int current_length, + cudaStream_t stream); + +template +void GetTempStorageSize(const T *d_keys_in, + const int* d_values_in, + int* d_offsets, + int num_items, + int num_segments, + cudaStream_t stream, + bool is_descending, + size_t& temp_storage_bytes); + +void LaunchSetupParamsKernel(int* d_values_in, + int* d_offsets, + int batch_size, + int vocab_size, + cudaStream_t stream); + +template +void LaunchSortPairs(void *d_temp_storage, + size_t temp_storage_bytes, + const T *d_keys_in, + T *d_keys_out, + const int *d_values_in, + int *d_values_out, + int num_items, + int num_segments, + int *d_offsets, + cudaStream_t stream, + bool is_descending); + +template +void LaunchFilterLogitsKernel(float* d_sorted_logits_in, + const int* d_sorted_indices, + T* d_logits_in_out, + float top_p, + float filter_value, + int min_tokens_to_keep, + int batch_size, + int vocab_size, + cudaStream_t stream, + bool is_descending); + +void TorchMultinomialKernelLauncher(float* d_input, + float* d_sampled, + int64_t* d_output, + int batch_size, + int vocab_size, + int* d_presence_mask, + cudaStream_t stream); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index ee1b90822b..5377c45176 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -4,19 +4,25 @@ #include #include #include "core/providers/shared_library/provider_api.h" +#include "core/providers/cuda/cuda_kernel.h" #include "core/providers/cuda/math/topk_impl.h" #include "core/providers/cuda/math/softmax.h" #include "core/providers/cuda/shared_inc/accumulation_type.h" #include "core/framework/ort_value.h" #include "contrib_ops/cuda/bert/transformer_cuda_common.h" #include -#include "contrib_ops/cuda/transformers/beam_search_impl.h" +#include "contrib_ops/cuda/transformers/generation_cuda_impl.h" #include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" #include "contrib_ops/cpu/transformers/subgraph_t5_decoder.h" #include "contrib_ops/cpu/transformers/subgraph_gpt.h" #include "contrib_ops/cuda/transformers/beam_search_topk.h" #include "core/providers/cuda/nvtx_profile.h" #include "core/providers/cuda/nvtx_profile_context.h" +#include "sampling_cuda_helper.h" + +#ifdef DEBUG_GENERATION +#include +#endif namespace onnxruntime { namespace concurrency { @@ -128,7 +134,6 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, cudaStream_t stream = ort_stream ? static_cast(ort_stream->GetHandle()) : nullptr; auto pinned_buffer = IAllocator::MakeUniquePtr(pinned_allocator, total_bytes); char* pinned_data = static_cast(pinned_buffer.get()); - // Copy tensors to one pinned memory buffer (so that we only need copy to GPU once) char* destination = pinned_data; for (auto& input : inputs) { @@ -145,16 +150,13 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, "AddToFeeds: An implementation for the input type ", dataType, " is not supported yet"); } - // Do not need alignment because GPT has int32 inputs (past is empty) and T5 encoder has int64 inputs. destination += bytes; } } - if (!buffer) { buffer = provider->GetScratchBuffer(total_bytes, ort_stream, WaitCudaNotificationOnDevice); } - char* gpu_data = buffer.get(); CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(gpu_data, pinned_data, total_bytes, cudaMemcpyHostToDevice, stream)); @@ -164,7 +166,6 @@ Status AddToFeeds(const IExecutionProvider* execution_provider, CUDA_RETURN_IF_ERROR(cudaEventCreate(&isCopyDone)); CUDA_RETURN_IF_ERROR(cudaEventRecord(isCopyDone, stream)); CUDA_RETURN_IF_ERROR(cudaEventSynchronize(isCopyDone)); - // TODO(tianleiwu): allocate a buffer for subgraph inputs so that we can reuse the buffer in each subgraph call. const OrtMemoryInfo& location = provider->GetAllocator(0, OrtMemTypeDefault)->Info(); for (auto& input : inputs) { @@ -253,7 +254,7 @@ Status ProcessLogits(const OrtValue& logits, // onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors transformers::IBeamScorer* beam_scorer, // beam scorer - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter Stream* ort_stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper @@ -363,6 +364,9 @@ Status ProcessLogits(const OrtValue& logits, // next_token_scores.data(), parameters->vocab_mask.data(), step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. + nullptr, // parameters->presence_mask.data(), + parameters->presence_penalty, + parameters->temperature, parameters->batch_size, parameters->num_beams, parameters->vocab_size, @@ -504,11 +508,13 @@ template Status GreedySearchProcessLogits( const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingState* sampling_state, // buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter Stream* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper) { // tensor dumper @@ -519,7 +525,6 @@ Status GreedySearchProcessLogits( #endif ORT_UNUSED_PARAMETER(logits_processors); - #ifndef DEBUG_GENERATION ORT_UNUSED_PARAMETER(dumper); #endif @@ -596,6 +601,13 @@ Status GreedySearchProcessLogits( cudaMemcpyHostToDevice, cuda_stream)); } + // Copy parameters->presence_mask to sampling_state->presence_mask + gsl::span& presence_mask = sampling_state->d_presence_mask; + if (step == 1 && parameters->presence_mask.data() != nullptr) { + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(presence_mask.data(), parameters->presence_mask.data(), + sizeof(int) * batch_size * vocab_size, cudaMemcpyDeviceToDevice, cuda_stream)); + } + // TODO(hasesh): Can we avoid the const_cast by changing the interface of // GreedySearchProcessLogits() to take in a non-const OrtValue for logits // as this is the only place we will ever use the logits and it may be reasonable @@ -605,11 +617,15 @@ Status GreedySearchProcessLogits( : reinterpret_cast(next_token_scores.data()), parameters->vocab_mask.data(), step > 1 ? nullptr : parameters->prefix_vocab_mask.data(), // prefix vocab mask is applied to first step only. + parameters->presence_mask.data() ? presence_mask.data() : nullptr, + parameters->presence_penalty, + parameters->temperature, parameters->batch_size, parameters->num_beams, parameters->vocab_size, is_reuse_logits_buffer ? padded_vocab_size : parameters->vocab_size, - (parameters->min_length > 0 && current_sequence_length < parameters->min_length) ? parameters->eos_token_id : -1, + (parameters->min_length > 0 && current_sequence_length < parameters->sequence_length + parameters->min_length) + ? parameters->eos_token_id : -1, reinterpret_cast(sequences_buffer.get()), parameters->max_length, current_sequence_length, @@ -629,6 +645,19 @@ Status GreedySearchProcessLogits( // TODO(wy): support output_scores in greedy search ORT_UNUSED_PARAMETER(output_scores); + if (do_sampling) { + ORT_RETURN_IF_ERROR(SamplingCudaHelper::Sample(allocator, + cuda_stream, + next_token_scores, + sampling_state, + greedy_state, + parameters, + step, + dumper)); + + return Status::OK(); + } + // next_tokens = torch.argmax(scores, dim=-1) int64_t next_token_scores_dims[] = {static_cast(batch_size), is_reuse_logits_buffer ? padded_vocab_size : vocab_size}; @@ -657,8 +686,8 @@ Status GreedySearchProcessLogits( *topk_scores, *topk_indices)); #ifdef DEBUG_GENERATION - dumper->Print("topk_scores", *(topk_scores.get())); - dumper->Print("topk_indices", *(topk_indices.get())); + dumper->Print("topk_scores", *(topk_scores.get())); + dumper->Print("topk_indices", *(topk_indices.get())); #endif const int64_t* next_token_indices = topk_indices->Data(); @@ -996,7 +1025,7 @@ template Status ProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, transformers::IBeamScorer* beam_scorer, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, Stream* ort_stream, const transformers::IConsoleDumper* dumper); @@ -1004,11 +1033,13 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, + transformers::ISamplingState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, + bool do_sampling, int step, Stream* ort_stream, const transformers::IConsoleDumper* dumper); @@ -1063,7 +1094,7 @@ template Status ProcessLogits( onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, transformers::IBeamScorer* beam_scorer, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, int step, Stream* ort_stream, const transformers::IConsoleDumper* dumper); @@ -1071,11 +1102,13 @@ template Status ProcessLogits( template Status GreedySearchProcessLogits( const OrtValue& logits, transformers::IGreedySearchState* greedy_state, + transformers::ISamplingState* sampling_state, transformers::ISequences* sequences, AllocatorPtr& allocator, onnxruntime::concurrency::ThreadPool* thread_pool, transformers::ILogitsProcessorList* logits_processors, - const transformers::IBeamSearchParameters* parameters, + const transformers::IGenerationParameters* parameters, + bool do_sampling, int step, Stream* ort_stream, const transformers::IConsoleDumper* dumper); diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index 5623f05a27..755ac75e14 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -55,7 +55,7 @@ Status ProcessLogits(const OrtValue& logits, // onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors transformers::IBeamScorer* beam_scorer, // beam scorer - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters int step, // iteration counter Stream* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper @@ -63,11 +63,13 @@ Status ProcessLogits(const OrtValue& logits, // template Status GreedySearchProcessLogits(const OrtValue& logits, // logits output of subgraph transformers::IGreedySearchState* greedy_state, // state + transformers::ISamplingState* sampling_state,// sampling buffers transformers::ISequences* sequences, // sequences AllocatorPtr& allocator, // default allocator onnxruntime::concurrency::ThreadPool* thread_pool, // thread pool (for CPU only) transformers::ILogitsProcessorList* logits_processors, // logits processors - const transformers::IBeamSearchParameters* parameters, // parameters + const transformers::IGenerationParameters* parameters, // parameters + bool do_sampling, // whether to do sampling int step, // iteration counter Stream* stream, // cuda stream (for CUDA only) const transformers::IConsoleDumper* dumper); // tensor dumper diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling.cc b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc new file mode 100644 index 0000000000..a758112f6f --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling.cc @@ -0,0 +1,68 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/cuda_execution_provider.h" +#include "contrib_ops/cuda/transformers/sampling.h" +#include "contrib_ops/cuda/transformers/generation_device_helper.h" +#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +ONNX_OPERATOR_KERNEL_EX( + Sampling, + kMSDomain, + 1, + kCudaExecutionProvider, + (*KernelDefBuilder::Create()) + .InputMemoryType(OrtMemTypeCPUInput, 0) // 'input_ids' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 1) // 'max_length' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 2) // 'min_length' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 3) // 'repetition_penalty' needs to be on CPU + .InputMemoryType(OrtMemTypeCPUInput, 6) // 'custom_attention_mask' needs to be on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 0) // 'sequences' output on CPU + .OutputMemoryType(OrtMemTypeCPUOutput, 1) // 'logits_to_debug' output on CPU + .TypeConstraint("T", {DataTypeImpl::GetTensorType(), + DataTypeImpl::GetTensorType()}), + Sampling); + +transformers::CudaTensorConsoleDumper g_cuda_dumper_sampling; + +Sampling::Sampling(const OpKernelInfo& info) + : onnxruntime::contrib::transformers::Sampling(info) { + SetDeviceHelpers(GenerationCudaDeviceHelper::AddToFeeds, + GenerationCudaDeviceHelper::TopK, + GenerationCudaDeviceHelper::DeviceCopy, + GenerationCudaDeviceHelper::GreedySearchProcessLogits, + GenerationCudaDeviceHelper::GreedySearchProcessLogits, + GenerationCudaDeviceHelper::InitGreedyState, + GenerationCudaDeviceHelper::InitGreedyState); + + SetDeviceHelpers_Gpt(GenerationCudaDeviceHelper::UpdateGptFeeds, + GenerationCudaDeviceHelper::UpdateGptFeeds); + + SetConsoleDumper(&g_cuda_dumper_sampling); +} + +Status Sampling::ComputeInternal(OpKernelContext* context) const { + return onnxruntime::contrib::transformers::Sampling::Compute(context); +} + +Status Sampling::Compute(OpKernelContext* context) const { + auto s = ComputeInternal(context); + + if (s.IsOK()) { + auto err = cudaGetLastError(); + if (err != cudaSuccess) { + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA error ", cudaGetErrorName(err), ":", cudaGetErrorString(err)); + } + } + + return s; +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling.h b/onnxruntime/contrib_ops/cuda/transformers/sampling.h new file mode 100644 index 0000000000..65bee53573 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling.h @@ -0,0 +1,26 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "contrib_ops/cpu/transformers/sampling.h" + +namespace onnxruntime { +class SessionState; + +namespace contrib { +namespace cuda { + +class Sampling final : public onnxruntime::contrib::transformers::Sampling { + public: + Sampling(const OpKernelInfo& info); + + Status Compute(OpKernelContext* context) const override; + + private: + Status ComputeInternal(OpKernelContext* context) const; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h new file mode 100644 index 0000000000..092612e674 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/transformers/sampling_cuda_helper.h @@ -0,0 +1,186 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +#pragma once + +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/tensor/utils.h" +#include "contrib_ops/cpu/transformers/generation_shared.h" +#include "core/providers/cuda/math/softmax.h" + +#ifdef DEBUG_GENERATION +#include +#endif + +namespace onnxruntime { +namespace contrib { +namespace SamplingCudaHelper { + +template +Status Sample(AllocatorPtr& allocator, + cudaStream_t cuda_stream, + gsl::span& next_token_scores, + transformers::ISamplingState* sampling_state, + transformers::IGreedySearchState* greedy_state, + const transformers::IGenerationParameters* parameters, + int step, + const transformers::IConsoleDumper* dumper) { + ORT_UNUSED_PARAMETER(dumper); + typedef typename ToCudaType::MappedType CudaT; + + gsl::span& d_index_in = sampling_state->d_index_in; + gsl::span& d_offset = sampling_state->d_offset; + + BufferUniquePtr& storage_buffer = sampling_state->storage_buffer; + size_t& temp_storage_bytes = sampling_state->temp_storage_bytes; + + bool is_descending = parameters->custom_sampling; + if (step == 1) { + cuda::GetTempStorageSize(reinterpret_cast(next_token_scores.data()), + d_index_in.data(), + d_offset.data(), + parameters->batch_size * parameters->vocab_size, + parameters->batch_size, + cuda_stream, + is_descending, + temp_storage_bytes); + + cuda::LaunchSetupParamsKernel(d_index_in.data(), + d_offset.data(), + parameters->batch_size, + parameters->vocab_size, + cuda_stream); + +#ifdef DEBUG_GENERATION + dumper->Print("d_offset_buffer", d_offset.data(), parameters->batch_size + 1, 1); +#endif + + void* temp_storage = allocator->Alloc(sampling_state->temp_storage_bytes); + BufferUniquePtr temp_storage_buffer(temp_storage, BufferDeleter(allocator)); + storage_buffer = std::move(temp_storage_buffer); + } + + gsl::span& d_sorted_score = sampling_state->d_sorted_score; + gsl::span& d_index_out = sampling_state->d_index_out; + +#ifdef DEBUG_GENERATION + dumper->Print("temp_storage_bytes", sampling_state->temp_storage_bytes, true); +#endif + + cuda::LaunchSortPairs(storage_buffer.get(), + temp_storage_bytes, + reinterpret_cast(next_token_scores.data()), + reinterpret_cast(d_sorted_score.data()), + d_index_in.data(), + d_index_out.data(), + parameters->batch_size * parameters->vocab_size, + parameters->batch_size, + d_offset.data(), + cuda_stream, + is_descending); + +#ifdef DEBUG_GENERATION + dumper->Print("d_sorted_score_buffer", + reinterpret_cast(d_sorted_score.data()), + parameters->batch_size, + parameters->vocab_size); + dumper->Print("d_index_buffer_in", d_index_in.data(), parameters->batch_size, parameters->vocab_size); + dumper->Print("d_index_buffer_out", d_index_out.data(), parameters->batch_size, parameters->vocab_size); +#endif + + gsl::span& d_sorted_softmaxed_score = sampling_state->d_sorted_softmaxed_score; + dispatch_blockwise_softmax_forward(cuda_stream, + d_sorted_softmaxed_score.data(), + reinterpret_cast(d_sorted_score.data()), + parameters->vocab_size, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size); + +#ifdef DEBUG_GENERATION + dumper->Print("d_sorted_softmaxed_score_buffer", + d_sorted_softmaxed_score.data(), + parameters->batch_size, + parameters->vocab_size); +#endif + + cuda::LaunchFilterLogitsKernel(d_sorted_softmaxed_score.data(), + d_index_out.data(), + reinterpret_cast(next_token_scores.data()), + parameters->top_p, + parameters->filter_value, + parameters->min_tokens_to_keep, + parameters->batch_size, + parameters->vocab_size, + cuda_stream, + is_descending); + +#ifdef DEBUG_GENERATION + dumper->Print("next_token_scores after filtering logits", + reinterpret_cast(next_token_scores.data()), + parameters->batch_size, + parameters->vocab_size); +#endif + + gsl::span& d_softmaxed_score = sampling_state->d_softmaxed_score; + dispatch_blockwise_softmax_forward(cuda_stream, + d_softmaxed_score.data(), + reinterpret_cast(next_token_scores.data()), + parameters->vocab_size, + parameters->vocab_size, + parameters->vocab_size, + parameters->batch_size); + +#ifdef DEBUG_GENERATION + dumper->Print("d_softmaxed_score_buffer", + d_softmaxed_score.data(), + parameters->batch_size, + parameters->vocab_size); +#endif + + // Multinomial sampling + gsl::span& d_sampled = sampling_state->d_sampled; + gsl::span& h_sampled_all = sampling_state->h_sampled_all; + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(d_sampled.data(), + h_sampled_all.data() + (step - 1) * parameters->batch_size, + sizeof(float) * parameters->batch_size, + cudaMemcpyHostToDevice, + cuda_stream)); + +#ifdef DEBUG_GENERATION + dumper->Print("d_sampled", d_sampled.data(), parameters->batch_size, 1); +#endif + + gsl::span& d_indices = sampling_state->d_indices; + gsl::span& presence_mask = sampling_state->d_presence_mask; + cuda::TorchMultinomialKernelLauncher(d_softmaxed_score.data(), + d_sampled.data(), + d_indices.data(), + parameters->batch_size, + parameters->vocab_size, + presence_mask.data(), + cuda_stream); + +#ifdef DEBUG_GENERATION + dumper->Print("d_indices", d_indices.data(), parameters->batch_size, 1); +#endif + + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(greedy_state->next_tokens_cpu.data(), + sampling_state->d_indices.data(), + greedy_state->next_tokens_cpu.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + + CUDA_RETURN_IF_ERROR(cudaMemcpyAsync(sampling_state->h_softmaxed_score.data(), + sampling_state->d_softmaxed_score.data(), + sampling_state->h_softmaxed_score.size_bytes(), + cudaMemcpyDeviceToHost, + cuda_stream)); + + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(cuda_stream)); + + return Status::OK(); +} + +} // namespace SamplingCudaHelper +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 03a18f1247..15f8599f52 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -510,6 +510,13 @@ void GreedySearchShapeInference(ONNX_NAMESPACE::InferenceContext& ctx) { sequences_shape.add_dim()->set_dim_value(batch_size); sequences_shape.add_dim()->set_dim_value(max_length_value); updateOutputShape(ctx, 0, sequences_shape); + + if (ctx.getNumOutputs() > 1) { + ONNX_NAMESPACE::TensorShapeProto logits_to_debug_shape; + logits_to_debug_shape.add_dim()->set_dim_value(batch_size); + logits_to_debug_shape.add_dim(); + updateOutputShape(ctx, 1, logits_to_debug_shape); + } } constexpr const char* Gelu_ver1_doc = @@ -1114,6 +1121,48 @@ ONNX_MS_OPERATOR_SET_SCHEMA(GreedySearch, 1, GreedySearchShapeInference(ctx); })); +ONNX_MS_OPERATOR_SET_SCHEMA(Sampling, 1, + OpSchema() + .SetDoc("Greedy Sampling for text generation.") + .Attr("eos_token_id", "The id of the end-of-sequence token", AttributeProto::INT) + .Attr("pad_token_id", "The id of the padding token", AttributeProto::INT) + .Attr("decoder_start_token_id", "The id of the token that indicates decoding starts.", AttributeProto::INT, static_cast(-1)) + .Attr("no_repeat_ngram_size", "no repeat ngrams size", AttributeProto::INT, static_cast(0)) + .Attr("temperature", "The value used to module the next token probabilities.", AttributeProto::FLOAT, 1.0f) + .Attr("top_p", + "If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or higher are kept for generation.", + AttributeProto::FLOAT, 0.0f) + .Attr("filter_value", "All filtered values will be set to this float value.", AttributeProto::FLOAT, -1e20f) + .Attr("min_tokens_to_keep", "Minimumber of tokens we keep per batch example in the output.", AttributeProto::INT, static_cast(0)) + .Attr("presence_penalty", "Presence penalty for custom sampling", AttributeProto::FLOAT, 0.0f) + .Attr("custom", "If 1 custom sampling logic", AttributeProto::INT, static_cast(0)) + .Attr("model_type", "Model type: 0 for decoder only like GPT-2; 1 for encoder decoder like Bart", AttributeProto::INT, static_cast(0)) + .Attr("encoder", "The subgraph for initialization of encoder and decoder. It will be called once before decoder subgraph.", AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("init_decoder", + "The subgraph for the first decoding run. It will be called once before `decoder` subgraph. " + "This is relevant only for the GPT2 model. If this attribute is missing, the `decoder` subgraph will be used for all decoding runs", + AttributeProto::GRAPH, OPTIONAL_VALUE) + .Attr("decoder", "Decoder subgraph to execute in a loop.", AttributeProto::GRAPH) + .Attr("vocab_size", + "Size of the vocabulary. " + "If not provided, it will be inferred from the decoder subgraph's output shape", + AttributeProto::INT, static_cast(-1)) + .Input(0, "input_ids", "The sequence used as a prompt for the generation. Shape is (batch_size, sequence_length)", "I") + .Input(1, "max_length", "The maximum length of the sequence to be generated. Shape is (1)", "I") + .Input(2, "min_length", "The minimum length below which the score of eos_token_id is set to -Inf. Shape is (1)", "I", OpSchema::Optional) + .Input(3, "repetition_penalty", "The parameter for repetition penalty. Default value 1.0 means no penalty. Accepts value > 0.0. Shape is (1)", "T", OpSchema::Optional) + .Input(4, "vocab_mask", "Mask of vocabulary. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (vacab_size)", "I", OpSchema::Optional) + .Input(5, "prefix_vocab_mask", "Mask of vocabulary for first step. Words that masked with 0 are not allowed to be generated, and 1 is allowed. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) + .Input(6, "attention_mask", "Custom attention mask. Shape is (batch_size, sequence_length)", "I", OpSchema::Optional) + .Input(7, "presence_mask", "Presence penalty mask. Shape is (batch_size, vocab_size)", "I", OpSchema::Optional) + .Output(0, "sequences", "Word IDs of generated sequences. Shape is (batch_size, max_sequence_length)", "I") + .Output(1, "filtered_logits", "Filtered logits as input to the mutinomial function for debug purpose. Shape is (batch_size, vocab_size)", "T", OpSchema::Optional) + .TypeConstraint("T", {"tensor(float)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("I", {"tensor(int32)"}, "Constrain to integer types") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + GreedySearchShapeInference(ctx); + })); + ONNX_MS_OPERATOR_SET_SCHEMA(SampleOp, 1, OpSchema() .Input(0, "X", "input", "T") diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index ff32e568e0..338baf82d3 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -81,6 +81,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RemovePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, RestorePadding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Rfft); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SampleOp); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Sampling); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SkipLayerNormalization); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, SparseToDenseMatMul); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, Tokenizer); @@ -165,6 +166,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc index bb8037fbc2..4013a8f156 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.cc +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.cc @@ -33,6 +33,7 @@ #include "contrib_ops/cpu/bert/longformer_attention_base.h" #include "contrib_ops/cpu/transformers/beam_search.h" #include "contrib_ops/cpu/transformers/greedy_search.h" +#include "contrib_ops/cpu/transformers/sampling.h" #ifdef ENABLE_ATEN #include "contrib_ops/cpu/aten_ops/aten_op.h" #endif @@ -248,6 +249,11 @@ struct ProviderHostCPUImpl : ProviderHostCPU { subgraph_session_state); } + void Sampling__Init(contrib::transformers::Sampling* p, const OpKernelInfo& info) override { p->contrib::transformers::Sampling::Init(info); } + Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) override { return p->contrib::transformers::Sampling::Compute(ctx); } + Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) override { return p->contrib::transformers::Sampling::SetupSubgraphExecutionInfo(session_state, attribute_name, subgraph_session_state); } + + #ifdef ENABLE_ATEN Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) override { return p->ATen::Compute(p_ctx); } #endif diff --git a/onnxruntime/core/providers/cpu/cpu_provider_shared.h b/onnxruntime/core/providers/cpu/cpu_provider_shared.h index 3161f25d40..7bdb101da1 100644 --- a/onnxruntime/core/providers/cpu/cpu_provider_shared.h +++ b/onnxruntime/core/providers/cpu/cpu_provider_shared.h @@ -9,6 +9,7 @@ class AttentionBase; namespace transformers { class BeamSearch; class GreedySearch; +class Sampling; } } // namespace contrib @@ -174,6 +175,10 @@ struct ProviderHostCPU { const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + virtual void Sampling__Init(contrib::transformers::Sampling* p, const OpKernelInfo& info) = 0; + virtual Status Sampling__Compute(const contrib::transformers::Sampling* p, OpKernelContext* ctx) = 0; + virtual Status Sampling__SetupSubgraphExecutionInfo(contrib::transformers::Sampling* p, const SessionState& session_state, const std::string& attribute_name, const SessionState& subgraph_session_state) = 0; + #ifdef ENABLE_ATEN virtual Status ATen__Compute(const contrib::ATen* p, OpKernelContext* p_ctx) = 0; #endif diff --git a/onnxruntime/core/providers/cpu/generator/random.cc b/onnxruntime/core/providers/cpu/generator/random.cc index 3add38239e..84a283c3c8 100644 --- a/onnxruntime/core/providers/cpu/generator/random.cc +++ b/onnxruntime/core/providers/cpu/generator/random.cc @@ -207,13 +207,13 @@ template using EigenVector = Eigen::TensorMap>; template -static Status MultinomialCompute(OpKernelContext* ctx, - const Tensor& X, - const int64_t batch_size, - const int64_t num_classes, - const int64_t num_samples, - std::default_random_engine& generator, - Tensor& Y) { +Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y) { if (!utils::HasType()) { return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Output type not supported in this build."); } @@ -227,8 +227,6 @@ static Status MultinomialCompute(OpKernelContext* ctx, Matrix output = Matrix(Y.MutableData(), Y_dims); // BEGIN create temporary tensor - AllocatorPtr alloc; - ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); auto cdf_data = static_cast(alloc->Alloc(SafeInt(sizeof(double)) * num_classes)); BufferUniquePtr cdf_buffer(cdf_data, BufferDeleter(std::move(alloc))); Eigen::array cdf_dims = {{num_classes}}; @@ -271,6 +269,20 @@ static Status MultinomialCompute(OpKernelContext* ctx, return Status::OK(); } +template +static Status MultinomialCompute(OpKernelContext* ctx, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y) { + // BEGIN create temporary tensor + AllocatorPtr alloc; + ORT_RETURN_IF_ERROR(ctx->GetTempSpaceAllocator(&alloc)); + return MultinomialComputeShared(alloc, X, batch_size, num_classes, num_samples, generator, Y); +} + Status Multinomial::Compute(OpKernelContext* ctx) const { const auto* tensor_pointer = ctx->Input(0); if (tensor_pointer == nullptr) return Status(common::ONNXRUNTIME, common::FAIL, "input count mismatch"); @@ -408,4 +420,12 @@ void GenerateData(std::default_random_engine& generator, TDistribution distribut } } +template Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y); + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cpu/generator/random.h b/onnxruntime/core/providers/cpu/generator/random.h index 7f5a761853..dc99328b6c 100644 --- a/onnxruntime/core/providers/cpu/generator/random.h +++ b/onnxruntime/core/providers/cpu/generator/random.h @@ -13,6 +13,15 @@ namespace onnxruntime { +template +Status MultinomialComputeShared(AllocatorPtr& alloc, + const Tensor& X, + const int64_t batch_size, + const int64_t num_classes, + const int64_t num_samples, + std::default_random_engine& generator, + Tensor& Y); + class RandomNormal final : public OpKernel { public: RandomNormal(const OpKernelInfo& info) : OpKernel(info) { diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 9fe1bc59f0..079154079c 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -33,6 +33,7 @@ #include "contrib_ops/cpu/bert/longformer_attention_base.h" #include "contrib_ops/cpu/transformers/beam_search.h" #include "contrib_ops/cpu/transformers/greedy_search.h" +#include "contrib_ops/cpu/transformers/sampling.h" #ifdef ENABLE_ATEN #include "contrib_ops/cpu/aten_ops/aten_op.h" #endif @@ -601,6 +602,15 @@ Status GreedySearch::SetupSubgraphExecutionInfo(const SessionState& session_stat return g_host_cpu.GreedySearch__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } + +void Sampling::Init(const OpKernelInfo& info) { g_host_cpu.Sampling__Init(this, info); } + +Status Sampling::Compute(OpKernelContext* ctx) const { return g_host_cpu.Sampling__Compute(this, ctx); } + +Status Sampling::SetupSubgraphExecutionInfo(const SessionState& session_state, const std::string& attribute_name, + const SessionState& subgraph_session_state) { + return g_host_cpu.Sampling__SetupSubgraphExecutionInfo(this, session_state, attribute_name, subgraph_session_state); } + } // namespace transformers #ifdef ENABLE_ATEN diff --git a/onnxruntime/python/tools/transformers/convert_generation.py b/onnxruntime/python/tools/transformers/convert_generation.py index 0524b9baf2..596c556cfc 100644 --- a/onnxruntime/python/tools/transformers/convert_generation.py +++ b/onnxruntime/python/tools/transformers/convert_generation.py @@ -25,6 +25,9 @@ Example 4: convert MT5 model with external data file like mt5-base-beamsearch.on Example 5: convert gpt2 model with greedy search: python convert_generation.py -m gpt2 --output gpt2_greedy_search.onnx --num_beams 1 --num_return_sequences 1 + +Example 6: convert gpt2 model with sampling: + python convert_generation.py -m gpt2 --output gpt2_sampling.onnx --num_beams 1 --num_return_sequences 1 --top_p 0.6 """ import argparse @@ -72,6 +75,7 @@ logger = logging.getLogger("") class GenerationType(Enum): BEAMSEARCH = "beam_search" GREEDYSEARCH = "greedy_search" + SAMPLING = "sampling" def __str__(self): return self.value @@ -261,6 +265,14 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: ) model_group.set_defaults(custom_attention_mask=False) + model_group.add_argument( + "--presence_mask", + required=False, + action="store_true", + help="Presence mask for custom sampling", + ) + model_group.set_defaults(presence_mask=False) + beam_parameters_group = parser.add_argument_group( "Beam search parameters not stored in the output model, for testing parity and performance" ) @@ -295,6 +307,54 @@ def parse_arguments(argv: Optional[List[str]] = None) -> argparse.Namespace: help="Positive. >1 to penalize and <1 to encourage.", ) + beam_parameters_group.add_argument( + "--temperature", + type=float, + required=False, + default=1.0, + help="The value used to module the next token probabilities.", + ) + + beam_parameters_group.add_argument( + "--top_p", + type=float, + required=False, + default=1.0, + help="Top P for sampling", + ) + + beam_parameters_group.add_argument( + "--filter_value", + type=float, + required=False, + default=-float("Inf"), + help="Filter value for Top P sampling", + ) + + beam_parameters_group.add_argument( + "--min_tokens_to_keep", + type=int, + required=False, + default=1, + help="Minimumber of tokens we keep per batch example in the output.", + ) + + beam_parameters_group.add_argument( + "--presence_penalty", + type=float, + required=False, + default=0.0, + help="presence penalty for custom sampling.", + ) + + beam_parameters_group.add_argument( + "--custom", + type=int, + required=False, + default=0, + help="If 1 customized top P logic is applied", + ) + beam_parameters_group.add_argument( "--vocab_size", type=int, @@ -1201,18 +1261,20 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati args (argparse.Namespace): arguments parsed from command line """ is_gpt2: bool = args.model_type == "gpt2" + is_beamsearch: bool = generation_type == GenerationType.BEAMSEARCH is_greedysearch: bool = generation_type == GenerationType.GREEDYSEARCH + is_sampling: bool = generation_type == GenerationType.SAMPLING past_present_share_buffer: bool = args.past_present_share_buffer and is_greedysearch logger.info(f"**** past_present_share_buffer={past_present_share_buffer}, is_greedysearch={is_greedysearch}") - if is_greedysearch: + if is_greedysearch or is_sampling: if not is_gpt2: - raise NotImplementedError("Currently only gpt2 with greedy search is supported") + raise NotImplementedError("Currently only gpt2 with greedy search/sampling is supported") if args.output_sequences_scores: - raise NotImplementedError("output_sequences_scores currently is not supported in greedy search") + raise NotImplementedError("output_sequences_scores currently is not supported in greedy search/sampling") if args.output_token_scores: - raise NotImplementedError("output_token_scores currently is not supported in greedy search") + raise NotImplementedError("output_token_scores currently is not supported in greedy search/sampling") if is_gpt2: if args.decoder_onnx and os.path.exists(args.decoder_onnx): @@ -1243,7 +1305,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati args.pad_vocab_size and args.precision == Precision.FLOAT16 and is_gpt2 - and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH) + and (is_beamsearch or is_greedysearch or is_sampling) ): logger.info( f"Pad logits MatMul weights for optimal MatMul perf in fp16 on {args.decoder_onnx}. " @@ -1257,11 +1319,7 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati gpt2_init_decoder_generated = False gpt2_init_decoder_onnx_path = None - if ( - args.separate_gpt2_decoder_for_init_run - and is_gpt2 - and (generation_type == GenerationType.BEAMSEARCH or generation_type == GenerationType.GREEDYSEARCH) - ): + if args.separate_gpt2_decoder_for_init_run and is_gpt2 and (is_beamsearch or is_greedysearch or is_sampling): logger.info(f"Creating an initial run GPT2 decoder from {args.decoder_onnx}. ") gpt2_init_decoder_onnx_filename = "gpt2_init_past_{}.onnx".format( @@ -1331,8 +1389,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati else: verify_t5_decoder_subgraph(decoder_model.graph, args.precision) - inputs = ( - [ + inputs = None + if is_beamsearch: + inputs = [ "input_ids", "max_length", "min_length", @@ -1341,14 +1400,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati "length_penalty", "repetition_penalty", ] - if not is_greedysearch - else [ + elif is_greedysearch or is_sampling: + inputs = [ "input_ids", "max_length", "min_length", "repetition_penalty", ] - ) if args.vocab_mask: inputs.append("vocab_mask") @@ -1365,6 +1423,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati else: inputs.append("") + if is_sampling and args.custom and args.presence_mask: + inputs.append("presence_mask") + outputs = ["sequences"] if args.output_sequences_scores: outputs.append("sequences_scores") @@ -1373,40 +1434,60 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati assert args.output_sequences_scores, "--output_token_scores requires --output_sequences_scores" outputs.append("scores") - node = ( - onnx.helper.make_node( + node = None + if is_beamsearch: + node = onnx.helper.make_node( "BeamSearch", inputs=inputs, outputs=outputs, name=f"BeamSearch_{args.model_type}", ) - if not is_greedysearch - else onnx.helper.make_node( + elif is_greedysearch: + node = onnx.helper.make_node( "GreedySearch", inputs=inputs, outputs=outputs, name=f"GreedySearch_{args.model_type}", ) - ) + elif is_sampling: + node = onnx.helper.make_node( + "Sampling", + inputs=inputs, + outputs=outputs, + name=f"Sampling_{args.model_type}", + ) node.domain = "com.microsoft" - attr_to_extend = ( - [ + attr_to_extend = None + if is_beamsearch: + attr_to_extend = [ onnx.helper.make_attribute("eos_token_id", eos_token_id), onnx.helper.make_attribute("pad_token_id", pad_token_id), onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), onnx.helper.make_attribute("early_stopping", 1 if args.early_stopping else 0), onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), ] - if not is_greedysearch - else [ + elif is_greedysearch: + attr_to_extend = [ onnx.helper.make_attribute("eos_token_id", eos_token_id), onnx.helper.make_attribute("pad_token_id", pad_token_id), onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), ] - ) + elif is_sampling: + attr_to_extend = [ + onnx.helper.make_attribute("eos_token_id", eos_token_id), + onnx.helper.make_attribute("pad_token_id", pad_token_id), + onnx.helper.make_attribute("model_type", 0 if args.model_type == "gpt2" else 1), + onnx.helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size), + onnx.helper.make_attribute("temperature", args.temperature), + onnx.helper.make_attribute("top_p", args.top_p), + onnx.helper.make_attribute("filter_value", args.filter_value), + onnx.helper.make_attribute("min_tokens_to_keep", args.min_tokens_to_keep), + onnx.helper.make_attribute("custom", args.custom), + onnx.helper.make_attribute("presence_penalty", args.presence_penalty), + ] # Explicitly pass in the vocab size via an attribute if logits_matmul_weight_padded: @@ -1481,8 +1562,9 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati length_penalty = onnx.helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1]) repetition_penalty = onnx.helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1]) - graph_inputs = ( - [ + graph_inputs = None + if is_beamsearch: + graph_inputs = [ input_ids, max_length, min_length, @@ -1491,14 +1573,13 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati length_penalty, repetition_penalty, ] - if not is_greedysearch - else [ + elif is_greedysearch or is_sampling: + graph_inputs = [ input_ids, max_length, min_length, repetition_penalty, ] - ) if args.vocab_mask: vocab_mask = onnx.helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [vocab_size]) @@ -1516,37 +1597,41 @@ def convert_generation_model(args: argparse.Namespace, generation_type: Generati ) graph_inputs.append(attention_mask) + if args.custom and args.presence_mask: + presence_mask = onnx.helper.make_tensor_value_info( + "presence_mask", TensorProto.INT32, ["batch_size", vocab_size] + ) + graph_inputs.append(presence_mask) + # graph outputs - sequences = ( - onnx.helper.make_tensor_value_info( + sequences = None + if is_beamsearch: + sequences = onnx.helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"], ) - if not is_greedysearch - else onnx.helper.make_tensor_value_info( + elif is_greedysearch or is_sampling: + sequences = onnx.helper.make_tensor_value_info( "sequences", TensorProto.INT32, ["batch_size", "max_length"], ) - ) - - sequences_scores = onnx.helper.make_tensor_value_info( - "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] - ) - - scores = onnx.helper.make_tensor_value_info( - "scores", - TensorProto.FLOAT, - ["max_length - sequence_length", "batch_size", "num_beams", vocab_size], - ) graph_outputs = [sequences] if args.output_sequences_scores: + sequences_scores = onnx.helper.make_tensor_value_info( + "sequences_scores", TensorProto.FLOAT, ["batch_size", "num_return_sequences"] + ) graph_outputs.append(sequences_scores) if args.output_token_scores: + scores = onnx.helper.make_tensor_value_info( + "scores", + TensorProto.FLOAT, + ["max_length - sequence_length", "batch_size", "num_beams", vocab_size], + ) graph_outputs.append(scores) new_graph = onnx.helper.make_graph( @@ -2085,6 +2170,10 @@ def main(argv: Optional[List[str]] = None, sentences: Optional[List[str]] = None is_greedy = args.num_beams == 1 and args.num_return_sequences == 1 if args.model_type == "gpt2" and is_greedy: + if args.top_p > 0.0 and args.top_p < 1.0: + convert_generation_model(args, GenerationType.SAMPLING) + logger.info("The test for gpt2_sampling onnx model is not implemented yet") + return convert_generation_model(args, GenerationType.GREEDYSEARCH) else: convert_generation_model(args)