mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-01 03:45:06 +00:00
fix prefast warning for GenerationCudaDeviceHelper::ProcessLogits (#15163)
This commit is contained in:
parent
c08d6b42e8
commit
c68044cc4b
1 changed files with 4 additions and 3 deletions
|
|
@ -455,11 +455,12 @@ Status ProcessLogits(const OrtValue& logits, //
|
|||
|
||||
if (num_beams <= 32) {
|
||||
constexpr size_t max_parts_of_vocab = 128;
|
||||
size_t candidate_count = SafeInt<size_t>(batch_beam_size) * 2 * num_beams;
|
||||
float* topk_tmp_buffer = beam_state->topk_buffer.data();
|
||||
float* topk_scores_1st_stage = topk_tmp_buffer;
|
||||
int32_t* topk_tokens_1st_stage = reinterpret_cast<int32_t*>(topk_scores_1st_stage + batch_beam_size * max_parts_of_vocab * 2 * num_beams);
|
||||
float* topk_scores_2nd_stage = reinterpret_cast<float*>(topk_tokens_1st_stage + batch_beam_size * max_parts_of_vocab * 2 * num_beams);
|
||||
int32_t* topk_tokens_2nd_stage = reinterpret_cast<int32_t*>(topk_scores_2nd_stage + batch_beam_size * 2 * num_beams);
|
||||
int32_t* topk_tokens_1st_stage = reinterpret_cast<int32_t*>(topk_scores_1st_stage + candidate_count * max_parts_of_vocab);
|
||||
float* topk_scores_2nd_stage = reinterpret_cast<float*>(topk_tokens_1st_stage + candidate_count * max_parts_of_vocab);
|
||||
int32_t* topk_tokens_2nd_stage = reinterpret_cast<int32_t*>(topk_scores_2nd_stage + candidate_count);
|
||||
|
||||
cuda::BeamSearchTopK(next_token_scores.data(),
|
||||
batch_size,
|
||||
|
|
|
|||
Loading…
Reference in a new issue