mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-21 21:52:11 +00:00
initialize cache_indir explicitly in beamsearch with encoder decoder model (#15667)
This commit is contained in:
parent
e1755541cc
commit
d00197aaa7
7 changed files with 38 additions and 7 deletions
|
|
@ -264,6 +264,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
|
||||
init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
|
||||
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
|
||||
|
|
@ -285,6 +286,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
|
||||
init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_fp16_func_,
|
||||
init_beam_state_fp16_func_,
|
||||
|
|
@ -312,7 +314,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
|
||||
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
nullptr,
|
||||
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
|
||||
init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
|
||||
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
|
||||
|
|
@ -323,8 +326,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
|
||||
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
|
||||
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
|
||||
nullptr,
|
||||
0};
|
||||
cuda_device_prop_,
|
||||
cuda_device_arch_};
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
|
||||
|
|
@ -333,7 +336,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
|
||||
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
|
||||
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
|
||||
nullptr,
|
||||
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
|
||||
init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now
|
||||
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
|
||||
process_logits_fp16_func_,
|
||||
init_beam_state_fp16_func_,
|
||||
|
|
@ -344,8 +348,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
|
|||
expand_buffer_int32_func_,
|
||||
expand_buffer_float_func_,
|
||||
expand_buffer_float16_func_,
|
||||
nullptr,
|
||||
0};
|
||||
cuda_device_prop_,
|
||||
cuda_device_arch_};
|
||||
|
||||
ORT_RETURN_IF_ERROR(impl.Initialize());
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ class BeamSearch : public IControlFlowKernel {
|
|||
// device helpers that is same for both GPT and encoder-decoder models.
|
||||
void SetDeviceHelpers(
|
||||
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
|
||||
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func,
|
||||
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const GenerationDeviceHelper::TopkFunc& topk_func,
|
||||
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
|
||||
|
|
@ -54,6 +55,7 @@ class BeamSearch : public IControlFlowKernel {
|
|||
const GenerationDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
|
||||
const GenerationDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_fp16_func) {
|
||||
reorder_past_state_func_ = reorder_past_state_func;
|
||||
init_cache_indir_func_ = init_cache_indir_func;
|
||||
add_to_feeds_func_ = add_to_feeds_func;
|
||||
topk_func_ = topk_func;
|
||||
device_copy_func_ = device_copy_func;
|
||||
|
|
@ -91,6 +93,7 @@ class BeamSearch : public IControlFlowKernel {
|
|||
private:
|
||||
// Device specific functions
|
||||
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
|
||||
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
|
||||
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
GenerationDeviceHelper::TopkFunc topk_func_;
|
||||
GenerationDeviceHelper::DeviceCopyFunc<float> device_copy_func_;
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
|
|||
BeamSearchParameters& params,
|
||||
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
|
||||
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
|
||||
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func,
|
||||
const GenerationDeviceHelper::TopkFunc& topk_func,
|
||||
const GenerationDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
|
||||
const GenerationDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
|
||||
|
|
@ -50,6 +51,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
|
|||
add_to_feeds_func_(add_to_feeds_func),
|
||||
init_beam_state_func_(init_beam_state_func),
|
||||
reorder_past_state_func_(reorder_past_state_func),
|
||||
init_cache_indir_func_(init_cache_indir_func),
|
||||
create_encoder_inputs_func_(create_encoder_inputs_func),
|
||||
update_decoder_feeds_func_(update_decoder_feeds_func),
|
||||
expand_buffer_int32_func_(expand_buffer_int32_func),
|
||||
|
|
@ -80,6 +82,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
|
|||
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
|
||||
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
|
||||
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
|
||||
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
|
||||
GenerationDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_;
|
||||
GenerationDeviceHelper::UpdateDecoderFeedsFunc<T> update_decoder_feeds_func_;
|
||||
GenerationDeviceHelper::ExpandBufferFunc<int32_t> expand_buffer_int32_func_;
|
||||
|
|
@ -284,6 +287,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
|
|||
beam_state.staging_for_past_state_reorder,
|
||||
this->ort_stream_));
|
||||
}
|
||||
size_t cache_indir_input_offset = static_cast<size_t>(decoder_subgraph_.GetFirstPastInputIndex()) + 4 * static_cast<size_t>(decoder_subgraph_.num_layers) + 2;
|
||||
ORT_RETURN_IF_ERROR(init_cache_indir_func_(*decoder_feeds[cache_indir_input_offset].GetMutable<Tensor>(), this->ort_stream_));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -302,7 +307,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
|
|||
dumper->Print("", decoder_feeds[offset]);
|
||||
dumper->Print("beam_width", offset + 1, true);
|
||||
dumper->Print("", decoder_feeds[offset + 1]);
|
||||
dumper->Print("past_sequence_length", offset + 2, true);
|
||||
dumper->Print("cache_redir", offset + 2, true);
|
||||
dumper->Print("", decoder_feeds[offset + 2]);
|
||||
#endif
|
||||
|
||||
|
|
|
|||
|
|
@ -39,6 +39,10 @@ using ReorderPastStateFunc = std::function<Status(
|
|||
Tensor& past_state_staging,
|
||||
Stream* stream)>; // cublasHandle_t
|
||||
|
||||
using InitCacheIndirFunc = std::function<Status(
|
||||
Tensor& cache_indir,
|
||||
Stream* stream)>;
|
||||
|
||||
using TopkFunc = std::function<Status(
|
||||
const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
|
|
|
|||
|
|
@ -36,6 +36,7 @@ transformers::CudaTensorConsoleDumper g_cuda_dumper;
|
|||
BeamSearch::BeamSearch(const OpKernelInfo& info)
|
||||
: onnxruntime::contrib::transformers::BeamSearch(info) {
|
||||
SetDeviceHelpers(GenerationCudaDeviceHelper::ReorderPastState,
|
||||
GenerationCudaDeviceHelper::InitCacheIndir,
|
||||
GenerationCudaDeviceHelper::AddToFeeds,
|
||||
GenerationCudaDeviceHelper::TopK,
|
||||
GenerationCudaDeviceHelper::DeviceCopy<float>,
|
||||
|
|
|
|||
|
|
@ -102,6 +102,16 @@ Status ReorderPastState(
|
|||
&transpose_output_shape_override);
|
||||
}
|
||||
|
||||
Status InitCacheIndir(Tensor& cache_indir, Stream* stream) {
|
||||
ORT_ENFORCE(stream);
|
||||
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream->GetHandle());
|
||||
|
||||
// Initialize the cache_indir tensor to all 0s
|
||||
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(cache_indir.MutableDataRaw(), 0, cache_indir.SizeInBytes(), cuda_stream));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
Stream* stream,
|
||||
|
|
|
|||
|
|
@ -28,6 +28,10 @@ Status ReorderPastState(
|
|||
Tensor& past_state_staging,
|
||||
Stream* stream);
|
||||
|
||||
Status InitCacheIndir(
|
||||
Tensor& cache_indir,
|
||||
Stream* stream);
|
||||
|
||||
Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
|
||||
AllocatorPtr allocator,
|
||||
Stream* stream,
|
||||
|
|
|
|||
Loading…
Reference in a new issue