From c68044cc4b29b6231967a7ca00b73e3e1b4ea747 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Fri, 31 Mar 2023 08:50:53 -0700 Subject: [PATCH] fix prefast warning for GenerationCudaDeviceHelper::ProcessLogits (#15163) --- .../cuda/transformers/generation_device_helper.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index 67f370665f..cc9068db28 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -455,11 +455,12 @@ Status ProcessLogits(const OrtValue& logits, // if (num_beams <= 32) { constexpr size_t max_parts_of_vocab = 128; + size_t candidate_count = SafeInt(batch_beam_size) * 2 * num_beams; float* topk_tmp_buffer = beam_state->topk_buffer.data(); float* topk_scores_1st_stage = topk_tmp_buffer; - int32_t* topk_tokens_1st_stage = reinterpret_cast(topk_scores_1st_stage + batch_beam_size * max_parts_of_vocab * 2 * num_beams); - float* topk_scores_2nd_stage = reinterpret_cast(topk_tokens_1st_stage + batch_beam_size * max_parts_of_vocab * 2 * num_beams); - int32_t* topk_tokens_2nd_stage = reinterpret_cast(topk_scores_2nd_stage + batch_beam_size * 2 * num_beams); + int32_t* topk_tokens_1st_stage = reinterpret_cast(topk_scores_1st_stage + candidate_count * max_parts_of_vocab); + float* topk_scores_2nd_stage = reinterpret_cast(topk_tokens_1st_stage + candidate_count * max_parts_of_vocab); + int32_t* topk_tokens_2nd_stage = reinterpret_cast(topk_scores_2nd_stage + candidate_count); cuda::BeamSearchTopK(next_token_scores.data(), batch_size,