From 2ae041f390917f776b4bb746fe9570f047c274d9 Mon Sep 17 00:00:00 2001 From: Ryan Hill <38674843+RyanUnderhill@users.noreply.github.com> Date: Fri, 14 Jul 2023 15:46:57 -0700 Subject: [PATCH] atomicAdd returns previous value, not current value. (#16690) ### Description Mistake in beam scorer processing, atomicAdd result should be compared with '1' vs '0' as it returns the original value, not the latest value. This error just results in slow perf, nothing fails. ### Motivation and Context Fixes #16642 --- .../contrib_ops/cuda/transformers/generation_cuda_impl.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu index 039dc6f4f9..07a8896210 100644 --- a/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu +++ b/onnxruntime/contrib_ops/cuda/transformers/generation_cuda_impl.cu @@ -392,7 +392,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu, if (beam_hyp.beams_used_ == state.num_beams_) { if (state.early_stopping_ || !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length)) { beam_hyp.done_ = true; - if (atomicAdd(&state.not_done_count_, -1) == 0) + if (atomicAdd(&state.not_done_count_, -1) == 1) state_cpu.not_done_count_ = 0; // Update the CPU side } }