This commit is contained in:
Tianlei Wu 2025-02-04 16:56:59 -08:00
parent 164495fbf0
commit 7fe7ea2f0e
3 changed files with 26 additions and 12 deletions

View file

@ -383,7 +383,7 @@ __device__ void BeamHypotheses::Output(
}
__global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
BeamScorerState& state,
BeamScorerState* state_gpu,
const int32_t* sequences_buffer,
int sequence_length,
BeamHypotheses* beam_hyps_,
@ -397,6 +397,8 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
// Sequences shape is (batch_size * num_beams, total_sequence_length)
// It contains word ID of whole sequence generated so far.
// It is different from subgraph input_ids, which only need one word when past state is not empty.
BeamScorerState& state = *state_gpu;
printf("\n >>> BeamSearchScorer_Process \n");
int batch = threadIdx.x;
@ -463,11 +465,11 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
printf("\n <<< BeamSearchScorer_Process \n");
}
__global__ void DumpBeamScorerState(BeamScorerState& state) {
state.Print();
__global__ void DumpBeamScorerState(BeamScorerState* state) {
state->Print();
}
void DumpBeamScorerStates(BeamScorerState& state_cpu, BeamScorerState& state, cudaStream_t stream){
void DumpBeamScorerStates(BeamScorerState& state_cpu, BeamScorerState* state, cudaStream_t stream){
printf("\n state_cpu: \n");
state_cpu.Print();
@ -490,7 +492,7 @@ void DumpBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, cudaStream_t stream
}
void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
BeamScorerState& state,
BeamScorerState* state_gpu,
gsl::span<const int32_t> sequences,
int sequence_length,
gsl::span<BeamHypotheses> beam_hyps,
@ -502,13 +504,19 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
gsl::span<const int32_t> next_tokens,
gsl::span<const int32_t> next_indices,
cudaStream_t stream) {
printf("\n >>> LaunchBeamSearchScorer_Process \n");
size_t printfBufferSize = 10 * 1024 * 1024;
cudaDeviceSetLimit(cudaLimitPrintfFifoSize, printfBufferSize);
cudaDeviceSynchronize();
printf("\n >>> LaunchBeamSearchScorer_Process \n");
cudaDeviceSynchronize();
DumpBeamHypotheses(beam_hyps, stream);
DumpBeamScorerStates(state_cpu, state, stream);
cudaDeviceSynchronize();
DumpBeamScorerStates(state_cpu, state_gpu, stream);
cudaDeviceSynchronize();
BeamSearchScorer_Process<<<1, state_cpu.batch_size_, 0, stream>>>(state_cpu,
state,
state_gpu,
sequences.data(),
sequence_length,
beam_hyps.data(),
@ -520,10 +528,14 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
next_tokens.data(),
next_indices.data());
cudaDeviceSynchronize();
DumpBeamHypotheses(beam_hyps, stream);
DumpBeamScorerStates(state_cpu, state, stream);
cudaDeviceSynchronize();
DumpBeamScorerStates(state_cpu, state_gpu, stream);
cudaDeviceSynchronize();
printf("\n <<< LaunchBeamSearchScorer_Process \n");
cudaDeviceSynchronize();
}
__global__ void BeamSearchScorer_AppendNextTokenToSequences1(BeamScorerState& state,
@ -642,6 +654,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size,
gsl::span<T> sequence_scores,
cudaStream_t stream) {
printf("\n >>> LaunchBeamSearchScorer_Finalize \n");
cudaDeviceSynchronize();
BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state,
sequences.data(),
sequence_length,
@ -649,6 +662,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size,
final_beam_scores.data(),
output.data(),
sequence_scores.data());
cudaDeviceSynchronize();
printf("\n <<< LaunchBeamSearchScorer_Finalize \n");
}

View file

@ -103,7 +103,7 @@ struct BeamScorerState {
void LaunchInitializeBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, float length_penalty, gsl::span<HypothesisScore> beams, int num_beams, cudaStream_t stream);
void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
BeamScorerState& state,
BeamScorerState* state,
gsl::span<const int32_t> sequences,
int sequence_length,
gsl::span<BeamHypotheses> beam_hyps_,

View file

@ -726,7 +726,7 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
state_cpu_->Print();
cuda::LaunchBeamSearchScorer_Process(*state_cpu_,
*state_gpu_,
state_gpu_.get(),
sequences.GetCurrentDeviceSequences(),
sequences.GetSequenceLength(),
beam_hyps_,