diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h index 4e2bd238bb..7229c7b522 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -196,6 +196,9 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; + // Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data + // members with all invalid members set to a value that does not impact the final result. This is necessary + // to avoid the performance impact from using the valid_items interface. float thread_data = -ROCMRT_INF_F; if (threadIdx.x < all_sequence_length) { if (add_before_softmax == nullptr) { @@ -242,7 +245,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, return; } - const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max(), all_sequence_length); + const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max()); // Store max value if (threadIdx.x == 0) { @@ -250,8 +253,11 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } __syncthreads(); + // Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp + // members with all invalid members set to a value that does not impact the final result. This is necessary + // to avoid the performance impact from using the valid_items interface. float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), all_sequence_length); + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum()); // Store value of 1.0/sum if (threadIdx.x == 0) {