initialize cache_indir explicitly in beamsearch with encoder decoder model (#15667)

This commit is contained in:
Ye Wang 2023-04-25 11:05:21 -07:00 committed by GitHub
parent e1755541cc
commit d00197aaa7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 38 additions and 7 deletions

View file

@ -264,6 +264,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
@ -285,6 +286,7 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
@ -312,7 +314,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
nullptr,
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_func_ ? process_logits_func_ : GenerationCpuDeviceHelper::ProcessLogits<float>,
init_beam_state_func_ ? init_beam_state_func_ : GenerationCpuDeviceHelper::InitBeamState<float>,
@ -323,8 +326,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_int32_func_ ? expand_buffer_int32_func_ : GenerationCpuDeviceHelper::ExpandBuffer<int32_t>,
expand_buffer_float_func_ ? expand_buffer_float_func_ : GenerationCpuDeviceHelper::ExpandBuffer<float>,
expand_buffer_float16_func_ ? expand_buffer_float16_func_ : GenerationCpuDeviceHelper::ExpandBuffer<MLFloat16>,
nullptr,
0};
cuda_device_prop_,
cuda_device_arch_};
ORT_RETURN_IF_ERROR(impl.Initialize());
return impl.Execute(*encoder_feeds_fetches_manager_, *decoder_feeds_fetches_manager_);
@ -333,7 +336,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
*ctx_internal, *encoder_session_state, *decoder_session_state, *t5_encoder_subgraph_,
*t5_decoder_subgraph_, thread_pool, ctx->GetComputeStream(), dumper_, parameters,
add_to_feeds_func_ ? add_to_feeds_func_ : GenerationCpuDeviceHelper::AddToFeeds,
nullptr,
reorder_past_state_func_ ? reorder_past_state_func_ : nullptr, // Only CUDA implementation needs the reorder helper for now
init_cache_indir_func_ ? init_cache_indir_func_ : nullptr, // Only CUDA implementation needs the init cache_indir for now
topk_func_ ? topk_func_ : GenerationCpuDeviceHelper::TopK,
process_logits_fp16_func_,
init_beam_state_fp16_func_,
@ -344,8 +348,8 @@ Status BeamSearch::Compute(OpKernelContext* ctx) const {
expand_buffer_int32_func_,
expand_buffer_float_func_,
expand_buffer_float16_func_,
nullptr,
0};
cuda_device_prop_,
cuda_device_arch_};
ORT_RETURN_IF_ERROR(impl.Initialize());

View file

@ -45,6 +45,7 @@ class BeamSearch : public IControlFlowKernel {
// device helpers that is same for both GPT and encoder-decoder models.
void SetDeviceHelpers(
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func,
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::DeviceCopyFunc<float>& device_copy_func,
@ -54,6 +55,7 @@ class BeamSearch : public IControlFlowKernel {
const GenerationDeviceHelper::InitBeamStateFunc<float>& init_beam_state_func,
const GenerationDeviceHelper::InitBeamStateFunc<MLFloat16>& init_beam_state_fp16_func) {
reorder_past_state_func_ = reorder_past_state_func;
init_cache_indir_func_ = init_cache_indir_func;
add_to_feeds_func_ = add_to_feeds_func;
topk_func_ = topk_func;
device_copy_func_ = device_copy_func;
@ -91,6 +93,7 @@ class BeamSearch : public IControlFlowKernel {
private:
// Device specific functions
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::TopkFunc topk_func_;
GenerationDeviceHelper::DeviceCopyFunc<float> device_copy_func_;

View file

@ -29,6 +29,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
BeamSearchParameters& params,
const GenerationDeviceHelper::AddToFeedsFunc& add_to_feeds_func,
const GenerationDeviceHelper::ReorderPastStateFunc& reorder_past_state_func,
const GenerationDeviceHelper::InitCacheIndirFunc& init_cache_indir_func,
const GenerationDeviceHelper::TopkFunc& topk_func,
const GenerationDeviceHelper::ProcessLogitsFunc<T>& process_logits_func,
const GenerationDeviceHelper::InitBeamStateFunc<T>& init_beam_state_func,
@ -50,6 +51,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
add_to_feeds_func_(add_to_feeds_func),
init_beam_state_func_(init_beam_state_func),
reorder_past_state_func_(reorder_past_state_func),
init_cache_indir_func_(init_cache_indir_func),
create_encoder_inputs_func_(create_encoder_inputs_func),
update_decoder_feeds_func_(update_decoder_feeds_func),
expand_buffer_int32_func_(expand_buffer_int32_func),
@ -80,6 +82,7 @@ class BeamSearchT5 : public BeamSearchBase<T> {
GenerationDeviceHelper::AddToFeedsFunc add_to_feeds_func_;
GenerationDeviceHelper::InitBeamStateFunc<T> init_beam_state_func_;
GenerationDeviceHelper::ReorderPastStateFunc reorder_past_state_func_;
GenerationDeviceHelper::InitCacheIndirFunc init_cache_indir_func_;
GenerationDeviceHelper::CreateEncoderInputsFunc create_encoder_inputs_func_;
GenerationDeviceHelper::UpdateDecoderFeedsFunc<T> update_decoder_feeds_func_;
GenerationDeviceHelper::ExpandBufferFunc<int32_t> expand_buffer_int32_func_;
@ -284,6 +287,8 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
beam_state.staging_for_past_state_reorder,
this->ort_stream_));
}
size_t cache_indir_input_offset = static_cast<size_t>(decoder_subgraph_.GetFirstPastInputIndex()) + 4 * static_cast<size_t>(decoder_subgraph_.num_layers) + 2;
ORT_RETURN_IF_ERROR(init_cache_indir_func_(*decoder_feeds[cache_indir_input_offset].GetMutable<Tensor>(), this->ort_stream_));
}
}
@ -302,7 +307,7 @@ Status BeamSearchT5<T>::Execute(const FeedsFetchesManager& encoder_feeds_fetches
dumper->Print("", decoder_feeds[offset]);
dumper->Print("beam_width", offset + 1, true);
dumper->Print("", decoder_feeds[offset + 1]);
dumper->Print("past_sequence_length", offset + 2, true);
dumper->Print("cache_redir", offset + 2, true);
dumper->Print("", decoder_feeds[offset + 2]);
#endif

View file

@ -39,6 +39,10 @@ using ReorderPastStateFunc = std::function<Status(
Tensor& past_state_staging,
Stream* stream)>; // cublasHandle_t
using InitCacheIndirFunc = std::function<Status(
Tensor& cache_indir,
Stream* stream)>;
using TopkFunc = std::function<Status(
const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,

View file

@ -36,6 +36,7 @@ transformers::CudaTensorConsoleDumper g_cuda_dumper;
BeamSearch::BeamSearch(const OpKernelInfo& info)
: onnxruntime::contrib::transformers::BeamSearch(info) {
SetDeviceHelpers(GenerationCudaDeviceHelper::ReorderPastState,
GenerationCudaDeviceHelper::InitCacheIndir,
GenerationCudaDeviceHelper::AddToFeeds,
GenerationCudaDeviceHelper::TopK,
GenerationCudaDeviceHelper::DeviceCopy<float>,

View file

@ -102,6 +102,16 @@ Status ReorderPastState(
&transpose_output_shape_override);
}
Status InitCacheIndir(Tensor& cache_indir, Stream* stream) {
ORT_ENFORCE(stream);
cudaStream_t cuda_stream = reinterpret_cast<cudaStream_t>(stream->GetHandle());
// Initialize the cache_indir tensor to all 0s
CUDA_RETURN_IF_ERROR(cudaMemsetAsync(cache_indir.MutableDataRaw(), 0, cache_indir.SizeInBytes(), cuda_stream));
return Status::OK();
}
Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
Stream* stream,

View file

@ -28,6 +28,10 @@ Status ReorderPastState(
Tensor& past_state_staging,
Stream* stream);
Status InitCacheIndir(
Tensor& cache_indir,
Stream* stream);
Status TopK(const Tensor* input, const int axis, const unsigned k, bool largest, bool sorted,
AllocatorPtr allocator,
Stream* stream,