mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-23 22:13:38 +00:00
sync
This commit is contained in:
parent
164495fbf0
commit
7fe7ea2f0e
3 changed files with 26 additions and 12 deletions
|
|
@ -383,7 +383,7 @@ __device__ void BeamHypotheses::Output(
|
|||
}
|
||||
|
||||
__global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
|
||||
BeamScorerState& state,
|
||||
BeamScorerState* state_gpu,
|
||||
const int32_t* sequences_buffer,
|
||||
int sequence_length,
|
||||
BeamHypotheses* beam_hyps_,
|
||||
|
|
@ -397,6 +397,8 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
|
|||
// Sequences shape is (batch_size * num_beams, total_sequence_length)
|
||||
// It contains word ID of whole sequence generated so far.
|
||||
// It is different from subgraph input_ids, which only need one word when past state is not empty.
|
||||
BeamScorerState& state = *state_gpu;
|
||||
|
||||
printf("\n >>> BeamSearchScorer_Process \n");
|
||||
|
||||
int batch = threadIdx.x;
|
||||
|
|
@ -463,11 +465,11 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
|
|||
printf("\n <<< BeamSearchScorer_Process \n");
|
||||
}
|
||||
|
||||
__global__ void DumpBeamScorerState(BeamScorerState& state) {
|
||||
state.Print();
|
||||
__global__ void DumpBeamScorerState(BeamScorerState* state) {
|
||||
state->Print();
|
||||
}
|
||||
|
||||
void DumpBeamScorerStates(BeamScorerState& state_cpu, BeamScorerState& state, cudaStream_t stream){
|
||||
void DumpBeamScorerStates(BeamScorerState& state_cpu, BeamScorerState* state, cudaStream_t stream){
|
||||
printf("\n state_cpu: \n");
|
||||
state_cpu.Print();
|
||||
|
||||
|
|
@ -490,7 +492,7 @@ void DumpBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, cudaStream_t stream
|
|||
}
|
||||
|
||||
void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
|
||||
BeamScorerState& state,
|
||||
BeamScorerState* state_gpu,
|
||||
gsl::span<const int32_t> sequences,
|
||||
int sequence_length,
|
||||
gsl::span<BeamHypotheses> beam_hyps,
|
||||
|
|
@ -502,13 +504,19 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
|
|||
gsl::span<const int32_t> next_tokens,
|
||||
gsl::span<const int32_t> next_indices,
|
||||
cudaStream_t stream) {
|
||||
printf("\n >>> LaunchBeamSearchScorer_Process \n");
|
||||
size_t printfBufferSize = 10 * 1024 * 1024;
|
||||
cudaDeviceSetLimit(cudaLimitPrintfFifoSize, printfBufferSize);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
printf("\n >>> LaunchBeamSearchScorer_Process \n");
|
||||
cudaDeviceSynchronize();
|
||||
DumpBeamHypotheses(beam_hyps, stream);
|
||||
DumpBeamScorerStates(state_cpu, state, stream);
|
||||
cudaDeviceSynchronize();
|
||||
DumpBeamScorerStates(state_cpu, state_gpu, stream);
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
BeamSearchScorer_Process<<<1, state_cpu.batch_size_, 0, stream>>>(state_cpu,
|
||||
state,
|
||||
state_gpu,
|
||||
sequences.data(),
|
||||
sequence_length,
|
||||
beam_hyps.data(),
|
||||
|
|
@ -520,10 +528,14 @@ void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
|
|||
next_tokens.data(),
|
||||
next_indices.data());
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
DumpBeamHypotheses(beam_hyps, stream);
|
||||
DumpBeamScorerStates(state_cpu, state, stream);
|
||||
|
||||
cudaDeviceSynchronize();
|
||||
DumpBeamScorerStates(state_cpu, state_gpu, stream);
|
||||
cudaDeviceSynchronize();
|
||||
printf("\n <<< LaunchBeamSearchScorer_Process \n");
|
||||
cudaDeviceSynchronize();
|
||||
|
||||
}
|
||||
|
||||
__global__ void BeamSearchScorer_AppendNextTokenToSequences1(BeamScorerState& state,
|
||||
|
|
@ -642,6 +654,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size,
|
|||
gsl::span<T> sequence_scores,
|
||||
cudaStream_t stream) {
|
||||
printf("\n >>> LaunchBeamSearchScorer_Finalize \n");
|
||||
cudaDeviceSynchronize();
|
||||
BeamSearchScorer_Finalize<<<1, batch_size, 0, stream>>>(state,
|
||||
sequences.data(),
|
||||
sequence_length,
|
||||
|
|
@ -649,6 +662,7 @@ void LaunchBeamSearchScorer_Finalize(int batch_size,
|
|||
final_beam_scores.data(),
|
||||
output.data(),
|
||||
sequence_scores.data());
|
||||
cudaDeviceSynchronize();
|
||||
printf("\n <<< LaunchBeamSearchScorer_Finalize \n");
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -103,7 +103,7 @@ struct BeamScorerState {
|
|||
void LaunchInitializeBeamHypotheses(gsl::span<BeamHypotheses> beam_hyps, float length_penalty, gsl::span<HypothesisScore> beams, int num_beams, cudaStream_t stream);
|
||||
|
||||
void LaunchBeamSearchScorer_Process(BeamScorerState& state_cpu,
|
||||
BeamScorerState& state,
|
||||
BeamScorerState* state,
|
||||
gsl::span<const int32_t> sequences,
|
||||
int sequence_length,
|
||||
gsl::span<BeamHypotheses> beam_hyps_,
|
||||
|
|
|
|||
|
|
@ -726,7 +726,7 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences,
|
|||
state_cpu_->Print();
|
||||
|
||||
cuda::LaunchBeamSearchScorer_Process(*state_cpu_,
|
||||
*state_gpu_,
|
||||
state_gpu_.get(),
|
||||
sequences.GetCurrentDeviceSequences(),
|
||||
sequences.GetSequenceLength(),
|
||||
beam_hyps_,
|
||||
|
|
|
|||
Loading…
Reference in a new issue