mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-05-19 21:32:23 +00:00
Softmax interface update (#12469)
* Template datatype for SoftmaxWithRawMaskSmallKernel in ROCm EP * Remove valid_items usage from SoftmaxWithRawMaskSmallKernel for ROCm EP The kernel already masks off invalid items and this gives a much faster implementation in hipCUB. * Update accumulator type in ROCm EP for SoftmaxWithRawMaskSmallKernel Hard code accumulator to fp32 for hipCUB in indicated kernel. * Reset casting to old behavior * Document steps to optimize SoftMax kernel on ROCm EP Usage of the hipCUB valid_items interface on reduction operations has a significant performance impact. Masking all thread data to avoid need to use the valid_items interface to hipCUB.
This commit is contained in:
parent
30ebc9e00a
commit
a433f22f17
1 changed files with 8 additions and 2 deletions
|
|
@ -196,6 +196,9 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
|||
// Input dimension is BxNxSxS*; blockIdx.y is batch index b; gridDim.x=N*S; blockIdx.x is index within N*S;
|
||||
int index = (blockIdx.y * gridDim.x + blockIdx.x) * all_sequence_length + threadIdx.x;
|
||||
|
||||
// Mask all thread_data values to negative infinity to allow BlockReduce Max operation over all thread_data
|
||||
// members with all invalid members set to a value that does not impact the final result. This is necessary
|
||||
// to avoid the performance impact from using the valid_items interface.
|
||||
float thread_data = -ROCMRT_INF_F;
|
||||
if (threadIdx.x < all_sequence_length) {
|
||||
if (add_before_softmax == nullptr) {
|
||||
|
|
@ -242,7 +245,7 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
|||
return;
|
||||
}
|
||||
|
||||
const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max(), all_sequence_length);
|
||||
const float max = BlockReduce(tmp_storage).Reduce(thread_data, hipcub::Max());
|
||||
|
||||
// Store max value
|
||||
if (threadIdx.x == 0) {
|
||||
|
|
@ -250,8 +253,11 @@ __device__ inline void SoftmaxWithRawMaskSmall(const int all_sequence_length,
|
|||
}
|
||||
__syncthreads();
|
||||
|
||||
// Mask all thread_data_exp values to zero to allow BlockReduce Sum operation over all thread_data_exp
|
||||
// members with all invalid members set to a value that does not impact the final result. This is necessary
|
||||
// to avoid the performance impact from using the valid_items interface.
|
||||
float thread_data_exp = threadIdx.x < all_sequence_length ? expf(thread_data - max_block) : 0.0f;
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum(), all_sequence_length);
|
||||
const auto sum = BlockReduce(tmp_storage).Reduce(thread_data_exp, hipcub::Sum());
|
||||
|
||||
// Store value of 1.0/sum
|
||||
if (threadIdx.x == 0) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue