atomicAdd returns previous value, not current value. (#16690)

### Description
Mistake in beam scorer processing, atomicAdd result should be compared
with '1' vs '0' as it returns the original value, not the latest value.

This error just results in slow perf, nothing fails.

### Motivation and Context
Fixes #16642
This commit is contained in:
Ryan Hill 2023-07-14 15:46:57 -07:00 committed by GitHub
parent 44fd98ebfe
commit 2ae041f390
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -392,7 +392,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
if (beam_hyp.beams_used_ == state.num_beams_) {
if (state.early_stopping_ || !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length)) {
beam_hyp.done_ = true;
if (atomicAdd(&state.not_done_count_, -1) == 0)
if (atomicAdd(&state.not_done_count_, -1) == 1)
state_cpu.not_done_count_ = 0; // Update the CPU side
}
}