From 2ac07549cfd85812af1fb09e6ee41b331ef5f1a3 Mon Sep 17 00:00:00 2001 From: Tianlei Wu Date: Wed, 5 Feb 2025 13:04:32 -0800 Subject: [PATCH] flush --- .../cuda/transformers/generation_cuda_impl.cu | 26 ++++++++++--------- .../transformers/generation_device_helper.cc | 5 ++++ .../cuda/utils/dump_cuda_tensor.cc | 3 +++ 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 989c5e264c..5d96be94b8 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -375,15 +375,17 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, printf("\n >>> BeamSearchScorer_Process \n"); #endif - int batch = threadIdx.x; - int batch_start = batch * state.num_beams_; + const int batch = threadIdx.x; + const int num_beams = state.num_beams_; + const bool early_stopping = state.early_stopping_; + const int batch_start = batch * num_beams; cuda::BeamHypotheses& beam_hyp = beam_hyps_[batch]; if (!beam_hyp.done_) { // Next tokens for this sentence. size_t beam_idx = 0; - size_t top_k = 2 * state.num_beams_; + size_t top_k = 2 * num_beams; for (size_t j = 0; j < top_k; j++) { int32_t next_token = next_tokens[batch * top_k + j]; float next_score = next_scores[batch * top_k + j]; @@ -397,7 +399,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, #endif // Add to generated hypotheses if end of sentence. if ((state.eos_token_id_ >= 0) && (next_token == state.eos_token_id_)) { - bool is_beam_token_worse_than_top_num_beams = (j >= state.num_beams_); + bool is_beam_token_worse_than_top_num_beams = (j >= num_beams); if (is_beam_token_worse_than_top_num_beams) { continue; } @@ -418,22 +420,22 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, } // Once the beam for next step is full, don't add more tokens to it. - if (beam_idx == state.num_beams_) + if (beam_idx == num_beams) break; } // Check if we are done so that we can save a pad step if all(done) #ifdef DEBUG_GENERATION - printf("\n beam_hyp.beams_used_ == state.num_beams_ is %d\n", beam_hyp.beams_used_ == state.num_beams_ ? 1 : 0); + printf("\n beam_hyp.beams_used_ == num_beams is %d\n", beam_hyp.beams_used_ == num_beams ? 1 : 0); #endif - if (beam_hyp.beams_used_ == state.num_beams_) { - bool is_done = state.early_stopping_; + if (beam_hyp.beams_used_ == num_beams) { + bool is_done = early_stopping; #ifdef DEBUG_GENERATION - printf("\n --- state.early_stopping_=%d (batch %d)\n", static_cast(state.early_stopping_), batch); + printf("\n --- early_stopping=%d is_done=%d (batch %d)\n", static_cast(early_stopping), static_cast(is_done), batch); #endif if (is_done) { #ifdef DEBUG_GENERATION - printf("\n --- state.early_stopping_ (batch %d)\n", batch); + printf("\n --- is_done=%d is true due to early_stopping=%d (batch %d)\n", static_cast(is_done), static_cast(early_stopping), batch); #endif } @@ -451,13 +453,13 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, state_cpu.not_done_count_ = 0; // Update the CPU side #ifdef DEBUG_GENERATION printf("\n --- BeamSearchScorer_Process updated cpu state for batch %d\n", batch); - printf("\n --- is_done (state.early_stopping_=%d state_cpu.not_done_count_=%d) for batch %d\n", state.early_stopping_ ? 1 : 0, state_cpu.not_done_count_, batch); + printf("\n --- is_done (early_stopping=%d state_cpu.not_done_count_=%d) for batch %d\n", static_cast(early_stopping), state_cpu.not_done_count_, batch); #endif } } } else { // Pad the batch. - for (size_t beam_idx = 0; beam_idx < state.num_beams_; beam_idx++) { + for (size_t beam_idx = 0; beam_idx < num_beams; beam_idx++) { next_beam_scores_[batch_start + beam_idx] = 0.0f; next_beam_tokens_[batch_start + beam_idx] = state.pad_token_id_; next_beam_indices_[batch_start + beam_idx] = 0; diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc index dbf11b683a..85e46d97c1 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_device_helper.cc @@ -30,6 +30,7 @@ #include "sampling_cuda_helper.h" #ifdef DEBUG_GENERATION +#include #include #endif @@ -726,8 +727,10 @@ void CudaBeamSearchScorer::Process(transformers::ISequences& sequences, gsl::span& next_tokens, gsl::span& next_indices) { #ifdef DEBUG_GENERATION + cudaDeviceSynchronize(); printf("\n---Process ---\n"); state_cpu_->Print(true); + fflush(stdout); #endif cuda::LaunchBeamSearchScorer_Process(*state_cpu_, @@ -760,8 +763,10 @@ bool CudaBeamSearchScorer::IsDoneLater() const { CUDA_CALL_THROW(cudaEventSynchronize(event_process_complete_.Get())); #ifdef DEBUG_GENERATION + cudaDeviceSynchronize(); printf("\n---IsDoneLater ---\n"); state_cpu_->Print(true); + fflush(stdout); #endif return state_cpu_->not_done_count_ == 0; diff --git a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc index 5c39cf56df..d8cfd70324 100644 --- a/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc +++ b/onnxruntime/contrib_ops/cuda/utils/dump_cuda_tensor.cc @@ -95,6 +95,7 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, bool i } else { onnxruntime::utils::PrintCpuTensorFull(*data, dim0, dim1); } + std::cout << std::flush; } template @@ -116,6 +117,7 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di } else { onnxruntime::utils::PrintCpuTensorFull(*data, dim0, dim1, dim2); } + std::cout << std::flush; } template @@ -144,6 +146,7 @@ void DumpGpuTensor(const char* name, const T* tensor, int dim0, int dim1, int di onnxruntime::utils::PrintCpuTensorFull((*data) + i * dim1 * dim2 * dim3, dim1, dim2, dim3); } } + std::cout << std::flush; } void DumpGpuTensor(const char* name, const Tensor& tensor, int dim0, int dim1, int dim2) {