Support Smooth Softmax in fmha (#21885)

### Description
<!-- Describe your changes. -->
refer to https://github.com/microsoft/onnxruntime/pull/21867


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

---------

Co-authored-by: Your Name <you@example.com>
This commit is contained in:
Ye Wang 2024-08-28 09:29:33 -07:00 committed by GitHub
parent ef073fd8f4
commit bf8855ba3c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 66 additions and 9 deletions

View file

@ -1,13 +1,64 @@
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
index 4c80f549..34327633 100644
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
@@ -221,6 +221,8 @@ struct AttentionKernel {
int32_t num_batches = 0;
int32_t num_heads = 0;
+ bool use_smooth_softmax = false;
+
// dropout
bool use_dropout = false;
unsigned long long dropout_batch_head_rng_offset = 0;
@@ -897,7 +899,8 @@ struct AttentionKernel {
p.num_keys - iter_key_start,
iter_key_start == 0,
iteratorC_tile_offset,
- kSupportsBias ? 1.0f : p.scale);
+ kSupportsBias ? 1.0f : p.scale,
+ p.use_smooth_softmax);
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
@@ -1166,7 +1169,8 @@ struct AttentionKernel {
int max_col,
bool is_first,
typename WarpIteratorC::TensorCoord const& tile_offset,
- float scaling) {
+ float scaling,
+ bool use_smooth_softmax) {
/* Iterates on the accumulator and corresponding position on result matrix
(1) Update `mi[r]` to the max value of the row `r`
@@ -1257,7 +1261,7 @@ struct AttentionKernel {
accum_t mi_row, total_row;
LambdaIterator::iterateRows(
lane_offset,
- [&](int accum_m) { mi_row = mi[accum_m]; },
+ [&](int accum_m) { mi_row = mi[accum_m];},
[&](int accum_m, int accum_n, int idx) {
frag[idx] =
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
@@ -1294,7 +1298,7 @@ struct AttentionKernel {
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
total_row += addition_storage[id + kQueriesPerBlock * i];
}
- s_prime[id] = total_row;
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
}
}
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 964d2ff3..b366bc14 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -39,6 +39,7 @@
#include "cutlass/numeric_types.h"
#include <cuda_runtime.h>
+#include <cuda_fp16.h>
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
#include <mma.h>
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644
return reinterpret_cast<half_t const &>(result);
+#else
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
+#endif
+#endif
#else
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
#endif
#endif

View file

@ -415,6 +415,7 @@ Status EfficientAttention(
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
p.scale = scale;
p.use_smooth_softmax = false;
if (nullptr == data.mask_index) {
p.seqlen_k_ptr = nullptr;

View file

@ -220,6 +220,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.bias_strideM = 0;
p.bias_strideB = 0;
}
p.use_smooth_softmax = params.use_smooth_softmax;
}
auto kernel_fn = attention_kernel_batched_impl<Attention>;

View file

@ -25,6 +25,7 @@ struct MemoryEfficientAttentionParams {
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
bool use_smooth_softmax;
float scale;

View file

@ -153,7 +153,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
#if USE_MEMORY_EFFICIENT_ATTENTION
int sm = (device_prop.major * 10) + device_prop.minor;
bool use_memory_efficient_attention =
!use_smooth_softmax_ &&
!use_flash_attention &&
!disable_memory_efficient_attention_ &&
local_window_size_ == -1 &&

View file

@ -678,8 +678,8 @@ Status FlashAttention(
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));
// if (parameters.left_padding && parameters.is_prompt) {
@ -843,6 +843,7 @@ Status EfficientAttention(
: nullptr;
p.stream = stream;
p.has_custom_right_padding = true;
p.use_smooth_softmax = parameters.use_smooth_softmax;
run_memory_efficient_attention(p);
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);

View file

@ -515,6 +515,7 @@ Status FusedScaledDotProductAttentionCutlass(
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
p.use_smooth_softmax = false;
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
p.seqlen_k_ptr = nullptr;

View file

@ -693,6 +693,7 @@ Status FusedAttentionCutlass(
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = false;
p.use_smooth_softmax = false;
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
p.seqlen_k_ptr = nullptr;

View file

@ -2219,7 +2219,7 @@ class TestGQA(unittest.TestCase):
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
use_smooth_softmax=False,
use_smooth_softmax=True,
)
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
@ -2263,7 +2263,7 @@ class TestGQA(unittest.TestCase):
rotary=rotary,
rotary_interleaved=rotary_interleaved,
packed=packed,
use_smooth_softmax=False,
use_smooth_softmax=True,
)
parity_check_gqa_past_no_buff(
config,