This commit is contained in:
Tianlei Wu 2025-02-05 13:04:32 -08:00
parent e0e748cfc7
commit 2ac07549cf
3 changed files with 22 additions and 12 deletions

View file

@ -375,15 +375,17 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
printf("\n >>> BeamSearchScorer_Process \n");
#endif
int batch = threadIdx.x;
int batch_start = batch * state.num_beams_;
const int batch = threadIdx.x;
const int num_beams = state.num_beams_;
const bool early_stopping = state.early_stopping_;
const int batch_start = batch * num_beams;
cuda::BeamHypotheses& beam_hyp = beam_hyps_[batch];
if (!beam_hyp.done_) {
// Next tokens for this sentence.
size_t beam_idx = 0;
size_t top_k = 2 * state.num_beams_;
size_t top_k = 2 * num_beams;
for (size_t j = 0; j < top_k; j++) {
int32_t next_token = next_tokens[batch * top_k + j];
float next_score = next_scores[batch * top_k + j];
@ -397,7 +399,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
#endif
// Add to generated hypotheses if end of sentence.
if ((state.eos_token_id_ >= 0) && (next_token == state.eos_token_id_)) {
bool is_beam_token_worse_than_top_num_beams = (j >= state.num_beams_);
bool is_beam_token_worse_than_top_num_beams = (j >= num_beams);
if (is_beam_token_worse_than_top_num_beams) {
continue;
}
@ -418,22 +420,22 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
}
// Once the beam for next step is full, don't add more tokens to it.
if (beam_idx == state.num_beams_)
if (beam_idx == num_beams)
break;
}
// Check if we are done so that we can save a pad step if all(done)
#ifdef DEBUG_GENERATION
printf("\n beam_hyp.beams_used_ == state.num_beams_ is %d\n", beam_hyp.beams_used_ == state.num_beams_ ? 1 : 0);
printf("\n beam_hyp.beams_used_ == num_beams is %d\n", beam_hyp.beams_used_ == num_beams ? 1 : 0);
#endif
if (beam_hyp.beams_used_ == state.num_beams_) {
bool is_done = state.early_stopping_;
if (beam_hyp.beams_used_ == num_beams) {
bool is_done = early_stopping;
#ifdef DEBUG_GENERATION
printf("\n --- state.early_stopping_=%d (batch %d)\n", static_cast<int>(state.early_stopping_), batch);
printf("\n --- early_stopping=%d is_done=%d (batch %d)\n", static_cast<int>(early_stopping), static_cast<int>(is_done), batch);
#endif
if (is_done) {
#ifdef DEBUG_GENERATION
printf("\n --- state.early_stopping_ (batch %d)\n", batch);
printf("\n --- is_done=%d is true due to early_stopping=%d (batch %d)\n", static_cast<int>(is_done), static_cast<int>(early_stopping), batch);
#endif
}
@ -451,13 +453,13 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
state_cpu.not_done_count_ = 0; // Update the CPU side
#ifdef DEBUG_GENERATION
printf("\n --- BeamSearchScorer_Process updated cpu state for batch %d\n", batch);
printf("\n --- is_done (state.early_stopping_=%d state_cpu.not_done_count_=%d) for batch %d\n", state.early_stopping_ ? 1 : 0, state_cpu.not_done_count_, batch);
printf("\n --- is_done (early_stopping=%d state_cpu.not_done_count_=%d) for batch %d\n", static_cast<int>(early_stopping), state_cpu.not_done_count_, batch);
#endif
}
}
} else {
// Pad the batch.
for (size_t beam_idx = 0; beam_idx < state.num_beams_; beam_idx++) {
for (size_t beam_idx = 0; beam_idx < num_beams; beam_idx++) {
next_beam_scores_[batch_start + beam_idx] = 0.0f;
next_beam_tokens_[batch_start + beam_idx] = state.pad_token_id_;
next_beam_indices_[batch_start + beam_idx] = 0;

View file

@ -30,6 +30,7 @@
#include "sampling_cuda_helper.h"
#ifdef DEBUG_GENERATION
#include <stdio.h>
#include <iostream>
#endif
@ -726,8 +727,10 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
gsl::span<const int32_t>& next_tokens,
gsl::span<const int32_t>& next_indices) {
#ifdef DEBUG_GENERATION
cudaDeviceSynchronize();
printf("\n---Process ---\n");
state_cpu_->Print(true);
fflush(stdout);
#endif
cuda::LaunchBeamSearchScorer_Process(*state_cpu_,
@ -760,8 +763,10 @@ bool CudaBeamSearchScorer::IsDoneLater() const {
CUDA_CALL_THROW(cudaEventSynchronize(event_process_complete_.Get()));
#ifdef DEBUG_GENERATION
cudaDeviceSynchronize();
printf("\n---IsDoneLater ---\n");
state_cpu_->Print(true);
fflush(stdout);
#endif
return state_cpu_->not_done_count_ == 0;

View file

@ -95,6 +95,7 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool i
} else {
onnxruntime::utils::PrintCpuTensorFull<T>(*data, dim0, dim1);
}
std::cout << std::flush;
}
template <typename T>
@ -116,6 +117,7 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di
} else {
onnxruntime::utils::PrintCpuTensorFull<T>(*data, dim0, dim1, dim2);
}
std::cout << std::flush;
}
template <typename T>
@ -144,6 +146,7 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di
onnxruntime::utils::PrintCpuTensorFull<T>((*data) + i * dim1 * dim2 * dim3, dim1, dim2, dim3);
}
}
std::cout << std::flush;
}
void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) {