From a433f22f17e59671ff01acf0d270b7e3476a952a Mon Sep 17 00:00:00 2001 From: Joseph Groenenboom Date: Mon, 12 Sep 2022 15:02:31 -0500 Subject: [PATCH] Softmax interface update (#12469) * Template datatype for SoftmaxWithRawMaskSmallKernel in ROCm EP * Remove valid_items usage from SoftmaxWithRawMaskSmallKernel for ROCm EP The kernel already masks off invalid items and this gives a much faster implementation in hipCUB. * Update accumulator type in ROCm EP for SoftmaxWithRawMaskSmallKernel Hard code accumulator to fp32 for hipCUB in indicated kernel. * Reset casting to old behavior * Document steps to optimize SoftMax kernel on ROCm EP Usage of the hipCUB valid_items interface on reduction operations has a significant performance impact. Masking all thread data to avoid need to use the valid_items interface to hipCUB. --- onnxruntime/contrib_ops/rocm/bert/attention_softmax.h | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h index 4e2bd238bb..7229c7b522 100644 --- a/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h +++ b/onnxruntime/contrib_ops/rocm/bert/attention_softmax.h @@ -196,6 +196,9 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, // Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S; int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x; + // Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data + // members with all invalid members set to a value that does not impact the final result. This is necessary + // to avoid the performance impact from using the valid_items interface. float thread_data = -ROCMRT_INF_F; if (threadIdx.x < all_sequence_length) { if (add_before_softmax == nullptr) { @@ -242,7 +245,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, return; } - const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max(), all_sequence_length); + const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max()); // Store max value if (threadIdx.x == 0) { @@ -250,8 +253,11 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length, } __syncthreads(); + // Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp + // members with all invalid members set to a value that does not impact the final result. This is necessary + // to avoid the performance impact from using the valid_items interface. float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f; - const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), all_sequence_length); + const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum()); // Store value of 1.0/sum if (threadIdx.x == 0) {