[CUDA] Fix BeamSearchTest.DummyT5WithSequenceInputIds test failure in Windows (#23596)

### Description
BeamSearchTest.DummyT5WithSequenceInputIds failed in Windows due to
early stopping triggered. The cause is state.early_stopping_ is
interpreted as true in cuda kernel at some point, however printf still
show its value is false. The root cause is unknown.

Update the code to use early_stopping as template parameter seems walk
around the issue.

Other changes: 
* Add some debug code (will not be built into binary unless
DEBUG_GENERATION is fined) to assist debugging beam search scorer in
CUDA.
* Enable DummyT5WithSequenceInputIds test in CI. This test was not run
in Windows CUDA CI pipeline previously.

### Motivation and Context

Fix a unit test BeamSearchTest.DummyT5WithSequenceInputIds failure in
Windows.
This commit is contained in:
Tianlei Wu 2025-02-06 13:15:09 -08:00 committed by GitHub
parent d981b153d3
commit 2c2ff4aef9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 181 additions and 90 deletions

View file

@ -8,7 +8,7 @@
#include <gsl/gsl>
#include "core/framework/allocator.h"
#include "core/framework/ort_value.h"
#include "contrib_ops/cpu/utils/debug_macros.h"
#include "contrib_ops/cpu/utils/console_dumper.h"
namespace onnxruntime {

View file

@ -1,11 +1,10 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// cub.cuh includes device/dispatch_radix_sort.cuh which has assignment in conditional expressions
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4706)
#pragma warning(disable : 4706)
#endif
#include <cub/cub.cuh>
#if defined(_MSC_VER)
@ -323,8 +322,8 @@ __device__ void BeamHypotheses::Output(
int top_k,
int max_length,
int pad_token_id,
int32_t* sequences, // buffer of shape (num_return_sequences, max_length)
T* sequences_scores) // buffer of shape (num_return_sequences) or empty
int32_t* sequences, // buffer of shape (num_return_sequences, max_length)
T* sequences_scores) // buffer of shape (num_return_sequences) or empty
{
// Copy the top_k beams into the sequences
for (int index = 0; index < top_k; index++) {
@ -343,6 +342,7 @@ __device__ void BeamHypotheses::Output(
}
}
template <bool early_stopping>
__global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
BeamScorerState& state,
const int32_t* sequences_buffer,
@ -358,24 +358,26 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
// Sequences shape is (batch_size * num_beams, total_sequence_length)
// It contains word ID of whole sequence generated so far.
// It is different from subgraph input_ids, which only need one word when past state is not empty.
int batch = threadIdx.x;
int batch_start = batch * state.num_beams_;
const int batch = threadIdx.x;
const int num_beams = state.num_beams_;
const int batch_start = batch * num_beams;
cuda::BeamHypotheses& beam_hyp = beam_hyps_[batch];
if (!beam_hyp.done_) {
// Next tokens for this sentence.
size_t beam_idx = 0;
size_t top_k = 2 * state.num_beams_;
size_t top_k = 2 * num_beams;
for (size_t j = 0; j < top_k; j++) {
int32_t next_token = next_tokens[batch * top_k + j];
float next_score = next_scores[batch * top_k + j];
int32_t next_index = next_indices[batch * top_k + j];
int batch_beam_idx = batch_start + next_index;
// Add to generated hypotheses if end of sentence.
if ((state.eos_token_id_ >= 0) && (next_token == state.eos_token_id_)) {
bool is_beam_token_worse_than_top_num_beams = (j >= state.num_beams_);
bool is_beam_token_worse_than_top_num_beams = (j >= num_beams);
if (is_beam_token_worse_than_top_num_beams) {
continue;
}
@ -396,21 +398,27 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
}
// Once the beam for next step is full, don't add more tokens to it.
if (beam_idx == state.num_beams_)
if (beam_idx == num_beams)
break;
}
// Check if we are done so that we can save a pad step if all(done)
if (beam_hyp.beams_used_ == state.num_beams_) {
if (state.early_stopping_ || !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length)) {
beam_hyp.done_ = true;
if (atomicAdd(&state.not_done_count_, -1) == 1)
state_cpu.not_done_count_ = 0; // Update the CPU side
if (beam_hyp.beams_used_ == num_beams) {
if constexpr (!early_stopping) {
float best_sum_logprobs = *std::max_element(next_scores + batch_start, next_scores + batch_start + top_k);
if (beam_hyp.CanImprove(best_sum_logprobs, sequence_length)) {
return;
}
}
beam_hyp.done_ = true;
if (atomicAdd(&(state.not_done_count_), -1) == 1) {
state_cpu.not_done_count_ = 0; // Update the CPU side
}
}
} else {
// Pad the batch.
for (size_t beam_idx = 0; beam_idx < state.num_beams_; beam_idx++) {
for (size_t beam_idx = 0; beam_idx < num_beams; beam_idx++) {
next_beam_scores_[batch_start + beam_idx] = 0.0f;
next_beam_tokens_[batch_start + beam_idx] = state.pad_token_id_;
next_beam_indices_[batch_start + beam_idx] = 0;
@ -418,6 +426,31 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
}
}
#ifdef DEBUG_GENERATION
__global__ void DumpBeamScorerState(const BeamScorerState& state) {
state.Print(false);
}
void DumpBeamScorerStates(const BeamScorerState& state_cpu, const BeamScorerState& state, cudaStream_t stream) {
state_cpu.Print(true);
DumpBeamScorerState<<<1, 1, 0, stream>>>(state);
cudaDeviceSynchronize();
}
__global__ void DumpBeamSearchScorer(const BeamHypotheses& beam_hyp) {
beam_hyp.Print();
}
void DumpBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, cudaStream_t stream) {
printf("\n BeamHypotheses of size %zu: \n", beam_hyps.size());
for (size_t i = 0; i < beam_hyps.size(); i++) {
printf("\n [%zu]:\n", i);
DumpBeamSearchScorer<<<1, 1, 0, stream>>>(beam_hyps[i]);
cudaDeviceSynchronize();
}
}
#endif
void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
BeamScorerState& state,
gsl::span<const int32_t> sequences,
@ -431,18 +464,46 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
gsl::span<const int32_t> next_tokens,
gsl::span<const int32_t> next_indices,
cudaStream_t stream) {
BeamSearchScorer_Process<<<1, state_cpu.batch_size_, 0, stream>>>(state_cpu,
state,
sequences.data(),
sequence_length,
beam_hyps.data(),
next_beam_scores.data(),
next_beam_tokens.data(),
next_beam_indices.data(),
hypothesis_buffer.data(),
next_scores.data(),
next_tokens.data(),
next_indices.data());
#ifdef DEBUG_GENERATION
printf("\n Before BeamSearchScorer_Process: \n");
DumpBeamHypotheses(beam_hyps, stream);
DumpBeamScorerStates(state_cpu, state, stream);
#endif
if (state_cpu.early_stopping_) {
BeamSearchScorer_Process<true><<<1, state_cpu.batch_size_, 0, stream>>>(state_cpu,
state,
sequences.data(),
sequence_length,
beam_hyps.data(),
next_beam_scores.data(),
next_beam_tokens.data(),
next_beam_indices.data(),
hypothesis_buffer.data(),
next_scores.data(),
next_tokens.data(),
next_indices.data());
} else {
BeamSearchScorer_Process<false><<<1, state_cpu.batch_size_, 0, stream>>>(state_cpu,
state,
sequences.data(),
sequence_length,
beam_hyps.data(),
next_beam_scores.data(),
next_beam_tokens.data(),
next_beam_indices.data(),
hypothesis_buffer.data(),
next_scores.data(),
next_tokens.data(),
next_indices.data());
}
#ifdef DEBUG_GENERATION
cudaDeviceSynchronize();
printf("\n After BeamSearchScorer_Process: \n");
DumpBeamHypotheses(beam_hyps, stream);
DumpBeamScorerStates(state_cpu, state, stream);
#endif
}
__global__ void BeamSearchScorer_AppendNextTokenToSequences1(BeamScorerState& state,
@ -600,14 +661,14 @@ template <typename T>
void LaunchBeamSearchScoreCopy(gsl::span<const float> final_scores,
gsl::span<T> output_scores,
cudaStream_t stream) {
ORT_ENFORCE(final_scores.size() == output_scores.size());
constexpr unsigned ThreadPerBlock = 256;
unsigned num_blocks = (unsigned)((final_scores.size() + (ThreadPerBlock - 1))/ ThreadPerBlock);
ORT_ENFORCE(final_scores.size() == output_scores.size());
constexpr unsigned ThreadPerBlock = 256;
unsigned num_blocks = (unsigned)((final_scores.size() + (ThreadPerBlock - 1)) / ThreadPerBlock);
typedef typename ToCudaType<float>::MappedType CudaT;
typedef typename ToCudaType<float>::MappedType CudaT;
FloatConvertAndCopyKernel<<<num_blocks, ThreadPerBlock, 0, stream>>>(
final_scores.data(), (CudaT*)output_scores.data(), final_scores.size());
FloatConvertAndCopyKernel<<<num_blocks, ThreadPerBlock, 0, stream>>>(
final_scores.data(), (CudaT*)output_scores.data(), final_scores.size());
}
template void LaunchBeamSearchScoreCopy(gsl::span<const float> final_scores,
@ -1444,15 +1505,14 @@ void ReorderPastStatesKernelLauncher(void* out_buffer,
template <typename T>
__global__ void CopyCrossQKSingleDecodeStepKernel(
T* target, // shape [batchxbeam, layer_head_pair_count, max_length, frame]
T* target, // shape [batchxbeam, layer_head_pair_count, max_length, frame]
T** qk_layer_pointers,
int token_index,
int num_layers,
int num_heads,
const int* cross_qk_layer_head_pairs,
int frames,
int max_length
) {
int max_length) {
const int pair = blockIdx.x;
const int layer_head_pair_count = gridDim.x;
const int bbm = blockIdx.y;
@ -1464,7 +1524,7 @@ __global__ void CopyCrossQKSingleDecodeStepKernel(
T* src = qk_layer_pointers[layer] + ((int64_t)bbm * num_heads + head) * frames;
for (int tid = threadIdx.x; tid < frames; tid += blockDim.x) {
target[tid] = src[tid]; // use vectorized read write in future if needed
target[tid] = src[tid]; // use vectorized read write in future if needed
}
}
@ -1479,8 +1539,7 @@ void LaunchCopyCrossQKSingleDecodeStep(
int cross_qk_layer_head_pair_count,
const int* cross_qk_layer_head_pairs,
int frames,
int max_length
) {
int max_length) {
dim3 block(512);
dim3 grid(cross_qk_layer_head_pair_count, batchxbeam);
typedef typename ToCudaType<float>::MappedType CudaT;
@ -1493,11 +1552,9 @@ void LaunchCopyCrossQKSingleDecodeStep(
num_heads,
cross_qk_layer_head_pairs,
frames,
max_length
);
max_length);
}
template <typename T>
__global__ void CopyDecoderCrossQKAllStepsKernel(
int context_decoding_len,
@ -1505,11 +1562,10 @@ __global__ void CopyDecoderCrossQKAllStepsKernel(
int num_return_sequences,
int max_length,
int frames_of_k,
const T* cross_qk_buffer_data, // [batch, num_beams, layer_head_pair_count, max_length, frames]
T* cross_qk_output, // [batch, num_return_sequences, layer_head_pair_count, total_decoding_length, frames]
const int* cache_indir_data, // [batch, num_beams, max_length]
const int32_t* beam_indices
) {
const T* cross_qk_buffer_data, // [batch, num_beams, layer_head_pair_count, max_length, frames]
T* cross_qk_output, // [batch, num_return_sequences, layer_head_pair_count, total_decoding_length, frames]
const int* cache_indir_data, // [batch, num_beams, max_length]
const int32_t* beam_indices) {
const int pair = blockIdx.y;
const int layer_head_pair_count = gridDim.y;
const int total_decoding_length = gridDim.x;
@ -1522,15 +1578,15 @@ __global__ void CopyDecoderCrossQKAllStepsKernel(
const int src_beam = beam_indices[batch * num_beams + ret_seq_id] % num_beams;
const int64_t offset_in_cache = ((int64_t)batch * num_beams + src_beam) * max_length + token_decoding_index + context_decoding_len;
int bm_mapped = ((num_beams <= 1) ? 0: ((token_decoding_index == total_decoding_length - 1) ? ret_seq_id : cache_indir_data[offset_in_cache]));
int bm_mapped = ((num_beams <= 1) ? 0 : ((token_decoding_index == total_decoding_length - 1) ? ret_seq_id : cache_indir_data[offset_in_cache]));
int bi_src = batch * num_beams + bm_mapped;
T* target = cross_qk_output +
(((int64_t)br * layer_head_pair_count + (int64_t)pair) * total_decoding_length + token_decoding_index) * frames_of_k;
T* target = cross_qk_output +
(((int64_t)br * layer_head_pair_count + (int64_t)pair) * total_decoding_length + token_decoding_index) * frames_of_k;
const T* src = cross_qk_buffer_data +
((int64_t)bi_src * layer_head_pair_count * max_length + (int64_t)pair * max_length + token_decoding_index) * frames_of_k;
((int64_t)bi_src * layer_head_pair_count * max_length + (int64_t)pair * max_length + token_decoding_index) * frames_of_k;
for (int tid = threadIdx.x; tid < frames_of_k; tid += blockDim.x) {
target[tid] = src[tid]; // use vectorized read write in future if needed
target[tid] = src[tid]; // use vectorized read write in future if needed
}
}
@ -1548,8 +1604,7 @@ void LaunchFinalizeCrossQK(
float* cross_qk_output,
int num_return_sequences,
const int* cache_indir_data,
const int32_t* beam_indices
) {
const int32_t* beam_indices) {
int64_t br = (int64_t)batch_size * num_return_sequences;
ORT_ENFORCE(br < 65536L && cross_qk_layer_head_pair_count < 65536);
const int total_decoding_length = iteration_number - 1;
@ -1558,15 +1613,15 @@ void LaunchFinalizeCrossQK(
typedef typename ToCudaType<float>::MappedType CudaT;
CopyDecoderCrossQKAllStepsKernel<<<grid, block, 0, stream>>>(
context_decoding_len,
num_beams,
num_return_sequences,
max_length,
frames_of_k,
(const CudaT*)cross_qk_buffer_data,
(CudaT*)cross_qk_output,
cache_indir_data,
beam_indices);
context_decoding_len,
num_beams,
num_return_sequences,
max_length,
frames_of_k,
(const CudaT*)cross_qk_buffer_data,
(CudaT*)cross_qk_output,
cache_indir_data,
beam_indices);
}
template <int ElementsPerThreads>
@ -1575,12 +1630,11 @@ __global__ void ForceDecodingIdsKernel(
const int vocab_size,
const int32_t* force_ids,
int id_len,
int step
) {
int step) {
const int num_beams = gridDim.y;
const int beam = blockIdx.y;
const int batch = blockIdx.z;
beam_scores += (((int64_t)batch * num_beams + beam)* vocab_size); // move to (batch, beam)
beam_scores += (((int64_t)batch * num_beams + beam) * vocab_size); // move to (batch, beam)
const int32_t id_wanted = force_ids[((int64_t)batch * id_len) + step];
if (id_wanted < 0 || id_wanted >= vocab_size) return;
@ -1588,7 +1642,7 @@ __global__ void ForceDecodingIdsKernel(
const int32_t block_start_id = blockIdx.x * elements_per_block;
int32_t token_id = block_start_id + (int)threadIdx.x;
#pragma unroll
#pragma unroll
for (int elem = 0; elem < ElementsPerThreads; elem++) {
if (token_id < vocab_size) {
beam_scores[token_id] = ((token_id == id_wanted) ? 0.0f : cub::FpLimits<float>::Lowest());
@ -1597,7 +1651,6 @@ __global__ void ForceDecodingIdsKernel(
}
}
void LaunchForceDecodingIds(
float* beam_scores,
const int batch_size,
@ -1606,15 +1659,13 @@ void LaunchForceDecodingIds(
const int32_t* force_ids,
int id_len,
int step,
cudaStream_t stream
) {
cudaStream_t stream) {
dim3 blocks(512);
constexpr int ElementsPerThreads = 4;
unsigned gridx = static_cast<unsigned>((vocab_size + 512 * ElementsPerThreads - 1) / (512 * ElementsPerThreads));
dim3 grids(gridx, num_beams, batch_size);
ForceDecodingIdsKernel<ElementsPerThreads><<<grids, blocks, 0, stream>>>(
beam_scores, vocab_size, force_ids, id_len, step
);
beam_scores, vocab_size, force_ids, id_len, step);
}
template <typename T>
@ -1624,8 +1675,7 @@ __global__ void SaveNoSpeechProbsKernel(
const int batch_size,
const int num_beams,
const int vocab_size,
const int no_speech_token_id
) {
const int no_speech_token_id) {
int b = blockIdx.x * blockDim.x + threadIdx.x;
if (b < batch_size) {
int64_t src_offset = b * num_beams * vocab_size + no_speech_token_id;
@ -1635,20 +1685,19 @@ __global__ void SaveNoSpeechProbsKernel(
template <typename T>
void LaunchSaveNoSpeechProbs(
T* result_no_speech_probs, /* [batch]*/
const float* probs, /* [batch, num_beams, vocab_size]*/
T* result_no_speech_probs, /* [batch]*/
const float* probs, /* [batch, num_beams, vocab_size]*/
const int batch_size,
const int num_beams,
const int vocab_size,
const int no_speech_token_id,
cudaStream_t stream
) {
cudaStream_t stream) {
int tpb = 256;
int bpg = (batch_size + 255) / 256;
typedef typename ToCudaType<T>::MappedType CudaT;
SaveNoSpeechProbsKernel<CudaT><<<bpg, tpb, 0, stream>>>(
(CudaT*)result_no_speech_probs, probs, batch_size, num_beams, vocab_size, no_speech_token_id);
(CudaT*)result_no_speech_probs, probs, batch_size, num_beams, vocab_size, no_speech_token_id);
}
template void LaunchSaveNoSpeechProbs<float>(
@ -1658,8 +1707,7 @@ template void LaunchSaveNoSpeechProbs<float>(
const int num_beams,
const int vocab_size,
const int no_speech_token_id,
cudaStream_t stream
);
cudaStream_t stream);
template void LaunchSaveNoSpeechProbs<MLFloat16>(
MLFloat16* result_no_speech_probs,
@ -1668,8 +1716,7 @@ template void LaunchSaveNoSpeechProbs<MLFloat16>(
const int num_beams,
const int vocab_size,
const int no_speech_token_id,
cudaStream_t stream
);
cudaStream_t stream);
} // namespace cuda
} // namespace contrib

View file

@ -6,6 +6,8 @@
#include <stdint.h>
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <cstdio>
#include "contrib_ops/cpu/transformers/generation_shared.h"
namespace onnxruntime {
namespace contrib {
@ -49,6 +51,21 @@ struct HypothesisScore {
const int32_t* hypothesis;
int hypothesis_length;
float score;
#ifdef DEBUG_GENERATION
__device__ void Print() const {
printf("HypothesisScore (hypothesis_length=%d, score=%f) \n", hypothesis_length, score);
printf(" hypothesis:");
if (hypothesis_length > 0 && hypothesis != NULL) {
for (int i = 0; i < hypothesis_length; ++i) {
printf("%d ", hypothesis[i]);
}
} else {
printf("(empty)");
}
printf("\n");
}
#endif
};
struct BeamHypotheses {
@ -71,6 +88,22 @@ struct BeamHypotheses {
int pad_token_id, // pad token
int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length)
T* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
#ifdef DEBUG_GENERATION
__device__ void Print() const {
printf("BeamHypotheses:\n");
printf(" beams_count: %d\n", beams_count_);
printf(" beams_used: %d\n", beams_used_);
printf(" length_penalty: %f\n", length_penalty_);
printf(" done: %s\n", done_ ? "true" : "false");
printf(" beams:\n");
for (int i = 0; i < beams_used_; ++i) {
printf(" Beam %d:\n", i + 1);
beams_[i].Print();
}
}
#endif
};
struct BeamScorerState {
@ -81,9 +114,23 @@ struct BeamScorerState {
int pad_token_id_;
int eos_token_id_;
bool early_stopping_;
int not_done_count_; // When zero, every batch entry is done (starts at batch_size_)
int not_done_count_; // When zero, every batch entry is done (starts at batch_size_)
int hypothesis_buffer_used_; // Offset of available buffer, or length of used buffer.
#ifdef DEBUG_GENERATION
__host__ __device__ void Print(bool is_cpu) const {
printf("BeamScorerState (cpu=%d) Dump:\n", is_cpu ? 1 : 0);
printf(" batch_size_: %d\n", batch_size_);
printf(" num_beams_: %d\n", num_beams_);
printf(" max_length_: %d\n", max_length_);
printf(" num_return_sequences_: %d\n", num_return_sequences_);
printf(" pad_token_id_: %d\n", pad_token_id_);
printf(" eos_token_id_: %d\n", eos_token_id_);
printf(" early_stopping_: %s\n", early_stopping_ ? "true" : "false");
printf(" not_done_count_: %d\n", not_done_count_);
printf(" hypothesis_buffer_used_: %d\n", hypothesis_buffer_used_);
}
#endif
};
void LaunchInitializeBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, float length_penalty, gsl::span<HypothesisScore> beams, int num_beams, cudaStream_t stream);

View file

@ -423,9 +423,6 @@ TEST(BeamSearchTest, DummyT5WithOuterScopeInitializers) {
}
TEST(BeamSearchTest, DummyT5WithSequenceInputIds) {
#if defined(USE_CUDA) && defined(USE_DML)
SKIP_CUDA_TEST_WITH_DML;
#endif
// dummy_t5_with_sequence_input_ids.onnx model generated using following command:
// python onnxruntime/test/testdata/dummy_t5_generator.py --output-path dummy_t5_with_sequence_input_ids.onnx --sequence-as-input
ModelTester tester(CurrentTestName(), ORT_TSTR("testdata/dummy_t5_with_sequence_input_ids.onnx"));