diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 039dc6f4f9..07a8896210 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -392,7 +392,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, if (beam_hyp.beams_used_ == state.num_beams_) { if (state.early_stopping_ || !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length)) { beam_hyp.done_ = true; - if (atomicAdd(&state.not_done_count_, -1) == 0) + if (atomicAdd(&state.not_done_count_, -1) == 1) state_cpu.not_done_count_ = 0; // Update the CPU side } }