mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-07-03 03:58:54 +00:00
Support Smooth Softmax in fmha (#21885)
### Description <!-- Describe your changes. --> refer to https://github.com/microsoft/onnxruntime/pull/21867 ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> --------- Co-authored-by: Your Name <you@example.com>
This commit is contained in:
parent
ef073fd8f4
commit
bf8855ba3c
9 changed files with 66 additions and 9 deletions
|
|
@ -1,13 +1,64 @@
|
|||
diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h
|
||||
index 4c80f549..34327633 100644
|
||||
--- a/examples/41_fused_multi_head_attention/kernel_forward.h
|
||||
+++ b/examples/41_fused_multi_head_attention/kernel_forward.h
|
||||
@@ -221,6 +221,8 @@ struct AttentionKernel {
|
||||
int32_t num_batches = 0;
|
||||
int32_t num_heads = 0;
|
||||
|
||||
+ bool use_smooth_softmax = false;
|
||||
+
|
||||
// dropout
|
||||
bool use_dropout = false;
|
||||
unsigned long long dropout_batch_head_rng_offset = 0;
|
||||
@@ -897,7 +899,8 @@ struct AttentionKernel {
|
||||
p.num_keys - iter_key_start,
|
||||
iter_key_start == 0,
|
||||
iteratorC_tile_offset,
|
||||
- kSupportsBias ? 1.0f : p.scale);
|
||||
+ kSupportsBias ? 1.0f : p.scale,
|
||||
+ p.use_smooth_softmax);
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = my_warp_id %
|
||||
@@ -1166,7 +1169,8 @@ struct AttentionKernel {
|
||||
int max_col,
|
||||
bool is_first,
|
||||
typename WarpIteratorC::TensorCoord const& tile_offset,
|
||||
- float scaling) {
|
||||
+ float scaling,
|
||||
+ bool use_smooth_softmax) {
|
||||
/* Iterates on the accumulator and corresponding position on result matrix
|
||||
|
||||
(1) Update `mi[r]` to the max value of the row `r`
|
||||
@@ -1257,7 +1261,7 @@ struct AttentionKernel {
|
||||
accum_t mi_row, total_row;
|
||||
LambdaIterator::iterateRows(
|
||||
lane_offset,
|
||||
- [&](int accum_m) { mi_row = mi[accum_m]; },
|
||||
+ [&](int accum_m) { mi_row = mi[accum_m];},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
frag[idx] =
|
||||
(accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0);
|
||||
@@ -1294,7 +1298,7 @@ struct AttentionKernel {
|
||||
for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) {
|
||||
total_row += addition_storage[id + kQueriesPerBlock * i];
|
||||
}
|
||||
- s_prime[id] = total_row;
|
||||
+ s_prime[id] = (use_smooth_softmax && (max_col <= kKeysPerBlock)) ? total_row + exp2f(-mi[id]) : total_row;
|
||||
}
|
||||
}
|
||||
|
||||
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
|
||||
index 964d2ff3..b366bc14 100644
|
||||
--- a/include/cutlass/functional.h
|
||||
+++ b/include/cutlass/functional.h
|
||||
@@ -39,6 +39,7 @@
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
|
||||
#include <cuda_runtime.h>
|
||||
+#include <cuda_fp16.h>
|
||||
|
||||
|
||||
#if defined(CUTLASS_ARCH_WMMA_ENABLED)
|
||||
#include <mma.h>
|
||||
@@ -230,8 +231,12 @@ struct inverse_square_root<half_t> {
|
||||
|
|
@ -19,7 +70,7 @@ index 964d2ff3..b366bc14 100644
|
|||
return reinterpret_cast<half_t const &>(result);
|
||||
+#else
|
||||
+ return half_t::convert((rsqrtf(half_t::convert(lhs))));
|
||||
+#endif
|
||||
+#endif
|
||||
#else
|
||||
return half_t(1.f / std::sqrt(half_t::convert(lhs)));
|
||||
#endif
|
||||
#endif
|
||||
|
|
@ -415,6 +415,7 @@ Status EfficientAttention(
|
|||
p.v_head_size = parameters.v_head_size;
|
||||
p.causal = parameters.is_unidirectional;
|
||||
p.scale = scale;
|
||||
p.use_smooth_softmax = false;
|
||||
|
||||
if (nullptr == data.mask_index) {
|
||||
p.seqlen_k_ptr = nullptr;
|
||||
|
|
|
|||
|
|
@ -220,6 +220,8 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
|
|||
p.bias_strideM = 0;
|
||||
p.bias_strideB = 0;
|
||||
}
|
||||
|
||||
p.use_smooth_softmax = params.use_smooth_softmax;
|
||||
}
|
||||
|
||||
auto kernel_fn = attention_kernel_batched_impl<Attention>;
|
||||
|
|
|
|||
|
|
@ -25,6 +25,7 @@ struct MemoryEfficientAttentionParams {
|
|||
int32_t qk_head_size;
|
||||
int32_t v_head_size;
|
||||
bool causal;
|
||||
bool use_smooth_softmax;
|
||||
|
||||
float scale;
|
||||
|
||||
|
|
|
|||
|
|
@ -153,7 +153,6 @@ Status GroupQueryAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
#if USE_MEMORY_EFFICIENT_ATTENTION
|
||||
int sm = (device_prop.major * 10) + device_prop.minor;
|
||||
bool use_memory_efficient_attention =
|
||||
!use_smooth_softmax_ &&
|
||||
!use_flash_attention &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
local_window_size_ == -1 &&
|
||||
|
|
|
|||
|
|
@ -678,8 +678,8 @@ Status FlashAttention(
|
|||
reinterpret_cast<void*>(data.softmax_lse), seqlens_k, cos_cache, sin_cache, /*block_table*/ nullptr,
|
||||
batch_size, num_heads, kv_num_heads, head_size, sequence_length,
|
||||
parameters.seqlen_present_kv_cache, kv_sequence_length, parameters.rotary_dim,
|
||||
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
|
||||
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
|
||||
scale, is_causal, is_bf16, parameters.use_smooth_softmax, past_bsnh, parameters.num_splits,
|
||||
reinterpret_cast<void*>(data.softmax_lse_accum), reinterpret_cast<void*>(data.out_accum),
|
||||
parameters.local_window_size, parameters.rotary_interleaved, parameters.is_packed_qkv));
|
||||
|
||||
// if (parameters.left_padding && parameters.is_prompt) {
|
||||
|
|
@ -843,6 +843,7 @@ Status EfficientAttention(
|
|||
: nullptr;
|
||||
p.stream = stream;
|
||||
p.has_custom_right_padding = true;
|
||||
p.use_smooth_softmax = parameters.use_smooth_softmax;
|
||||
run_memory_efficient_attention(p);
|
||||
|
||||
DUMP_TENSOR("efficient attention output", data.output, batch_size, sequence_length, num_heads, head_size);
|
||||
|
|
|
|||
|
|
@ -515,6 +515,7 @@ Status FusedScaledDotProductAttentionCutlass(
|
|||
p.qk_head_size = parameters.head_size;
|
||||
p.v_head_size = parameters.v_head_size;
|
||||
p.causal = false;
|
||||
p.use_smooth_softmax = false;
|
||||
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
|
||||
: parameters.scale;
|
||||
p.seqlen_k_ptr = nullptr;
|
||||
|
|
|
|||
|
|
@ -693,6 +693,7 @@ Status FusedAttentionCutlass(
|
|||
p.qk_head_size = parameters.head_size;
|
||||
p.v_head_size = parameters.v_head_size;
|
||||
p.causal = false;
|
||||
p.use_smooth_softmax = false;
|
||||
p.scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
|
||||
: parameters.scale;
|
||||
p.seqlen_k_ptr = nullptr;
|
||||
|
|
|
|||
|
|
@ -2219,7 +2219,7 @@ class TestGQA(unittest.TestCase):
|
|||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
use_smooth_softmax=False,
|
||||
use_smooth_softmax=True,
|
||||
)
|
||||
|
||||
@parameterized.expand(gqa_no_past_flash_attention_test_cases())
|
||||
|
|
@ -2263,7 +2263,7 @@ class TestGQA(unittest.TestCase):
|
|||
rotary=rotary,
|
||||
rotary_interleaved=rotary_interleaved,
|
||||
packed=packed,
|
||||
use_smooth_softmax=False,
|
||||
use_smooth_softmax=True,
|
||||
)
|
||||
parity_check_gqa_past_no_buff(
|
||||
config,
|
||||
|
|
|
|||
Loading…
Reference in a new issue