fix greedysearch token out of range bug (#14242)

Bug: the last sentence generates token out of vocabulary size.
Cause: total element should be computed with padded vocabulary size.
This commit is contained in:
Yufeng Li 2023-01-12 09:06:05 -08:00 committed by GitHub
parent 5c16e0befb
commit 8f7eb75c3e
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -181,7 +181,7 @@ void LaunchLogitsProcessKernel(
float repetition_penalty,
int no_repeat_ngram_size,
cudaStream_t stream) {
int total_elements = batch_size * num_beams * vocab_size;
int total_elements = batch_size * num_beams * padded_vocab_size;
constexpr int blockSize = 256;
const int gridSize = (total_elements + blockSize - 1) / blockSize;
LogitsProcessKernel<T><<<gridSize, blockSize, 0, stream>>>(