Softmax interface update (#12469)

* Template datatype for SoftmaxWithRawMaskSmallKernel in ROCm EP

* Remove valid_items usage from SoftmaxWithRawMaskSmallKernel for ROCm EP

The kernel already masks off invalid items and this gives a much
faster implementation in hipCUB.

* Update accumulator type in ROCm EP for SoftmaxWithRawMaskSmallKernel

Hard code accumulator to fp32 for hipCUB in indicated kernel.

* Reset casting to old behavior

* Document steps to optimize SoftMax kernel on ROCm EP

Usage of the hipCUB valid_items interface on reduction operations
has a significant performance impact. Masking all thread data to
avoid need to use the valid_items interface to hipCUB.
This commit is contained in:
Joseph Groenenboom 2022-09-12 15:02:31 -05:00 committed by GitHub
parent 30ebc9e00a
commit a433f22f17
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -196,6 +196,9 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
// Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data
// members with all invalid members set to a value that does not impact the final result. This is necessary
// to avoid the performance impact from using the valid_items interface.
float thread_data = -ROCMRT_INF_F;
if (threadIdx.x < all_sequence_length) {
if (add_before_softmax == nullptr) {
@ -242,7 +245,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
return;
}
const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max(), all_sequence_length);
const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max());
// Store max value
if (threadIdx.x == 0) {
@ -250,8 +253,11 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
}
__syncthreads();
// Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp
// members with all invalid members set to a value that does not impact the final result. This is necessary
// to avoid the performance impact from using the valid_items interface.
float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), all_sequence_length);
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum());
// Store value of 1.0/sum
if (threadIdx.x == 0) {