diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index bf98394d1c..600eb50648 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -181,7 +181,7 @@ void LaunchLogitsProcessKernel( float repetition_penalty, int no_repeat_ngram_size, cudaStream_t stream) { - int total_elements = batch_size * num_beams * vocab_size; + int total_elements = batch_size * num_beams * padded_vocab_size; constexpr int blockSize = 256; const int gridSize = (total_elements + blockSize - 1) / blockSize; LogitsProcessKernel<<>>(