fix prefast warning for GenerationCudaDeviceHelper::ProcessLogits (#15163)

This commit is contained in:
Yufeng Li 2023-03-31 08:50:53 -07:00 committed by GitHub
parent c08d6b42e8
commit c68044cc4b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -455,11 +455,12 @@ Status ProcessLogits(const OrtValue& logits, //
if (num_beams <= 32) {
constexpr size_t max_parts_of_vocab = 128;
size_t candidate_count = SafeInt<size_t>(batch_beam_size) * 2 * num_beams;
float* topk_tmp_buffer = beam_state->topk_buffer.data();
float* topk_scores_1st_stage = topk_tmp_buffer;
int32_t* topk_tokens_1st_stage = reinterpret_cast<int32_t*>(topk_scores_1st_stage + batch_beam_size * max_parts_of_vocab * 2 * num_beams);
float* topk_scores_2nd_stage = reinterpret_cast<float*>(topk_tokens_1st_stage + batch_beam_size * max_parts_of_vocab * 2 * num_beams);
int32_t* topk_tokens_2nd_stage = reinterpret_cast<int32_t*>(topk_scores_2nd_stage + batch_beam_size * 2 * num_beams);
int32_t* topk_tokens_1st_stage = reinterpret_cast<int32_t*>(topk_scores_1st_stage + candidate_count * max_parts_of_vocab);
float* topk_scores_2nd_stage = reinterpret_cast<float*>(topk_tokens_1st_stage + candidate_count * max_parts_of_vocab);
int32_t* topk_tokens_2nd_stage = reinterpret_cast<int32_t*>(topk_scores_2nd_stage + candidate_count);
cuda::BeamSearchTopK(next_token_scores.data(),
batch_size,