This commit is contained in:
Tianlei Wu 2025-02-05 02:03:30 -08:00
parent 4bc025537a
commit 8d6c1556a5
3 changed files with 81 additions and 48 deletions

View file

@ -267,37 +267,6 @@ __global__ void InitializeBeamHypotheses(BeamHypotheses* beam_hyps, int beam_hyp
beam_hyp.done_ = false;
}
// Function to dump the data in the struct
__device__ void dump_hypothesis_score(const struct HypothesisScore* hs) {
printf("HypothesisScore Dump:\n");
printf(" hypothesis_length: %d\n", hs->hypothesis_length);
printf(" score: %f\n", hs->score);
printf(" hypothesis: ");
if (hs->hypothesis_length > 0 && hs->hypothesis != NULL) {
for (int i = 0; i < hs->hypothesis_length; ++i) {
printf("%d ", hs->hypothesis[i]);
}
} else {
printf("(empty)");
}
printf("\n");
}
__device__ void dump_beam_hypotheses(const struct BeamHypotheses* bh) {
printf("BeamHypotheses Dump:\n");
printf(" beams_count_: %d\n", bh->beams_count_);
printf(" beams_used_: %d\n", bh->beams_used_);
printf(" length_penalty_: %f\n", bh->length_penalty_);
printf(" done_: %s\n", bh->done_ ? "true" : "false");
printf(" beams_: %d\n", bh->beams_used_);
for (int i = 0; i < bh->beams_used_; ++i) {
printf(" Beam %d:\n", i + 1);
dump_hypothesis_score(&bh->beams_[i]);
}
}
// For counts that are typically far less than 256, this will round up the count to the next multiple of 32
// If this winds up being >256 then it uses a block size of 256 and calculates the appropriate grid_size
struct GridBlock32 {
@ -329,7 +298,9 @@ void LaunchInitializeBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps,
__device__ void BeamHypotheses::Add(const int32_t* hypothesis, int hypothesis_length, float sum_logprobs) {
float score = sum_logprobs / pow(static_cast<float>(hypothesis_length), length_penalty_);
#ifdef DEBUG_GENERATION
printf("\n BeamHypotheses::Add (score=%f hypothesis_length=%d sum_logprobs=%f) \n", score, hypothesis_length, sum_logprobs);
#endif
size_t index = beams_used_;
// If the array is full, don't add unless it's better than the worst element
@ -345,15 +316,18 @@ __device__ void BeamHypotheses::Add(const int32_t* hypothesis, int hypothesis_le
beams_[index] = HypothesisScore{hypothesis, hypothesis_length, score};
#ifdef DEBUG_GENERATION
printf("\n BeamHypotheses::Add (index=%d) \n", static_cast<int>(index));
// dump_hypothesis_score(&beams_[index]);
#endif
}
__device__ bool BeamHypotheses::CanImprove(float best_sum_logprobs, int current_length) const {
float current_score = best_sum_logprobs / pow(static_cast<float>(current_length), length_penalty_);
bool result = beams_[beams_count_ - 1].score < current_score;
#ifdef DEBUG_GENERATION
printf("\n BeamHypotheses::CanImprove (current_score=%f beams_[%d].score=%f can_improve=%d) \n",
current_score, beams_count_ - 1, beams_[beams_count_ - 1].score, static_cast<int>(result));
#endif
return result;
}
@ -398,8 +372,9 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
// It contains word ID of whole sequence generated so far.
// It is different from subgraph input_ids, which only need one word when past state is not empty.
BeamScorerState& state = *state_gpu;
#ifdef DEBUG_GENERATION
printf("\n >>> BeamSearchScorer_Process \n");
#endif
int batch = threadIdx.x;
int batch_start = batch * state.num_beams_;
@ -417,9 +392,10 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
int batch_beam_idx = batch_start + next_index;
#ifdef DEBUG_GENERATION
printf("\nbatch=%d batch_beam_idx=%d j=%d next_token=%d eos_token_id=%d next_score=%f next_index=%d\n",
batch, batch_beam_idx, static_cast<int>(j), next_token, state.eos_token_id_, next_score, next_index);
#endif
// Add to generated hypotheses if end of sentence.
if ((state.eos_token_id_ >= 0) && (next_token == state.eos_token_id_)) {
bool is_beam_token_worse_than_top_num_beams = (j >= state.num_beams_);
@ -448,15 +424,23 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
}
// Check if we are done so that we can save a pad step if all(done)
#ifdef DEBUG_GENERATION
printf("\n beam_hyp.beams_used_ == state.num_beams_ is %d\n", beam_hyp.beams_used_ == state.num_beams_ ? 1 : 0);
#endif
if (beam_hyp.beams_used_ == state.num_beams_) {
if (state.early_stopping_ || !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length)) {
bool is_done = state.early_stopping_;
if (!is_done) {
float best_sum_logprobs = *std::max_element(next_scores + batch_start, next_scores + batch_start + top_k);
is_done = !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length);
}
if (is_done) {
beam_hyp.done_ = true;
if (atomicAdd(&state.not_done_count_, -1) == 1)
state_cpu.not_done_count_ = 0; // Update the CPU side
printf("\n --- BeamSearchScorer_Process updated cpu state for batch %d\n", threadIdx.x);
#ifdef DEBUG_GENERATION
printf("\n --- BeamSearchScorer_Process updated cpu state for batch %d\n", batch);
#endif
}
}
} else {
@ -467,8 +451,9 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
next_beam_indices_[batch_start + beam_idx] = 0;
}
}
#ifdef DEBUG_GENERATION
printf("\n <<< BeamSearchScorer_Process \n");
#endif
}
__global__ void DumpBeamScorerState(BeamScorerState* state) {
@ -484,7 +469,7 @@ void DumpBeamScorerStates(BeamScorerState& state_cpu, BeamScorerState* state, cu
}
__global__ void DumpBeamSearchScorer(const BeamHypotheses& beam_hyp) {
dump_beam_hypotheses(&beam_hyp);
beam_hyp.Print();
}
void DumpBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, cudaStream_t stream) {
@ -510,15 +495,13 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
gsl::span<const int32_t> next_tokens,
gsl::span<const int32_t> next_indices,
cudaStream_t stream) {
// size_t printfBufferSize = 10 * 1024 * 1024;
// cudaDeviceSetLimit(cudaLimitPrintfFifoSize, printfBufferSize);
// cudaDeviceSynchronize();
#ifdef DEBUG_GENERATION
printf("\n >>> LaunchBeamSearchScorer_Process \n");
// cudaDeviceSynchronize();
DumpBeamHypotheses(beam_hyps, stream);
cudaDeviceSynchronize();
DumpBeamScorerStates(state_cpu, state_gpu, stream);
cudaDeviceSynchronize();
#endif
BeamSearchScorer_Process<<<1, state_cpu.batch_size_, 0, stream>>>(state_cpu,
state_gpu,
@ -533,13 +516,14 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
next_tokens.data(),
next_indices.data());
#ifdef DEBUG_GENERATION
cudaDeviceSynchronize();
DumpBeamHypotheses(beam_hyps, stream);
cudaDeviceSynchronize();
DumpBeamScorerStates(state_cpu, state_gpu, stream);
cudaDeviceSynchronize();
printf("\n <<< LaunchBeamSearchScorer_Process \n");
// cudaDeviceSynchronize();
#endif
}
__global__ void BeamSearchScorer_AppendNextTokenToSequences1(BeamScorerState& state,
@ -575,7 +559,9 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp
gsl::span<int32_t> next_beam_tokens,
gsl::span<int32_t> next_beam_indices,
cudaStream_t stream) {
#ifdef DEBUG_GENERATION
printf("\n >>> LaunchBeamSearchScorer_AppendNextTokenToSequences \n");
#endif
const int max_threads = 512;
int batch_beam_size = state_cpu.batch_size_ * state_cpu.num_beams_;
@ -610,7 +596,9 @@ void LaunchBeamSearchScorer_AppendNextTokenToSequences(BeamScorerState& state_cp
next_sequences.data(),
sequence_length,
next_beam_tokens.data());
#ifdef DEBUG_GENERATION
printf("\n <<< LaunchBeamSearchScorer_AppendNextTokenToSequences \n");
#endif
}
template <typename T>
@ -657,7 +645,10 @@ void LaunchBeamSearchScorer_Finalize(int batch_size,
gsl::span<int32_t> output,
gsl::span<T> sequence_scores,
cudaStream_t stream) {
#ifdef DEBUG_GENERATION
printf("\n >>> LaunchBeamSearchScorer_Finalize \n");
#endif
cudaDeviceSynchronize();
BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state,
sequences.data(),
@ -667,7 +658,10 @@ void LaunchBeamSearchScorer_Finalize(int batch_size,
output.data(),
sequence_scores.data());
cudaDeviceSynchronize();
#ifdef DEBUG_GENERATION
printf("\n <<< LaunchBeamSearchScorer_Finalize \n");
#endif
}
template void LaunchBeamSearchScorer_Finalize<float>(

View file

@ -7,6 +7,7 @@
#include <cuda_fp16.h>
#include <curand_kernel.h>
#include <cstdio>
#include "contrib_ops/cpu/transformers/generation_shared.h"
namespace onnxruntime {
namespace contrib {
@ -50,6 +51,21 @@ struct HypothesisScore {
const int32_t* hypothesis;
int hypothesis_length;
float score;
#ifdef DEBUG_GENERATION
__device__ void Print() const {
printf("HypothesisScore (hypothesis_length=%d, score=%f) \n", hypothesis_length, score);
printf(" hypothesis:");
if (hypothesis_length > 0 && hypothesis != NULL) {
for (int i = 0; i < hypothesis_length; ++i) {
printf("%d ", hypothesis[i]);
}
} else {
printf("(empty)");
}
printf("\n");
}
#endif
};
struct BeamHypotheses {
@ -72,6 +88,22 @@ struct BeamHypotheses {
int pad_token_id, // pad token
int32_t* sequences, // buffer with pad token, shape (num_return_sequences, max_length)
T* sequences_scores); // buffer for sequence scores, with shape (num_return_sequences)
#ifdef DEBUG_GENERATION
__device__ void Print() const {
printf("BeamHypotheses:\n");
printf(" beams_count: %d\n", beams_count_);
printf(" beams_used: %d\n", beams_used_);
printf(" length_penalty: %f\n", length_penalty_);
printf(" done: %s\n", done_ ? "true" : "false");
printf(" beams:\n");
for (int i = 0; i < beams_used_; ++i) {
printf(" Beam %d:\n", i + 1);
beams_[i].Print();
}
}
#endif
};
struct BeamScorerState {
@ -85,7 +117,7 @@ struct BeamScorerState {
int not_done_count_; // When zero, every batch entry is done (starts at batch_size_)
int hypothesis_buffer_used_; // Offset of available buffer, or length of used buffer.
// Function to dump the struct data to stdout
#ifdef DEBUG_GENERATION
__host__ __device__ void Print(bool is_cpu) const {
printf("BeamScorerState (cpu=%d) Dump:\n", is_cpu ? 1 : 0);
printf(" batch_size_: %d\n", batch_size_);
@ -98,6 +130,7 @@ struct BeamScorerState {
printf(" not_done_count_: %d\n", not_done_count_);
printf(" hypothesis_buffer_used_: %d\n", hypothesis_buffer_used_);
}
#endif
};
void LaunchInitializeBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, float length_penalty, gsl::span<HypothesisScore> beams, int num_beams, cudaStream_t stream);

View file

@ -524,6 +524,7 @@ Status ProcessLogits(const OrtValue& logits, //
beam_state->remaining_scores = beam_state->remaining_scores.subspan(next_token_scores.size());
}
gsl::span<float> out_scores = beam_state->next_scores;
if (num_beams <= 32) {
constexpr size_t max_parts_of_vocab = 128;
size_t candidate_count = SafeInt<size_t>(batch_beam_size) * 2 * num_beams;
@ -588,16 +589,17 @@ Status ProcessLogits(const OrtValue& logits, //
cuda::LaunchNextTokenKernel(next_token_indices, beam_state->next_indices.data(), beam_state->next_tokens.data(),
batch_size, top_k, vocab_size, cuda_stream);
// BUG?: Copy topk_scores to next_scores
#ifdef DEBUG_GENERATION
dumper->Print("next_scores before scorer", topk_scores->Data<float>(), batch_size, top_k);
dumper->Print("next_tokens before scorer", beam_state->next_tokens.data(), batch_size, top_k);
dumper->Print("next_indices before scorer", beam_state->next_indices.data(), batch_size, top_k);
#endif
// use topk_scores as next_scores
out_scores = gsl::span<float>(topk_scores->MutableData<float>(), batch_size * top_k);
}
// gsl::span doesn't convert from non const to const, so all we're doing here is making each const.
gsl::span<const float> next_scores(beam_state->next_scores.data(), beam_state->next_scores.size());
gsl::span<const float> next_scores(out_scores.data(), out_scores.size());
gsl::span<const int32_t> next_tokens(beam_state->next_tokens.data(), beam_state->next_tokens.size());
gsl::span<const int32_t> next_indices(beam_state->next_indices.data(), beam_state->next_indices.size());
@ -723,8 +725,10 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
gsl::span<const float>& next_scores,
gsl::span<const int32_t>& next_tokens,
gsl::span<const int32_t>& next_indices) {
#ifdef DEBUG_GENERATION
printf("\n---Process ---\n");
state_cpu_->Print(true);
#endif
cuda::LaunchBeamSearchScorer_Process(*state_cpu_,
state_gpu_.get(),
@ -755,8 +759,10 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
bool CudaBeamSearchScorer::IsDoneLater() const {
CUDA_CALL_THROW(cudaEventSynchronize(event_process_complete_.Get()));
#ifdef DEBUG_GENERATION
printf("\n---IsDoneLater ---\n");
state_cpu_->Print(true);
#endif
return state_cpu_->not_done_count_ == 0;
}