mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-25 02:50:42 +00:00
Extend memory efficient attention coverage in Attention/MHA cuda op (#15064)
### Description <!-- Describe your changes. --> 1. upgrade cutlass to 3.0 that containing attn_bias support. 2. extend Attention/MHA to use memory efficient attention when rel_pos_bias with [1, num_head, s, s*] and 1d mask with [2 * batch_size + 1] are present. new mask format introduction: MASK_1D_KEY_SEQ_LEN_START, [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0], ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ..., key_start[batch_size - 1], key_end[batch_size - 1]] e.g 2D mask with [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 0]] converts to this 1D mask is [3, 5, 0, 6, 12, 0, 6, 12] ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> It potentially benefits tnlrv6 and t5(encoder) --------- Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net> Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com> Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
parent
7033346605
commit
2ee822d483
19 changed files with 382 additions and 999 deletions
|
|
@ -392,7 +392,7 @@
|
|||
"component": {
|
||||
"type": "git",
|
||||
"git": {
|
||||
"commitHash": "66d9cddc832c1cdc2b30a8755274f7f74640cfe6",
|
||||
"commitHash": "c4f6b8c6bc94ff69048492fb34df0dfaf1983933",
|
||||
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
|
||||
},
|
||||
"comments": "cutlass"
|
||||
|
|
|
|||
|
|
@ -34,7 +34,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/5916273f79a21551890fd
|
|||
re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374
|
||||
safeint;https://github.com/dcleblanc/SafeInt/archive/ff15c6ada150a5018c5ef2172401cb4529eac9c0.zip;913a4046e5274d329af2806cb53194f617d8c0ab
|
||||
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
|
||||
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v2.11.0.zip;be70c559f07251ba7f33c789dba98872b444c10f
|
||||
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd
|
||||
# below are deps introduced by triton client, might remove after 1.14 release
|
||||
openssl;https://github.com/openssl/openssl/archive/refs/tags/openssl-3.0.7.zip;dda8fc81308555410505eb4a9eab3e1da0436a1d
|
||||
rapidjson;https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.zip;0fe7b4f7b83df4b3d517f4a202f3a383af7a0818
|
||||
|
|
|
|||
1
cmake/external/cutlass.cmake
vendored
1
cmake/external/cutlass.cmake
vendored
|
|
@ -4,6 +4,7 @@ if (onnxruntime_USE_FLASH_ATTENTION)
|
|||
cutlass
|
||||
URL ${DEP_URL_cutlass}
|
||||
URL_HASH SHA1=${DEP_SHA1_cutlass}
|
||||
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(cutlass)
|
||||
|
|
|
|||
92
cmake/patches/cutlass/cutlass.patch
Normal file
92
cmake/patches/cutlass/cutlass.patch
Normal file
|
|
@ -0,0 +1,92 @@
|
|||
diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp
|
||||
index 3790ebd3..cf727d09 100644
|
||||
--- a/include/cute/numeric/complex.hpp
|
||||
+++ b/include/cute/numeric/complex.hpp
|
||||
@@ -41,10 +41,14 @@
|
||||
// With CUDA 11.4, builds show spurious "-Wconversion" warnings
|
||||
// on line 656 of thrust/detail/type_traits.h.
|
||||
// These pragmas suppress the warnings.
|
||||
+#ifdef __GNUC__
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wconversion"
|
||||
+#endif
|
||||
#include <thrust/complex.h>
|
||||
+#ifdef __GNUC__
|
||||
#pragma GCC diagnostic pop
|
||||
+#endif
|
||||
|
||||
#include <cute/config.hpp>
|
||||
|
||||
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
|
||||
index 59aec46a..8f2a913a 100644
|
||||
--- a/include/cutlass/functional.h
|
||||
+++ b/include/cutlass/functional.h
|
||||
@@ -89,7 +89,7 @@ struct multiplies {
|
||||
}
|
||||
};
|
||||
|
||||
-#if defined(__CUDA_ARCH__)
|
||||
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
|
||||
/// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set
|
||||
template<>
|
||||
struct plus<__half2> {
|
||||
@@ -143,12 +143,12 @@ struct multiplies<__half> {
|
||||
|
||||
|
||||
// Maximum with nan propogation
|
||||
-// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
|
||||
+// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
|
||||
template <typename T>
|
||||
struct maximum_with_nan_propogation {
|
||||
CUTLASS_HOST_DEVICE
|
||||
T operator()(T const &lhs, T const &rhs) const {
|
||||
- return lhs > rhs or std::isnan(lhs) ? lhs : rhs;
|
||||
+ return lhs > rhs or isnan(lhs) ? lhs : rhs;
|
||||
}
|
||||
};
|
||||
|
||||
@@ -160,7 +160,7 @@ struct maximum_with_nan_propogation<float> {
|
||||
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
|
||||
asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs));
|
||||
#else
|
||||
- res = lhs > rhs or std::isnan(lhs) ? lhs : rhs;
|
||||
+ res = lhs > rhs or isnan(lhs) ? lhs : rhs;
|
||||
#endif
|
||||
return res;
|
||||
}
|
||||
@@ -233,7 +233,7 @@ struct negate {
|
||||
}
|
||||
};
|
||||
|
||||
-/// Greater equal
|
||||
+/// Greater equal
|
||||
template <typename T>
|
||||
struct greater_equal {
|
||||
CUTLASS_HOST_DEVICE
|
||||
@@ -242,7 +242,7 @@ struct greater_equal {
|
||||
}
|
||||
};
|
||||
|
||||
-/// Greater
|
||||
+/// Greater
|
||||
template <typename T>
|
||||
struct greater {
|
||||
CUTLASS_HOST_DEVICE
|
||||
@@ -251,7 +251,7 @@ struct greater {
|
||||
}
|
||||
};
|
||||
|
||||
-/// Less equal
|
||||
+/// Less equal
|
||||
template <typename T>
|
||||
struct less_equal {
|
||||
CUTLASS_HOST_DEVICE
|
||||
@@ -260,7 +260,7 @@ struct less_equal {
|
||||
}
|
||||
};
|
||||
|
||||
-/// Less
|
||||
+/// Less
|
||||
template <typename T>
|
||||
struct less {
|
||||
CUTLASS_HOST_DEVICE
|
||||
4
docs/ContribOperators.md
Normal file → Executable file
4
docs/ContribOperators.md
Normal file → Executable file
|
|
@ -155,7 +155,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dt><tt>bias</tt> (optional) : T</dt>
|
||||
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection</dd>
|
||||
<dt><tt>mask_index</tt> (optional) : M</dt>
|
||||
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)</dd>
|
||||
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)</dd>
|
||||
<dt><tt>past</tt> (optional) : T</dt>
|
||||
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
|
||||
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
|
||||
|
|
@ -2404,7 +2404,7 @@ This version of the operator has been available since version 1 of the 'com.micr
|
|||
<dt><tt>bias</tt> (optional) : T</dt>
|
||||
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
|
||||
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
|
||||
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
|
||||
<dd>Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)</dd>
|
||||
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
|
||||
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
|
||||
<dt><tt>past_key</tt> (optional) : T</dt>
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
|
|||
|
||||
// For mask_index, the following shapes are supported:
|
||||
// NULL, (B, 1), (1, 1)
|
||||
// (B), (2 * B),
|
||||
// (B), (2 * B), (3 * B + 2)
|
||||
// (B, T)
|
||||
// (B, S, T)
|
||||
// (B, 1, M, M)
|
||||
|
|
@ -274,11 +274,13 @@ Status AttentionBase::CheckMask(const Tensor* mask_index,
|
|||
int64_t total_sequence_length) const {
|
||||
const auto& mask_dims = mask_index->Shape().GetDims();
|
||||
if (mask_dims.size() == 1) {
|
||||
if (mask_dims[0] != batch_size && mask_dims[0] != 2 * batch_size) {
|
||||
if (mask_dims[0] != batch_size && mask_dims[0] != 2 * batch_size && mask_dims[0] != 3 * batch_size + 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
|
||||
"Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size");
|
||||
"Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size or 3 * batch_size + 2");
|
||||
}
|
||||
mask_type = (mask_dims[0] == batch_size ? AttentionMaskType::MASK_1D_KEY_SEQ_LEN : AttentionMaskType::MASK_1D_END_START);
|
||||
mask_type = (mask_dims[0] == batch_size ?
|
||||
AttentionMaskType::MASK_1D_KEY_SEQ_LEN :
|
||||
mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START : AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
|
||||
} else if (mask_dims.size() == 2) {
|
||||
if (mask_dims[0] == batch_size && mask_dims[1] == total_sequence_length) {
|
||||
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
|
||||
|
|
|
|||
|
|
@ -7,13 +7,16 @@ namespace onnxruntime {
|
|||
namespace contrib {
|
||||
|
||||
enum AttentionMaskType {
|
||||
MASK_NONE, // No mask
|
||||
MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length
|
||||
MASK_1D_END_START, // [2 * batch_size] with end positions and start positions
|
||||
MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
|
||||
MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length]
|
||||
MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length]
|
||||
MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
|
||||
MASK_NONE, // No mask
|
||||
MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length
|
||||
MASK_1D_END_START, // [2 * batch_size] with end positions and start positions
|
||||
MASK_1D_KEY_SEQ_LEN_START, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0],
|
||||
// ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ...,
|
||||
// key_start[batch_size - 1], key_end[batch_size - 1]]
|
||||
MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
|
||||
MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length]
|
||||
MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length]
|
||||
MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
|
||||
MASK_UNKNOWN
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ Status CheckInputs(const T* query,
|
|||
float mask_filter_value,
|
||||
float scale,
|
||||
int max_threads_per_block) {
|
||||
// key_padding_mask (K/V) : (B) or (B, L) or None
|
||||
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
|
||||
// relative_position_bias : (B, 1, S, L)
|
||||
// past_key : (B, N, S*, H)
|
||||
// past_value : (B, N, S*, H)
|
||||
|
|
@ -188,8 +188,12 @@ Status CheckInputs(const T* query,
|
|||
if (key_padding_mask != nullptr) {
|
||||
mask_type = AttentionMaskType::MASK_UNKNOWN;
|
||||
const auto& mask_dims = key_padding_mask->Shape().GetDims();
|
||||
if (mask_dims.size() == 1 && mask_dims[0] == static_cast<int64_t>(batch_size)) {
|
||||
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
if (mask_dims.size() == 1) {
|
||||
if (mask_dims[0] == static_cast<int64_t>(batch_size)) {
|
||||
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
} else if (mask_dims[0] == static_cast<int64_t>(3 * batch_size + 2)) {
|
||||
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
|
||||
}
|
||||
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
|
||||
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
// Check whether we can use fused kernel
|
||||
int sm = device_prop.major * 10 + device_prop.minor;
|
||||
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
|
||||
|
||||
if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT
|
||||
// GPT fused kernels requires left side padding. mask can be:
|
||||
|
|
@ -151,12 +152,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
|
||||
bool use_memory_efficient_attention = fused_runner == nullptr &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
nullptr == mask_index && // TODO: support 1D mask
|
||||
(nullptr == mask_index || is_mask_1d_key_seq_len_start) &&
|
||||
nullptr == past &&
|
||||
nullptr == present &&
|
||||
nullptr == relative_position_bias &&
|
||||
(nullptr == relative_position_bias || is_good_for_rpb) &&
|
||||
(sizeof(T) == 2 || // sequence length threshold is 0 in FP16
|
||||
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2);
|
||||
|
|
|
|||
|
|
@ -445,6 +445,14 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
|
|||
DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size);
|
||||
DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
|
||||
|
||||
if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
|
||||
DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, num_heads, sequence_length, kv_sequence_length);
|
||||
}
|
||||
|
||||
if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
|
||||
DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
|
||||
}
|
||||
|
||||
if (data.fused_cross_attention_kernel != nullptr) {
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
|
|
@ -735,11 +743,14 @@ Status QkvToContext(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
|
||||
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
|
||||
: parameters.scale;
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
if (data.use_memory_efficient_attention) {
|
||||
// We only enable fused cross attention when there is no key padding mask.
|
||||
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
|
||||
assert(data.mask_index == nullptr);
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
|
||||
const void* query = q;
|
||||
|
|
@ -754,23 +765,26 @@ Status QkvToContext(
|
|||
MemoryEfficientAttentionParams p;
|
||||
p.sm = device_prop.major * 10 + device_prop.minor;
|
||||
p.is_half = sizeof(T) == 2;
|
||||
p.batch_size = data.mask_index == nullptr ? parameters.batch_size : 2 * parameters.batch_size;
|
||||
p.batch_size = parameters.batch_size;
|
||||
p.num_heads = parameters.num_heads;
|
||||
p.sequence_length = parameters.sequence_length;
|
||||
p.kv_sequence_length = parameters.total_sequence_length;
|
||||
p.qk_head_size = parameters.head_size;
|
||||
p.v_head_size = parameters.v_head_size;
|
||||
p.causal = parameters.is_unidirectional;
|
||||
p.cu_seqlens_q = nullptr;
|
||||
p.cu_seqlens_k = nullptr;
|
||||
p.scale = scale;
|
||||
p.seqlen_k_ptr = nullptr == data.mask_index ? nullptr : const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index));
|
||||
p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr : const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index + batch_size));
|
||||
p.seqstart_k_ptr = nullptr == data.mask_index ? nullptr : const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index + 2 * batch_size + 1));
|
||||
p.query = query;
|
||||
p.key = key;
|
||||
p.value = value;
|
||||
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
|
||||
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
|
||||
p.output = data.output;
|
||||
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr;
|
||||
p.stream = stream;
|
||||
run_memory_efficient_attention(p);
|
||||
|
||||
DUMP_TENSOR("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size);
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
@ -789,9 +803,6 @@ Status QkvToContext(
|
|||
float one = 1.0f;
|
||||
float zero = 0.f;
|
||||
|
||||
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
|
||||
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
|
||||
: parameters.scale;
|
||||
float alpha = use_raw_attention_mask ? one : scale;
|
||||
|
||||
cublasSetStream(cublas, stream);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@
|
|||
#endif
|
||||
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h"
|
||||
#include "41_fused_multi_head_attention/kernel_forward.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
|
|
@ -24,8 +24,10 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
|
|||
p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query));
|
||||
p.key_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.key));
|
||||
p.value_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.value));
|
||||
p.cu_seqlens_q_ptr = params.cu_seqlens_q;
|
||||
p.cu_seqlens_k_ptr = params.cu_seqlens_k;
|
||||
p.attn_bias_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.attn_bias));
|
||||
p.seqstart_q_ptr = params.seqstart_q_ptr;
|
||||
p.seqstart_k_ptr = params.seqstart_k_ptr;
|
||||
p.seqlen_k_ptr = params.seqlen_k_ptr;
|
||||
|
||||
p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward
|
||||
p.output_ptr = reinterpret_cast<T*>(params.output);
|
||||
|
|
@ -42,28 +44,32 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
|
|||
p.head_dim = params.qk_head_size;
|
||||
p.head_dim_value = params.v_head_size;
|
||||
|
||||
p.scale = params.scale;
|
||||
|
||||
// When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel
|
||||
p.num_queries = params.sequence_length;
|
||||
p.num_keys = params.kv_sequence_length;
|
||||
|
||||
p.causal = params.causal;
|
||||
if (params.causal) {
|
||||
p.custom_mask_type = Attention::CausalFromTopLeft;
|
||||
}
|
||||
|
||||
// Input format is BxSxNxH, output is BxSxNxH
|
||||
p.q_strideH = params.qk_head_size;
|
||||
p.k_strideH = params.qk_head_size;
|
||||
p.v_strideH = params.v_head_size;
|
||||
p.o_strideH = params.v_head_size;
|
||||
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
|
||||
|
||||
p.q_strideM = params.num_heads * params.qk_head_size;
|
||||
p.k_strideM = params.num_heads * params.qk_head_size;
|
||||
p.v_strideM = params.num_heads * params.v_head_size;
|
||||
p.o_strideM = params.num_heads * params.v_head_size;
|
||||
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
|
||||
|
||||
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
|
||||
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
|
||||
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
|
||||
p.o_strideB = static_cast<int64_t>(params.num_heads) * params.v_head_size * params.sequence_length;
|
||||
|
||||
p.causal = params.causal;
|
||||
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
|
||||
}
|
||||
|
||||
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
|
||||
|
|
|
|||
|
|
@ -1,947 +0,0 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "41_fused_multi_head_attention/attention_scaling_coefs_updater.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "41_fused_multi_head_attention/debug_utils.h"
|
||||
#include "41_fused_multi_head_attention/epilogue_pipelined.h"
|
||||
#include "41_fused_multi_head_attention/epilogue_rescale_output.h"
|
||||
#include "41_fused_multi_head_attention/find_default_mma.h"
|
||||
#include "41_fused_multi_head_attention/gemm_kernel_utils.h"
|
||||
#include "41_fused_multi_head_attention/mma_from_smem.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t, typename Arch>
|
||||
constexpr int getWarpsPerSm() {
|
||||
return (
|
||||
Arch::kMinComputeCapability >= 80 &&
|
||||
!cutlass::platform::is_same<scalar_t, float>::value
|
||||
? 16
|
||||
: 12);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <
|
||||
// The datatype of Q/K/V
|
||||
typename scalar_t_,
|
||||
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
|
||||
typename ArchTag,
|
||||
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
|
||||
>
|
||||
struct AttentionKernel {
|
||||
using scalar_t = scalar_t_;
|
||||
using accum_t = float;
|
||||
using lse_scalar_t = float;
|
||||
using output_t = scalar_t;
|
||||
// Accumulator between 2 iterations
|
||||
// Using `accum_t` improves perf on f16 at the cost of
|
||||
// numerical errors
|
||||
using output_accum_t = accum_t;
|
||||
static constexpr bool kIsAligned = isAligned_;
|
||||
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
||||
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
||||
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||
|
||||
static_assert(kQueriesPerBlock % 32 == 0, "");
|
||||
static_assert(kKeysPerBlock % 32 == 0, "");
|
||||
static constexpr int kNumWarpsPerBlock =
|
||||
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
|
||||
static constexpr int kWarpSize = 32;
|
||||
|
||||
// Launch bounds
|
||||
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
||||
static constexpr int kMinBlocksPerSm =
|
||||
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
||||
|
||||
struct Params {
|
||||
// Input tensors
|
||||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
||||
int32_t* cu_seqlens_q_ptr = nullptr;
|
||||
int32_t* cu_seqlens_k_ptr = nullptr;
|
||||
|
||||
// Output tensors
|
||||
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
output_accum_t*
|
||||
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
||||
|
||||
// Dimensions/strides
|
||||
int32_t head_dim;
|
||||
int32_t head_dim_value;
|
||||
int32_t num_queries;
|
||||
int32_t num_keys;
|
||||
|
||||
bool causal;
|
||||
|
||||
int32_t q_strideM;
|
||||
int32_t k_strideM;
|
||||
int32_t v_strideM;
|
||||
|
||||
// Everything below is only used in `advance_to_block`
|
||||
// and shouldn't use registers
|
||||
int32_t q_strideH;
|
||||
int32_t k_strideH;
|
||||
int32_t v_strideH;
|
||||
int32_t o_strideH;
|
||||
int64_t q_strideB;
|
||||
int64_t k_strideB;
|
||||
int64_t v_strideB;
|
||||
int64_t o_strideB;
|
||||
int32_t num_batches;
|
||||
int32_t num_heads;
|
||||
|
||||
// https://github.com/NVIDIA/cutlass/issues/771
|
||||
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
||||
return head_dim_value * num_heads;
|
||||
}
|
||||
|
||||
// Moves pointers to what we should process
|
||||
// Returns "false" if there is no work to do
|
||||
CUTLASS_DEVICE bool advance_to_block() {
|
||||
auto batch_id = blockIdx.z;
|
||||
auto head_id = blockIdx.y;
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
|
||||
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
|
||||
|
||||
int64_t q_start, k_start;
|
||||
// Advance to current batch - in case of different sequence lengths
|
||||
if (cu_seqlens_q_ptr != nullptr) {
|
||||
assert(cu_seqlens_k_ptr != nullptr);
|
||||
cu_seqlens_q_ptr += batch_id;
|
||||
cu_seqlens_k_ptr += batch_id;
|
||||
q_start = cu_seqlens_q_ptr[0];
|
||||
k_start = cu_seqlens_k_ptr[0];
|
||||
int64_t q_next_start = cu_seqlens_q_ptr[1];
|
||||
int64_t k_next_start = cu_seqlens_k_ptr[1];
|
||||
num_queries = q_next_start - q_start;
|
||||
num_keys = k_next_start - k_start;
|
||||
|
||||
if (query_start >= num_queries) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
query_ptr += batch_id * q_strideB;
|
||||
key_ptr += batch_id * k_strideB;
|
||||
value_ptr += batch_id * v_strideB;
|
||||
output_ptr += batch_id * o_strideB;
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += batch_id * o_strideB;
|
||||
}
|
||||
q_start = 0;
|
||||
k_start = 0;
|
||||
}
|
||||
|
||||
// Advance to the current batch / head / query_start
|
||||
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
|
||||
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
||||
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
||||
output_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * o_strideH;
|
||||
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * o_strideH;
|
||||
} else {
|
||||
// Accumulate directly in the destination buffer (eg for f32)
|
||||
output_accum_ptr = (accum_t*)output_ptr;
|
||||
}
|
||||
if (logsumexp_ptr != nullptr) {
|
||||
// lse[batch_id, head_id, query_start]
|
||||
logsumexp_ptr +=
|
||||
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
|
||||
}
|
||||
|
||||
num_queries -= query_start;
|
||||
if (causal) {
|
||||
num_keys = cutlass::fast_min(
|
||||
int32_t(query_start + kQueriesPerBlock), num_keys);
|
||||
}
|
||||
num_batches = 0; // no longer used after
|
||||
|
||||
// Make sure the compiler knows these variables are the same on all
|
||||
// the threads of the warp.
|
||||
query_ptr = warp_uniform(query_ptr);
|
||||
key_ptr = warp_uniform(key_ptr);
|
||||
value_ptr = warp_uniform(value_ptr);
|
||||
output_ptr = warp_uniform(output_ptr);
|
||||
output_accum_ptr = warp_uniform(output_accum_ptr);
|
||||
logsumexp_ptr = warp_uniform(logsumexp_ptr);
|
||||
num_queries = warp_uniform(num_queries);
|
||||
num_keys = warp_uniform(num_keys);
|
||||
head_dim = warp_uniform(head_dim);
|
||||
head_dim_value = warp_uniform(head_dim_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ dim3 getBlocksGrid() const {
|
||||
return dim3(
|
||||
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
|
||||
num_heads,
|
||||
num_batches);
|
||||
}
|
||||
__host__ dim3 getThreadsGrid() const {
|
||||
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
|
||||
}
|
||||
};
|
||||
|
||||
struct MM0 {
|
||||
/*
|
||||
In this first matmul, we compute a block of `Q @ K.T`.
|
||||
While the calculation result is still hot in registers, we update
|
||||
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
|
||||
into a shared-memory ("AccumulatorSharedStorage") that is used later as
|
||||
operand A for the second matmul (see MM1)
|
||||
*/
|
||||
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
||||
|
||||
using OpClass = typename GemmType::OpClass;
|
||||
using DefaultConfig =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
ArchTag,
|
||||
scalar_t,
|
||||
scalar_t,
|
||||
scalar_t, // ElementC
|
||||
accum_t // ElementAccumulator
|
||||
>;
|
||||
static constexpr int kAlignmentA =
|
||||
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
|
||||
static constexpr int kAlignmentB =
|
||||
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
||||
using ThreadblockShape = cutlass::gemm::
|
||||
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
||||
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
|
||||
scalar_t, // ElementA,
|
||||
cutlass::layout::RowMajor, // LayoutA,
|
||||
kAlignmentA,
|
||||
scalar_t, // ElementB,
|
||||
cutlass::layout::ColumnMajor, // LayoutB,
|
||||
kAlignmentB,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor, // LayoutC,
|
||||
OpClass,
|
||||
ArchTag, // ArchTag
|
||||
ThreadblockShape, // ThreadblockShape
|
||||
WarpShape, // WarpShape
|
||||
typename GemmType::InstructionShape, // InstructionShape
|
||||
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
|
||||
// uses too much smem
|
||||
typename GemmType::Operator // Operator
|
||||
>::DefaultMma;
|
||||
using MmaCore = typename DefaultMma::MmaCore;
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
|
||||
typename Mma::Operator::IteratorC,
|
||||
accum_t,
|
||||
kWarpSize>::Updater;
|
||||
static_assert(
|
||||
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
|
||||
MmaCore::WarpCount::kK ==
|
||||
kNumWarpsPerBlock,
|
||||
"");
|
||||
|
||||
// Epilogue to store to shared-memory in a format that we can use later for
|
||||
// the second matmul
|
||||
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
||||
typename Mma::Operator::IteratorC,
|
||||
typename Mma::Operator,
|
||||
scalar_t,
|
||||
WarpShape,
|
||||
ThreadblockShape>;
|
||||
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
|
||||
};
|
||||
|
||||
struct MM1 {
|
||||
/**
|
||||
Second matmul: perform `attn @ V` where `attn` is the attention (not
|
||||
normalized) and stored in shared memory
|
||||
*/
|
||||
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
||||
|
||||
using OpClass = typename GemmType::OpClass;
|
||||
using DefaultConfig =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
ArchTag,
|
||||
scalar_t,
|
||||
scalar_t,
|
||||
output_accum_t, // ElementC
|
||||
accum_t // ElementAccumulator
|
||||
>;
|
||||
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
|
||||
static constexpr int kAlignmentB =
|
||||
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
||||
using ThreadblockShape = cutlass::gemm::
|
||||
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
||||
using InstructionShape = typename GemmType::InstructionShape;
|
||||
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
||||
scalar_t, // ElementA,
|
||||
cutlass::layout::RowMajor, // LayoutA,
|
||||
kAlignmentA,
|
||||
scalar_t, // ElementB,
|
||||
LayoutB, // LayoutB,
|
||||
kAlignmentB,
|
||||
output_accum_t,
|
||||
cutlass::layout::RowMajor, // LayoutC,
|
||||
accum_t,
|
||||
OpClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
typename GemmType::InstructionShape,
|
||||
typename DefaultConfig::EpilogueOutputOp,
|
||||
void, // ThreadblockSwizzle - not used
|
||||
DefaultConfig::kStages,
|
||||
false, // SplitKSerial
|
||||
typename GemmType::Operator>;
|
||||
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage>;
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static_assert(
|
||||
WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
|
||||
"");
|
||||
|
||||
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
||||
using OutputTileIterator =
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_t>;
|
||||
using OutputTileIteratorAccum =
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_accum_t>;
|
||||
|
||||
struct SharedStorageMM1 {
|
||||
typename Mma::SharedStorage mm;
|
||||
};
|
||||
};
|
||||
|
||||
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
||||
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
|
||||
static constexpr int64_t kAlignmentV = 1;
|
||||
|
||||
// Shared storage - depends on kernel params
|
||||
struct ScalingCoefs {
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
};
|
||||
|
||||
union {
|
||||
typename MM0::Mma::SharedStorage mm0;
|
||||
SharedStorageAfterMM0 after_mm0;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
||||
epilogue_shared_storage() {
|
||||
return epilogue;
|
||||
}
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueInLoop : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
union {
|
||||
typename MM0::Mma::SharedStorage mm0;
|
||||
SharedStorageAfterMM0 after_mm0;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
||||
epilogue_shared_storage() {
|
||||
return after_mm0.epilogue;
|
||||
}
|
||||
};
|
||||
|
||||
using SharedStorage = typename cutlass::platform::conditional<
|
||||
kSingleValueIteration || kKeepOutputInRF,
|
||||
SharedStorageEpilogueAtEnd,
|
||||
SharedStorageEpilogueInLoop>::type;
|
||||
|
||||
static bool __host__ check_supported(Params const& p) {
|
||||
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
|
||||
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
|
||||
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
|
||||
return true;
|
||||
}
|
||||
|
||||
static void CUTLASS_DEVICE attention_kernel(Params& p) {
|
||||
// In this block, we will only ever:
|
||||
// - read query[query_start:query_end, :]
|
||||
// - write to output[query_start:query_end, :]
|
||||
|
||||
extern __shared__ char smem_buffer[];
|
||||
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
auto& mi = shared_storage.mi;
|
||||
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = accum_t(0);
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
typename MM1::Mma::FragmentC accum_o;
|
||||
accum_o.clear();
|
||||
|
||||
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
|
||||
using OutputTileIterator = typename MM1::OutputTileIterator;
|
||||
return OutputTileIterator(
|
||||
typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
|
||||
p.output_ptr,
|
||||
typename OutputTileIterator::TensorCoord{
|
||||
p.num_queries, p.head_dim_value},
|
||||
thread_id(),
|
||||
{0, col});
|
||||
};
|
||||
|
||||
auto createOutputAccumIter = [&](int col) ->
|
||||
typename MM1::OutputTileIteratorAccum {
|
||||
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
|
||||
return OutputTileIteratorAccum(
|
||||
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
|
||||
p.output_accum_ptr,
|
||||
typename OutputTileIteratorAccum::TensorCoord{
|
||||
p.num_queries, p.head_dim_value},
|
||||
thread_id(),
|
||||
{0, col});
|
||||
};
|
||||
|
||||
// Iterate through keys
|
||||
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
|
||||
iter_key_start += kKeysPerBlock) {
|
||||
int32_t problem_size_0_m =
|
||||
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
|
||||
int32_t problem_size_0_n = cutlass::fast_min(
|
||||
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
|
||||
int32_t const& problem_size_0_k = p.head_dim;
|
||||
int32_t const& problem_size_1_n = p.head_dim_value;
|
||||
int32_t const& problem_size_1_k = problem_size_0_n;
|
||||
|
||||
auto prologueV = [&](int blockN) {
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
|
||||
p.value_ptr + iter_key_start * p.v_strideM,
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
MM1::Mma::prologue(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
iterator_V,
|
||||
thread_id(),
|
||||
problem_size_1_k);
|
||||
};
|
||||
|
||||
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
|
||||
// updated from end of prev iter
|
||||
//
|
||||
// MATMUL: Q.K_t
|
||||
//
|
||||
// Computes the block-matrix product of:
|
||||
// (a) query[query_start:query_end, :]
|
||||
// with
|
||||
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
|
||||
// and stores that into `shared_storage.si`
|
||||
//
|
||||
|
||||
// Compute threadblock location
|
||||
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename MM0::IteratorA iterator_A(
|
||||
typename MM0::IteratorA::Params(
|
||||
typename MM0::MmaCore::LayoutA(p.q_strideM)),
|
||||
p.query_ptr,
|
||||
{problem_size_0_m, problem_size_0_k},
|
||||
thread_id(),
|
||||
tb_offset_A);
|
||||
|
||||
typename MM0::IteratorB iterator_B(
|
||||
typename MM0::IteratorB::Params(
|
||||
typename MM0::MmaCore::LayoutB(p.k_strideM)),
|
||||
p.key_ptr + iter_key_start * p.k_strideM,
|
||||
{problem_size_0_k, problem_size_0_n},
|
||||
thread_id(),
|
||||
tb_offset_B);
|
||||
|
||||
auto my_warp_id = warp_id();
|
||||
auto my_lane_id = lane_id();
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
typename MM0::Mma mma(
|
||||
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
|
||||
|
||||
typename MM0::Mma::FragmentC accum;
|
||||
|
||||
accum.clear();
|
||||
|
||||
auto gemm_k_iterations =
|
||||
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
|
||||
__syncthreads();
|
||||
|
||||
if (kPreloadV) {
|
||||
prologueV(0);
|
||||
}
|
||||
|
||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||
iteratorC_tile_offset = {
|
||||
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
|
||||
(my_warp_id % MM0::Mma::WarpCount::kM),
|
||||
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
|
||||
(my_warp_id / MM0::Mma::WarpCount::kM)};
|
||||
|
||||
// Mask out last if causal
|
||||
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
int32_t last_col;
|
||||
MM0::ScalingCoefsUpdater::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
last_col = query_start + accum_m - iter_key_start;
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (accum_n > last_col) {
|
||||
accum[idx] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
p.num_keys - iter_key_start >= kKeysPerBlock,
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also updates `accum` with accum[i] <-
|
||||
// exp(accum[i] * scale
|
||||
// - mi)
|
||||
MM0::ScalingCoefsUpdater::update<
|
||||
kQueriesPerBlock,
|
||||
kFullColumns,
|
||||
kIsFirst,
|
||||
kKeepOutputInRF>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
p.num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
|
||||
}));
|
||||
}));
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = my_warp_id %
|
||||
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
|
||||
auto output_tile_coords = cutlass::MatrixCoord{
|
||||
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
|
||||
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
|
||||
|
||||
MM0::B2bGemm::accumToSmem(
|
||||
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// MATMUL: Attn . V
|
||||
// Run the matmul `attn @ V` for a block of attn and V.
|
||||
// `attn` is read from shared memory (in `shared_storage_si`)
|
||||
// `V` is read from global memory (with iterator_B)
|
||||
//
|
||||
|
||||
const int64_t nBlockN = kSingleValueIteration
|
||||
? 1
|
||||
: ceil_div(
|
||||
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
|
||||
for (int blockN = 0; blockN < nBlockN; ++blockN) {
|
||||
int gemm_k_iterations =
|
||||
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add and store it in accum
|
||||
// (in registers)
|
||||
if (!kPreloadV) {
|
||||
__syncthreads(); // we share shmem between mma and epilogue
|
||||
}
|
||||
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
|
||||
p.value_ptr + iter_key_start * p.v_strideM,
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
typename MM1::Mma mma_pv(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.si,
|
||||
(int)thread_id(),
|
||||
(int)warp_id(),
|
||||
(int)lane_id(),
|
||||
(int)problem_size_1_k);
|
||||
mma_pv.set_prologue_done(kPreloadV);
|
||||
if (!kKeepOutputInRF) {
|
||||
accum_o.clear();
|
||||
}
|
||||
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
|
||||
__syncthreads();
|
||||
|
||||
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
|
||||
prologueV(blockN + 1);
|
||||
}
|
||||
|
||||
if (!kKeepOutputInRF) {
|
||||
DISPATCH_BOOL(
|
||||
iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
(iter_key_start + kKeysPerBlock) >= p.num_keys,
|
||||
kIsLast,
|
||||
([&] {
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp =
|
||||
typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
output_accum_t,
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator,
|
||||
ElementCompute,
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpiloguePipelined<
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
typename DefaultEpilogue::
|
||||
AccumulatorFragmentIterator,
|
||||
typename DefaultEpilogue::WarpTileIterator,
|
||||
typename DefaultEpilogue::SharedLoadIterator,
|
||||
EpilogueOutputOp,
|
||||
typename DefaultEpilogue::Padding,
|
||||
DefaultEpilogue::kFragmentsPerIteration,
|
||||
true, // IterationsUnroll
|
||||
typename MM1::OutputTileIteratorAccum // Read
|
||||
// iterator
|
||||
>;
|
||||
|
||||
int col = blockN * MM1::Mma::Shape::kN;
|
||||
auto source_iter = createOutputAccumIter(col);
|
||||
auto dest_iter = call_conditional<
|
||||
kIsLast,
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
epilogue(rescale, dest_iter, accum_o, source_iter);
|
||||
}));
|
||||
}));
|
||||
if (!kSingleValueIteration) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // we modify `m_prime` after
|
||||
}
|
||||
|
||||
if (kKeepOutputInRF) {
|
||||
constexpr bool kIsFirst = true;
|
||||
constexpr bool kIsLast = true;
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp =
|
||||
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
|
||||
output_t, // output
|
||||
output_accum_t, // source
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator, // accum
|
||||
output_accum_t, // compute
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::EpiloguePipelined<
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename MM1::OutputTileIterator, // destination
|
||||
typename DefaultEpilogue::AccumulatorFragmentIterator,
|
||||
typename DefaultEpilogue::WarpTileIterator,
|
||||
typename DefaultEpilogue::SharedLoadIterator,
|
||||
EpilogueOutputOp,
|
||||
typename DefaultEpilogue::Padding,
|
||||
DefaultEpilogue::kFragmentsPerIteration,
|
||||
true, // IterationsUnroll
|
||||
typename MM1::OutputTileIteratorAccum // source tile
|
||||
>;
|
||||
auto dest_iter = createOutputIter(0);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
epilogue(rescale, dest_iter, accum_o);
|
||||
}
|
||||
|
||||
// 7. Calculate logsumexp
|
||||
// To make the backward easier, we pad logsumexp with `inf`
|
||||
// this avoids a few bound checks, and is not more expensive during fwd
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
||||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||
if (thread_id() < p.num_queries) {
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
|
||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||
} else if (thread_id() < lse_dim) {
|
||||
p.logsumexp_ptr[thread_id()] =
|
||||
cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int8_t lane_id() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
static CUTLASS_DEVICE int8_t warp_id() {
|
||||
return threadIdx.y;
|
||||
}
|
||||
static CUTLASS_DEVICE int16_t thread_id() {
|
||||
return threadIdx.x + threadIdx.y * blockDim.x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AK>
|
||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
attention_kernel_batched_impl(typename AK::Params p) {
|
||||
if (!p.advance_to_block()) {
|
||||
return;
|
||||
}
|
||||
AK::attention_kernel(p);
|
||||
}
|
||||
|
||||
template <typename AK>
|
||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
attention_kernel_batched(typename AK::Params params);
|
||||
|
||||
#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \
|
||||
template <> \
|
||||
__global__ void __launch_bounds__( \
|
||||
__VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \
|
||||
attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \
|
||||
using Kernel = __VA_ARGS__;
|
||||
#define _ATTENTION_KERNEL_FORWARD_END() }
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__
|
||||
#else
|
||||
#define __CUDA_ARCH_OR_ZERO__ 0
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \
|
||||
ARCH, \
|
||||
SCALAR_T, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER) \
|
||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
||||
SCALAR_T, \
|
||||
cutlass::arch::Sm##ARCH, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER>) \
|
||||
if (!p.advance_to_block()) { \
|
||||
return; \
|
||||
} \
|
||||
Kernel::attention_kernel(p); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \
|
||||
ARCH, \
|
||||
SCALAR_T, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER) \
|
||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
||||
SCALAR_T, \
|
||||
cutlass::arch::Sm##ARCH, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER>) \
|
||||
printf( \
|
||||
"FATAL: this function is for sm%d, but was built for sm%d\n", \
|
||||
int(ARCH), \
|
||||
int(__CUDA_ARCH_OR_ZERO__)); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
// All kernels are disabled by default
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
|
||||
|
||||
// Enable the right one based on __CUDA_ARCH__
|
||||
#ifndef __CUDA_ARCH__
|
||||
#elif __CUDA_ARCH__ < 500
|
||||
//#error "Need cuda arch at least 5.0"
|
||||
#elif __CUDA_ARCH__ < 700
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ < 750
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ < 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ >= 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
@ -21,15 +21,21 @@ struct MemoryEfficientAttentionParams {
|
|||
int32_t qk_head_size;
|
||||
int32_t v_head_size;
|
||||
bool causal;
|
||||
// The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models.
|
||||
bool is_attn_bias_batched;
|
||||
|
||||
int32_t* cu_seqlens_q;
|
||||
int32_t* cu_seqlens_k;
|
||||
float scale;
|
||||
|
||||
const void* query; // [B, S, N, H]
|
||||
const void* key; // [B, L, N, H], where L is kv_sequence_length
|
||||
const void* value; // [B, L, N, H_v]
|
||||
void* output; // [B, S, N, H_v]
|
||||
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
|
||||
int32_t* seqstart_q_ptr;
|
||||
int32_t* seqstart_k_ptr;
|
||||
int32_t* seqlen_k_ptr;
|
||||
|
||||
const void* query; // [B, S, N, H]
|
||||
const void* key; // [B, L, N, H], where L is kv_sequence_length
|
||||
const void* value; // [B, L, N, H_v]
|
||||
const void* attn_bias; // [N, S, S*] or null
|
||||
void* output; // [B, S, N, H_v]
|
||||
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
|
||||
cudaStream_t stream;
|
||||
|
||||
static bool need_workspace(size_t v_head_size, bool is_float) {
|
||||
|
|
|
|||
|
|
@ -116,6 +116,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
int sm = device_prop.major * 10 + device_prop.minor;
|
||||
|
||||
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
|
||||
|
||||
bool use_fused_cross_attention = !disable_fused_cross_attention_ &&
|
||||
nullptr == key_padding_mask &&
|
||||
|
|
@ -168,12 +169,14 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 ||
|
||||
parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32;
|
||||
|
||||
bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
|
||||
|
||||
bool use_memory_efficient_attention = fused_runner == nullptr &&
|
||||
fused_cross_attention_kernel == nullptr &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
is_long_sequence &&
|
||||
nullptr == key_padding_mask && // TODO: support 1D mask
|
||||
nullptr == relative_position_bias &&
|
||||
(relative_position_bias == nullptr || is_good_for_rpb) &&
|
||||
(nullptr == key_padding_mask || is_mask_1d_key_seq_len_start) &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2);
|
||||
#else
|
||||
constexpr bool use_memory_efficient_attention = false;
|
||||
|
|
|
|||
|
|
@ -272,7 +272,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
"mask_index",
|
||||
"Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), "
|
||||
"(batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), "
|
||||
"or index with shape (batch_size) or (2 * batch_size)",
|
||||
"or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)",
|
||||
"M",
|
||||
OpSchema::Optional)
|
||||
.Input(4,
|
||||
|
|
@ -590,7 +590,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
|
|||
OpSchema::Optional)
|
||||
.Input(4,
|
||||
"key_padding_mask",
|
||||
"Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)",
|
||||
"Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)",
|
||||
"M",
|
||||
OpSchema::Optional)
|
||||
.Input(5,
|
||||
|
|
|
|||
|
|
@ -3090,6 +3090,195 @@ void GetSelfAttentionDataWithPast(AttentionTestData& data) {
|
|||
data.is_static_kv = false;
|
||||
}
|
||||
|
||||
void GetAttentionDataCutlassRelPosBias(AttentionTestData& data) {
|
||||
data.hidden_size = 8;
|
||||
data.v_hidden_size = 8;
|
||||
data.num_heads = 2;
|
||||
data.batch_size = 1;
|
||||
data.sequence_length = 8;
|
||||
data.kv_sequence_length = 0;
|
||||
data.mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
|
||||
|
||||
data.key_padding_mask_data = {8, 0, 8, 0, 8};
|
||||
|
||||
data.skip_kernel_types = {
|
||||
AttentionKernelType::AttentionKernel_TrtFlashAttention,
|
||||
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
|
||||
AttentionKernelType::AttentionKernel_TrtFusedAttention};
|
||||
|
||||
{
|
||||
data.query_data = {
|
||||
-0.029273793f, 0.079709493f, 0.064531095f, 0.24270254f,
|
||||
-0.28326464f, 0.20984903f, -0.10173888f, 0.18373983f,
|
||||
|
||||
0.089472905f, -0.0063416883f, -0.049477674f, 0.36512995f,
|
||||
-0.23620239f, 0.1464397f, 0.068258412f, 0.31627196f,
|
||||
|
||||
0.12436871f, -0.0075563118f, -0.11576633f, 0.41008925f,
|
||||
-0.19456652f, 0.20145792f, 0.11790096f, 0.39789933f,
|
||||
|
||||
0.002485469f, 0.029660821f, -0.043821491f, 0.3892332f,
|
||||
-0.26994205f, 0.14530671f, 0.12950704f, 0.36185294f,
|
||||
|
||||
-0.029273793f, 0.079709493f, 0.064531095f, 0.24270254f,
|
||||
-0.28326464f, 0.20984903f, -0.10173888f, 0.18373983f,
|
||||
|
||||
0.089472905f, -0.0063416883f, -0.049477674f, 0.36512995f,
|
||||
-0.23620239f, 0.1464397f, 0.068258412f, 0.31627196f,
|
||||
|
||||
0.12436871f, -0.0075563118f, -0.11576633f, 0.41008925f,
|
||||
-0.19456652f, 0.20145792f, 0.11790096f, 0.39789933f,
|
||||
|
||||
0.002485469f, 0.029660821f, -0.043821491f, 0.3892332f,
|
||||
-0.26994205f, 0.14530671f, 0.12950704f, 0.36185294f,
|
||||
};
|
||||
}
|
||||
{
|
||||
data.key_data = {
|
||||
-0.32538497f, 0.34121913f, -0.18170178f, -0.015152611f,
|
||||
0.20429322f, 0.25979176f, 0.21269324f, 0.0025638193f,
|
||||
|
||||
-0.24246037f, 0.21112341f, -0.36959589f, -0.16091451f,
|
||||
0.24183474f, 0.18856162f, 0.094487116f, -0.3053959f,
|
||||
|
||||
-0.35736683f, 0.29276621f, -0.4217523f, -0.20031664f,
|
||||
0.33148992f, 0.26928401f, 0.19360018f, -0.39494509f,
|
||||
|
||||
-0.28043351f, 0.24279942f, -0.29154932f, -0.13657911f,
|
||||
0.31932494f, 0.3500579f, 0.027172565f, -0.19327414f,
|
||||
|
||||
-0.32538497f, 0.34121913f, -0.18170178f, -0.015152611f,
|
||||
0.20429322f, 0.25979176f, 0.21269324f, 0.0025638193f,
|
||||
|
||||
-0.24246037f, 0.21112341f, -0.36959589f, -0.16091451f,
|
||||
0.24183474f, 0.18856162f, 0.094487116f, -0.3053959f,
|
||||
|
||||
-0.35736683f, 0.29276621f, -0.4217523f, -0.20031664f,
|
||||
0.33148992f, 0.26928401f, 0.19360018f, -0.39494509f,
|
||||
|
||||
-0.28043351f, 0.24279942f, -0.29154932f, -0.13657911f,
|
||||
0.31932494f, 0.3500579f, 0.027172565f, -0.19327414f,
|
||||
};
|
||||
}
|
||||
|
||||
{
|
||||
data.value_data = {
|
||||
0.56916672f, -0.2443777f, 0.47111356f, -0.52134115f,
|
||||
0.010381341f, 0.0696759f, -0.071910433f, -0.35201436f,
|
||||
|
||||
0.70809275f, -0.24479815f, 0.41633749f, -0.34744334f,
|
||||
-0.0044222325f, 0.25929695f, -0.087832771f, -0.281232f,
|
||||
|
||||
0.90039468f, -0.28931504f, 0.56394172f, -0.43948689f,
|
||||
-0.05856207f, 0.33713666f, -0.10320446f, -0.38833332f,
|
||||
|
||||
0.76054728f, -0.29080144f, 0.50414616f, -0.42371163f,
|
||||
-0.047198489f, 0.31959397f, -0.22683662f, -0.30321664f,
|
||||
|
||||
0.56916672f, -0.2443777f, 0.47111356f, -0.52134115f,
|
||||
0.010381341f, 0.0696759f, -0.071910433f, -0.35201436f,
|
||||
|
||||
0.70809275f, -0.24479815f, 0.41633749f, -0.34744334f,
|
||||
-0.0044222325f, 0.25929695f, -0.087832771f, -0.281232f,
|
||||
|
||||
0.90039468f, -0.28931504f, 0.56394172f, -0.43948689f,
|
||||
-0.05856207f, 0.33713666f, -0.10320446f, -0.38833332f,
|
||||
|
||||
0.76054728f, -0.29080144f, 0.50414616f, -0.42371163f,
|
||||
-0.047198489f, 0.31959397f, -0.22683662f, -0.30321664f,
|
||||
};
|
||||
}
|
||||
|
||||
{
|
||||
data.bias_data = {
|
||||
-0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f,
|
||||
0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f,
|
||||
-0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f,
|
||||
-0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f,
|
||||
0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f,
|
||||
-0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f
|
||||
};
|
||||
}
|
||||
|
||||
{
|
||||
data.rel_pos_bias_data = {
|
||||
-10.808288f, -10.887209f, 7.8799553f, -4.6565766f,
|
||||
-1.6700006f, -0.033962168f, 7.4929152f, 10.944146f,
|
||||
8.640254f, -18.862164f, -3.1202927f, -6.3049207f,
|
||||
3.4508536f, 11.722519f, 3.3550568f, -5.4888172f,
|
||||
|
||||
-2.0828252f, -13.241742f, 2.9868939f, 1.4455698f,
|
||||
-15.262972f, -10.457437f, -8.4519463f, -4.4281874f,
|
||||
10.212368f, -0.28622282f, 12.087646f, 6.5218501f,
|
||||
8.1785011f, 13.985523f, -8.2068987f, 5.4260745f,
|
||||
|
||||
-10.808288f, -10.887209f, 7.8799553f, -4.6565766f,
|
||||
-1.6700006f, -0.033962168f, 7.4929152f, 10.944146f,
|
||||
8.640254f, -18.862164f, -3.1202927f, -6.3049207f,
|
||||
3.4508536f, 11.722519f, 3.3550568f, -5.4888172f,
|
||||
|
||||
-2.0828252f, -13.241742f, 2.9868939f, 1.4455698f,
|
||||
-15.262972f, -10.457437f, -8.4519463f, -4.4281874f,
|
||||
10.212368f, -0.28622282f, 12.087646f, 6.5218501f,
|
||||
8.1785011f, 13.985523f, -8.2068987f, 5.4260745f,
|
||||
|
||||
-10.808288f, -10.887209f, 7.8799553f, -4.6565766f,
|
||||
-1.6700006f, -0.033962168f, 7.4929152f, 10.944146f,
|
||||
8.640254f, -18.862164f, -3.1202927f, -6.3049207f,
|
||||
3.4508536f, 11.722519f, 3.3550568f, -5.4888172f,
|
||||
|
||||
-2.0828252f, -13.241742f, 2.9868939f, 1.4455698f,
|
||||
-15.262972f, -10.457437f, -8.4519463f, -4.4281874f,
|
||||
10.212368f, -0.28622282f, 12.087646f, 6.5218501f,
|
||||
8.1785011f, 13.985523f, -8.2068987f, 5.4260745f,
|
||||
|
||||
-10.808288f, -10.887209f, 7.8799553f, -4.6565766f,
|
||||
-1.6700006f, -0.033962168f, 7.4929152f, 10.944146f,
|
||||
8.640254f, -18.862164f, -3.1202927f, -6.3049207f,
|
||||
3.4508536f, 11.722519f, 3.3550568f, -5.4888172f,
|
||||
|
||||
-2.0828252f, -13.241742f, 2.9868939f, 1.4455698f,
|
||||
-15.262972f, -10.457437f, -8.4519463f, -4.4281874f,
|
||||
10.212368f, -0.28622282f, 12.087646f, 6.5218501f,
|
||||
8.1785011f, 13.985523f, -8.2068987f, 5.4260745f,
|
||||
};
|
||||
}
|
||||
|
||||
{
|
||||
data.fp16_output_data = {
|
||||
1.0419922f, 0.13000488f, 0.10528564f, -0.86230469f,
|
||||
-0.45336914f, 0.39013672f, -0.048858643f, 0.10571289f,
|
||||
|
||||
0.97265625f, 0.17590332f, 0.015625f, -0.79248047f,
|
||||
-0.40917969f, 0.31933594f, 0.082763672f, 0.12976074f,
|
||||
|
||||
1.1455078f, 0.13134766f, 0.15014648f, -0.87451172f,
|
||||
-0.46142578f, 0.40161133f, 0.04309082f, 0.042663574f,
|
||||
|
||||
1.0009766f, 0.17004395f, 0.033752441f, -0.80078125f,
|
||||
-0.41625977f, 0.33349609f, 0.080383301f, 0.11846924f,
|
||||
|
||||
1.0419922f, 0.13000488f, 0.10528564f, -0.86230469f,
|
||||
-0.45336914f, 0.39013672f, -0.048858643f, 0.10571289f,
|
||||
|
||||
0.97265625f, 0.17590332f, 0.015625f, -0.79248047f,
|
||||
-0.40917969f, 0.31933594f, 0.082763672f, 0.12976074f,
|
||||
|
||||
1.1455078f, 0.13134766f, 0.15014648f, -0.87451172f,
|
||||
-0.46142578f, 0.40161133f, 0.04309082f, 0.042663574f,
|
||||
|
||||
1.0009766f, 0.17004395f, 0.033752441f, -0.80078125f,
|
||||
-0.41625977f, 0.33349609f, 0.080383301f, 0.11846924f,
|
||||
};
|
||||
}
|
||||
|
||||
{
|
||||
data.fp32_output_data = {};
|
||||
}
|
||||
|
||||
data.is_static_kv = false;
|
||||
}
|
||||
|
||||
bool SkipAttentionKernel(AttentionTestData& data, AttentionKernelType kernel_type) {
|
||||
return std::find(data.skip_kernel_types.begin(), data.skip_kernel_types.end(), kernel_type) != data.skip_kernel_types.end();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,6 +62,8 @@ void GetCrossAttentionData_HeadSize16(AttentionTestData& data);
|
|||
void GetCrossAttentionDataWithPast(AttentionTestData& data);
|
||||
void GetSelfAttentionDataWithPast(AttentionTestData& data);
|
||||
|
||||
void GetAttentionDataCutlassRelPosBias(AttentionTestData& data);
|
||||
|
||||
bool SkipAttentionKernel(AttentionTestData& data, AttentionKernelType kernel_type);
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -55,7 +55,9 @@ static void RunMultiHeadAttentionTest(
|
|||
std::vector<int64_t> key_dims = {batch_size, is_static_kv ? kv_sequence_length : sequence_length, hidden_size};
|
||||
std::vector<int64_t> value_dims = {batch_size, is_static_kv ? kv_sequence_length : sequence_length, v_hidden_size};
|
||||
std::vector<int64_t> bias_dims = {hidden_size + hidden_size + v_hidden_size};
|
||||
std::vector<int64_t> rel_pos_bias_dims = {1, num_heads, sequence_length, sequence_length + kv_sequence_length};
|
||||
// TODO(wy): Introduce past sequence length to avoid using kv_sequence_length.
|
||||
std::vector<int64_t> rel_pos_bias_dims =
|
||||
{1, num_heads, sequence_length, past_key_data.size() ? sequence_length + kv_sequence_length : sequence_length};
|
||||
std::vector<int64_t> past_key_dims = {batch_size, num_heads, kv_sequence_length, hidden_size / num_heads};
|
||||
std::vector<int64_t> past_value_dims = past_key_dims;
|
||||
std::vector<int64_t> output_dims = {batch_size, sequence_length, v_hidden_size};
|
||||
|
|
@ -82,9 +84,10 @@ static void RunMultiHeadAttentionTest(
|
|||
|
||||
std::vector<int64_t> mask_dims_1 = {batch_size};
|
||||
std::vector<int64_t> mask_dims_2 = {batch_size, kv_sequence_length};
|
||||
std::vector<int64_t> mask_dims_3 = {3 * batch_size + 2};
|
||||
std::vector<int64_t>& key_padding_mask_dims = (mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN)
|
||||
? mask_dims_1
|
||||
: mask_dims_2;
|
||||
? mask_dims_1
|
||||
: (mask_type == AttentionMaskType::MASK_2D_KEY_PADDING ? mask_dims_2 : mask_dims_3);
|
||||
|
||||
if (use_float16) {
|
||||
tester.AddInput<MLFloat16>("query", query_dims, ToFloat16(query));
|
||||
|
|
@ -487,5 +490,11 @@ TEST(MultiHeadAttentionTest, SelfAttentionWithPast) {
|
|||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) {
|
||||
AttentionTestData data;
|
||||
GetAttentionDataCutlassRelPosBias(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
|
||||
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
|
||||
version: 1.0.32
|
||||
version: 1.0.36
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# The private ADO project
|
||||
|
|
@ -22,7 +22,7 @@ steps:
|
|||
packageType: upack
|
||||
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
|
||||
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
|
||||
version: 1.0.32
|
||||
version: 1.0.36
|
||||
downloadPath: $(Build.BinariesDirectory)/deps
|
||||
|
||||
# You can add more ADO accounts at here.
|
||||
|
|
|
|||
Loading…
Reference in a new issue