diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc index e480c10bf5..80cae088d1 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.cc @@ -264,6 +264,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now + init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits, init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState, @@ -285,6 +286,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now + init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, process_logits_fp16_func_, init_beam_state_fp16_func_, @@ -312,7 +314,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, - nullptr, + reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now + init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits, init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState, @@ -323,8 +326,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer, expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer, - nullptr, - 0}; + cuda_device_prop_, + cuda_device_arch_}; ORT_RETURN_IF_ERROR(impl.Initialize()); return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_); @@ -333,7 +336,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { *ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_, *t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters, add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds, - nullptr, + reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now + init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK, process_logits_fp16_func_, init_beam_state_fp16_func_, @@ -344,8 +348,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const { expand_buffer_int32_func_, expand_buffer_float_func_, expand_buffer_float16_func_, - nullptr, - 0}; + cuda_device_prop_, + cuda_device_arch_}; ORT_RETURN_IF_ERROR(impl.Initialize()); diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h index 37b94bd8f1..63b0418a12 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search.h @@ -45,6 +45,7 @@ class BeamSearch : public IControlFlowKernel { // device helpers that is same for both GPT and encoder-decoder models. void SetDeviceHelpers( const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, + const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func, const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, const GenerationDeviceHelper::TopkFunc& topk_func, const GenerationDeviceHelper::DeviceCopyFunc& device_copy_func, @@ -54,6 +55,7 @@ class BeamSearch : public IControlFlowKernel { const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_func, const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_fp16_func) { reorder_past_state_func_ = reorder_past_state_func; + init_cache_indir_func_ = init_cache_indir_func; add_to_feeds_func_ = add_to_feeds_func; topk_func_ = topk_func; device_copy_func_ = device_copy_func; @@ -91,6 +93,7 @@ class BeamSearch : public IControlFlowKernel { private: // Device specific functions GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; + GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_; GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::TopkFunc topk_func_; GenerationDeviceHelper::DeviceCopyFunc device_copy_func_; diff --git a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h index d4021de95e..51e8ae7b13 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h +++ b/onnxruntime/contrib_ops/cpu/transformers/beam_search_impl_t5.h @@ -29,6 +29,7 @@ class BeamSearchT5 : public BeamSearchBase { BeamSearchParameters& params, const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func, const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func, + const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func, const GenerationDeviceHelper::TopkFunc& topk_func, const GenerationDeviceHelper::ProcessLogitsFunc& process_logits_func, const GenerationDeviceHelper::InitBeamStateFunc& init_beam_state_func, @@ -50,6 +51,7 @@ class BeamSearchT5 : public BeamSearchBase { add_to_feeds_func_(add_to_feeds_func), init_beam_state_func_(init_beam_state_func), reorder_past_state_func_(reorder_past_state_func), + init_cache_indir_func_(init_cache_indir_func), create_encoder_inputs_func_(create_encoder_inputs_func), update_decoder_feeds_func_(update_decoder_feeds_func), expand_buffer_int32_func_(expand_buffer_int32_func), @@ -80,6 +82,7 @@ class BeamSearchT5 : public BeamSearchBase { GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_; GenerationDeviceHelper::InitBeamStateFunc init_beam_state_func_; GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_; + GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_; GenerationDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_; GenerationDeviceHelper::UpdateDecoderFeedsFunc update_decoder_feeds_func_; GenerationDeviceHelper::ExpandBufferFunc expand_buffer_int32_func_; @@ -284,6 +287,8 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches beam_state.staging_for_past_state_reorder, this->ort_stream_)); } + size_t cache_indir_input_offset = static_cast(decoder_subgraph_.GetFirstPastInputIndex()) + 4 * static_cast(decoder_subgraph_.num_layers) + 2; + ORT_RETURN_IF_ERROR(init_cache_indir_func_(*decoder_feeds[cache_indir_input_offset].GetMutable(), this->ort_stream_)); } } @@ -302,7 +307,7 @@ Status BeamSearchT5::Execute(const FeedsFetchesManager& encoder_feeds_fetches dumper->Print("", decoder_feeds[offset]); dumper->Print("beam_width", offset + 1, true); dumper->Print("", decoder_feeds[offset + 1]); - dumper->Print("past_sequence_length", offset + 2, true); + dumper->Print("cache_redir", offset + 2, true); dumper->Print("", decoder_feeds[offset + 2]); #endif diff --git a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h index 9a0aa241cf..782795eaad 100644 --- a/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cpu/transformers/generation_device_helper.h @@ -39,6 +39,10 @@ using ReorderPastStateFunc = std::function; // cublasHandle_t +using InitCacheIndirFunc = std::function; + using TopkFunc = std::function, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index f791abb2ce..b51c022c3e 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -102,6 +102,16 @@ Status ReorderPastState( &transpose_output_shape_override); } +Status InitCacheIndir(Tensor& cache_indir, Stream* stream) { + ORT_ENFORCE(stream); + cudaStream_t cuda_stream = reinterpret_cast(stream->GetHandle()); + + // Initialize the cache_indir tensor to all 0s + CUDA_RETURN_IF_ERROR(cudaMemsetAsync(cache_indir.MutableDataRaw(), 0, cache_indir.SizeInBytes(), cuda_stream)); + + return Status::OK(); +} + Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted, AllocatorPtr allocator, Stream* stream, diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h index ba6a2884f1..5ed956f9a2 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.h @@ -28,6 +28,10 @@ Status ReorderPastState( Tensor& past_state_staging, Stream* stream); +Status InitCacheIndir( + Tensor& cache_indir, + Stream* stream); + Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted, AllocatorPtr allocator, Stream* stream,