mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-14 20:48:00 +00:00
atomicAdd returns previous value, not current value. (#16690)
### Description Mistake in beam scorer processing, atomicAdd result should be compared with '1' vs '0' as it returns the original value, not the latest value. This error just results in slow perf, nothing fails. ### Motivation and Context Fixes #16642
This commit is contained in:
parent
44fd98ebfe
commit
2ae041f390
1 changed files with 1 additions and 1 deletions
|
|
@ -392,7 +392,7 @@ __global__ void BeamSearchScorer_Process(BeamScorerState& state_cpu,
|
|||
if (beam_hyp.beams_used_ == state.num_beams_) {
|
||||
if (state.early_stopping_ || !beam_hyp.CanImprove(*std::max_element(next_scores + batch_start, next_scores + batch_start + top_k), sequence_length)) {
|
||||
beam_hyp.done_ = true;
|
||||
if (atomicAdd(&state.not_done_count_, -1) == 0)
|
||||
if (atomicAdd(&state.not_done_count_, -1) == 1)
|
||||
state_cpu.not_done_count_ = 0; // Update the CPU side
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in a new issue