Extend memory efficient attention coverage in Attention/MHA cuda op (#15064)

### Description
<!-- Describe your changes. -->

1. upgrade cutlass to 3.0 that containing attn_bias support.
2. extend Attention/MHA to use memory efficient attention when
rel_pos_bias with [1, num_head, s, s*] and 1d mask with [2 * batch_size
+ 1] are present.

new mask format introduction:
MASK_1D_KEY_SEQ_LEN_START,  
[3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1],
query_start[0], ..., query_start[batch_size - 1], query_end[batch_size -
1], key_start[0], ..., key_start[batch_size - 1], key_end[batch_size -
1]]

e.g
2D mask with [[1, 1, 1, 0, 0, 0], [1, 1, 1, 1, 1, 0]] converts to this
1D mask is [3, 5, 0, 6, 12, 0, 6, 12]


### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->

It potentially benefits tnlrv6 and t5(encoder)

---------

Co-authored-by: Ubuntu <wy@v100-2.0cdb2e52twzevn1i4fi45bylyg.jx.internal.cloudapp.net>
Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com>
Co-authored-by: Kunal Vaishnavi <kvaishnavi@microsoft.com@orttrainingdev7.d32nl1ml4oruzj4qz3bqlggovf.px.internal.cloudapp.net>
This commit is contained in:
Ye Wang 2023-03-23 11:05:17 -07:00 committed by GitHub
parent 7033346605
commit 2ee822d483
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
19 changed files with 382 additions and 999 deletions

View file

@ -392,7 +392,7 @@
"component": {
"type": "git",
"git": {
"commitHash": "66d9cddc832c1cdc2b30a8755274f7f74640cfe6",
"commitHash": "c4f6b8c6bc94ff69048492fb34df0dfaf1983933",
"repositoryUrl": "https://github.com/NVIDIA/cutlass.git"
},
"comments": "cutlass"

View file

@ -34,7 +34,7 @@ pytorch_cpuinfo;https://github.com/pytorch/cpuinfo/archive/5916273f79a21551890fd
re2;https://github.com/google/re2/archive/refs/tags/2022-06-01.zip;aa77313b76e91b531ee7f3e45f004c6a502a5374
safeint;https://github.com/dcleblanc/SafeInt/archive/ff15c6ada150a5018c5ef2172401cb4529eac9c0.zip;913a4046e5274d329af2806cb53194f617d8c0ab
tensorboard;https://github.com/tensorflow/tensorboard/archive/373eb09e4c5d2b3cc2493f0949dc4be6b6a45e81.zip;67b833913605a4f3f499894ab11528a702c2b381
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v2.11.0.zip;be70c559f07251ba7f33c789dba98872b444c10f
cutlass;https://github.com/NVIDIA/cutlass/archive/refs/tags/v3.0.0.zip;0f95b3c1fc1bd1175c4a90b2c9e39074d1bccefd
# below are deps introduced by triton client, might remove after 1.14 release
openssl;https://github.com/openssl/openssl/archive/refs/tags/openssl-3.0.7.zip;dda8fc81308555410505eb4a9eab3e1da0436a1d
rapidjson;https://github.com/Tencent/rapidjson/archive/refs/tags/v1.1.0.zip;0fe7b4f7b83df4b3d517f4a202f3a383af7a0818

View file

@ -4,6 +4,7 @@ if (onnxruntime_USE_FLASH_ATTENTION)
cutlass
URL ${DEP_URL_cutlass}
URL_HASH SHA1=${DEP_SHA1_cutlass}
PATCH_COMMAND ${Patch_EXECUTABLE} --binary --ignore-whitespace -p1 < ${PROJECT_SOURCE_DIR}/patches/cutlass/cutlass.patch
)
FetchContent_GetProperties(cutlass)

View file

@ -0,0 +1,92 @@
diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp
index 3790ebd3..cf727d09 100644
--- a/include/cute/numeric/complex.hpp
+++ b/include/cute/numeric/complex.hpp
@@ -41,10 +41,14 @@
// With CUDA 11.4, builds show spurious "-Wconversion" warnings
// on line 656 of thrust/detail/type_traits.h.
// These pragmas suppress the warnings.
+#ifdef __GNUC__
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wconversion"
+#endif
#include <thrust/complex.h>
+#ifdef __GNUC__
#pragma GCC diagnostic pop
+#endif
#include <cute/config.hpp>
diff --git a/include/cutlass/functional.h b/include/cutlass/functional.h
index 59aec46a..8f2a913a 100644
--- a/include/cutlass/functional.h
+++ b/include/cutlass/functional.h
@@ -89,7 +89,7 @@ struct multiplies {
}
};
-#if defined(__CUDA_ARCH__)
+#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530)
/// Partial specializations needed when __CUDA_NO_HALF2_OPERATORS__ is set
template<>
struct plus<__half2> {
@@ -143,12 +143,12 @@ struct multiplies<__half> {
// Maximum with nan propogation
-// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
+// To propgate the NANs, the "max" of a two element that contains NaNs should also return a NaN
template <typename T>
struct maximum_with_nan_propogation {
CUTLASS_HOST_DEVICE
T operator()(T const &lhs, T const &rhs) const {
- return lhs > rhs or std::isnan(lhs) ? lhs : rhs;
+ return lhs > rhs or isnan(lhs) ? lhs : rhs;
}
};
@@ -160,7 +160,7 @@ struct maximum_with_nan_propogation<float> {
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)
asm volatile("max.NaN.f32 %0, %1, %2;\n" : "=f"(res) : "f"(lhs), "f"(rhs));
#else
- res = lhs > rhs or std::isnan(lhs) ? lhs : rhs;
+ res = lhs > rhs or isnan(lhs) ? lhs : rhs;
#endif
return res;
}
@@ -233,7 +233,7 @@ struct negate {
}
};
-/// Greater equal
+/// Greater equal
template <typename T>
struct greater_equal {
CUTLASS_HOST_DEVICE
@@ -242,7 +242,7 @@ struct greater_equal {
}
};
-/// Greater
+/// Greater
template <typename T>
struct greater {
CUTLASS_HOST_DEVICE
@@ -251,7 +251,7 @@ struct greater {
}
};
-/// Less equal
+/// Less equal
template <typename T>
struct less_equal {
CUTLASS_HOST_DEVICE
@@ -260,7 +260,7 @@ struct less_equal {
}
};
-/// Less
+/// Less
template <typename T>
struct less {
CUTLASS_HOST_DEVICE

4
docs/ContribOperators.md Normal file → Executable file
View file

@ -155,7 +155,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) for input projection</dd>
<dt><tt>mask_index</tt> (optional) : M</dt>
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size)</dd>
<dd>Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), (batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)</dd>
<dt><tt>past</tt> (optional) : T</dt>
<dd>past state for key and value with shape (2, batch_size, num_heads, past_sequence_length, head_size)When past_present_share_buffer is set, its shape is (2, batch_size, num_heads, max_sequence_length, head_size)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
@ -2404,7 +2404,7 @@ This version of the operator has been available since version 1 of the 'com.micr
<dt><tt>bias</tt> (optional) : T</dt>
<dd>Bias tensor with shape (hidden_size + hidden_size + v_hidden_size) from input projection</dd>
<dt><tt>key_padding_mask</tt> (optional) : M</dt>
<dd>Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)</dd>
<dd>Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)</dd>
<dt><tt>relative_position_bias</tt> (optional) : T</dt>
<dd>relative position bias: addition to QxK' with shape (batch_size, num_heads, sequence_length, total_sequence_length) or (1, num_heads, sequence_length, total_sequence_length)</dd>
<dt><tt>past_key</tt> (optional) : T</dt>

View file

@ -41,7 +41,7 @@ Status AttentionBase::CheckInputs(const TensorShape& input_shape,
// For mask_index, the following shapes are supported:
// NULL, (B, 1), (1, 1)
// (B), (2 * B),
// (B), (2 * B), (3 * B + 2)
// (B, T)
// (B, S, T)
// (B, 1, M, M)
@ -274,11 +274,13 @@ Status AttentionBase::CheckMask(const Tensor* mask_index,
int64_t total_sequence_length) const {
const auto& mask_dims = mask_index->Shape().GetDims();
if (mask_dims.size() == 1) {
if (mask_dims[0] != batch_size && mask_dims[0] != 2 * batch_size) {
if (mask_dims[0] != batch_size && mask_dims[0] != 2 * batch_size && mask_dims[0] != 3 * batch_size + 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
"Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size");
"Inputs 'mask_index' with 1D data shall have length of batch_size or 2 * batch_size or 3 * batch_size + 2");
}
mask_type = (mask_dims[0] == batch_size ? AttentionMaskType::MASK_1D_KEY_SEQ_LEN : AttentionMaskType::MASK_1D_END_START);
mask_type = (mask_dims[0] == batch_size ?
AttentionMaskType::MASK_1D_KEY_SEQ_LEN :
mask_dims[0] == 2 * batch_size ? AttentionMaskType::MASK_1D_END_START : AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START);
} else if (mask_dims.size() == 2) {
if (mask_dims[0] == batch_size && mask_dims[1] == total_sequence_length) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;

View file

@ -7,13 +7,16 @@ namespace onnxruntime {
namespace contrib {
enum AttentionMaskType {
MASK_NONE, // No mask
MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length
MASK_1D_END_START, // [2 * batch_size] with end positions and start positions
MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length]
MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length]
MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
MASK_NONE, // No mask
MASK_1D_KEY_SEQ_LEN, // [batch_size], key sequence length
MASK_1D_END_START, // [2 * batch_size] with end positions and start positions
MASK_1D_KEY_SEQ_LEN_START, // [3 * batch_size + 2] with [key_len[0], ..., key_len[batch_size - 1], query_start[0],
// ..., query_start[batch_size - 1], query_end[batch_size - 1], key_start[0], ...,
// key_start[batch_size - 1], key_end[batch_size - 1]]
MASK_2D_DUMMY, // dummy mask with shape [1, 1] or [batch_size, 1]. It has same effect as no mask.
MASK_2D_KEY_PADDING, // [batch_size, total_sequence_length]
MASK_3D_ATTENTION, // [batch_size, sequence_length, total_sequence_length]
MASK_4D_MEGATRON, // Megatron causal mask with shape [batch_size, 1, max_sequence_length, max_sequence_length]
MASK_UNKNOWN
};

View file

@ -25,7 +25,7 @@ Status CheckInputs(const T* query,
float mask_filter_value,
float scale,
int max_threads_per_block) {
// key_padding_mask (K/V) : (B) or (B, L) or None
// key_padding_mask (K/V) : (B) or (2*B + 1) or (B, L) or None
// relative_position_bias : (B, 1, S, L)
// past_key : (B, N, S*, H)
// past_value : (B, N, S*, H)
@ -188,8 +188,12 @@ Status CheckInputs(const T* query,
if (key_padding_mask != nullptr) {
mask_type = AttentionMaskType::MASK_UNKNOWN;
const auto& mask_dims = key_padding_mask->Shape().GetDims();
if (mask_dims.size() == 1 && mask_dims[0] == static_cast<int64_t>(batch_size)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
if (mask_dims.size() == 1) {
if (mask_dims[0] == static_cast<int64_t>(batch_size)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
} else if (mask_dims[0] == static_cast<int64_t>(3 * batch_size + 2)) {
mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
}
} else if (mask_dims.size() == 2 && mask_dims[0] == static_cast<int64_t>(batch_size) && mask_dims[1] == static_cast<int64_t>(kv_sequence_length)) {
mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
}

View file

@ -102,6 +102,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Check whether we can use fused kernel
int sm = device_prop.major * 10 + device_prop.minor;
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
if (is_unidirectional_ && enable_fused_causal_attention_) { // GPT
// GPT fused kernels requires left side padding. mask can be:
@ -151,12 +152,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}
#if USE_FLASH_ATTENTION
bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
bool use_memory_efficient_attention = fused_runner == nullptr &&
!disable_memory_efficient_attention_ &&
nullptr == mask_index && // TODO: support 1D mask
(nullptr == mask_index || is_mask_1d_key_seq_len_start) &&
nullptr == past &&
nullptr == present &&
nullptr == relative_position_bias &&
(nullptr == relative_position_bias || is_good_for_rpb) &&
(sizeof(T) == 2 || // sequence length threshold is 0 in FP16
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);

View file

@ -445,6 +445,14 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
DUMP_TENSOR_D("value", data.value, batch_size * kv_sequence_length, num_heads, v_head_size);
DUMP_TENSOR_D("value_bias", data.bias + 2 * num_heads * qk_head_size, num_heads, v_head_size);
if (data.relative_position_bias != nullptr && parameters.broadcast_res_pos_bias) {
DUMP_TENSOR_D("relative_position_bias", data.relative_position_bias, num_heads, sequence_length, kv_sequence_length);
}
if (data.mask_index != nullptr && parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START) {
DUMP_TENSOR_D("mask_index", data.mask_index, 3 * batch_size + 2, 1);
}
if (data.fused_cross_attention_kernel != nullptr) {
assert(qk_head_size == v_head_size);
@ -735,11 +743,14 @@ Status QkvToContext(
return Status::OK();
}
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
#if USE_FLASH_ATTENTION
if (data.use_memory_efficient_attention) {
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
assert(data.mask_index == nullptr);
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
const void* query = q;
@ -754,23 +765,26 @@ Status QkvToContext(
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
p.is_half = sizeof(T) == 2;
p.batch_size = data.mask_index == nullptr ? parameters.batch_size : 2 * parameters.batch_size;
p.batch_size = parameters.batch_size;
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.total_sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
p.cu_seqlens_q = nullptr;
p.cu_seqlens_k = nullptr;
p.scale = scale;
p.seqlen_k_ptr = nullptr == data.mask_index ? nullptr : const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index));
p.seqstart_q_ptr = nullptr == data.mask_index ? nullptr : const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index + batch_size));
p.seqstart_k_ptr = nullptr == data.mask_index ? nullptr : const_cast<int32_t*>(reinterpret_cast<const int32_t*>(data.mask_index + 2 * batch_size + 1));
p.query = query;
p.key = key;
p.value = value;
p.attn_bias = nullptr == data.relative_position_bias ? nullptr : data.relative_position_bias;
p.is_attn_bias_batched = !parameters.broadcast_res_pos_bias;
p.output = data.output;
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr;
p.stream = stream;
run_memory_efficient_attention(p);
DUMP_TENSOR("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size);
return Status::OK();
}
@ -789,9 +803,6 @@ Status QkvToContext(
float one = 1.0f;
float zero = 0.f;
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
const float scale = parameters.scale == 0.0f ? 1.f / sqrt(static_cast<float>(qk_head_size))
: parameters.scale;
float alpha = use_raw_attention_mask ? one : scale;
cublasSetStream(cublas, stream);

View file

@ -10,7 +10,7 @@
#endif
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h"
#include "41_fused_multi_head_attention/kernel_forward.h"
namespace onnxruntime {
namespace contrib {
@ -24,8 +24,10 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query));
p.key_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.key));
p.value_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.value));
p.cu_seqlens_q_ptr = params.cu_seqlens_q;
p.cu_seqlens_k_ptr = params.cu_seqlens_k;
p.attn_bias_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.attn_bias));
p.seqstart_q_ptr = params.seqstart_q_ptr;
p.seqstart_k_ptr = params.seqstart_k_ptr;
p.seqlen_k_ptr = params.seqlen_k_ptr;
p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward
p.output_ptr = reinterpret_cast<T*>(params.output);
@ -42,28 +44,32 @@ void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
p.head_dim = params.qk_head_size;
p.head_dim_value = params.v_head_size;
p.scale = params.scale;
// When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel
p.num_queries = params.sequence_length;
p.num_keys = params.kv_sequence_length;
p.causal = params.causal;
if (params.causal) {
p.custom_mask_type = Attention::CausalFromTopLeft;
}
// Input format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.o_strideH = params.v_head_size;
p.bias_strideH = nullptr == params.attn_bias ? 0 : p.num_queries * p.num_keys;
p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.o_strideM = params.num_heads * params.v_head_size;
p.bias_strideM = nullptr == params.attn_bias ? 0 : p.num_keys;
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
p.o_strideB = static_cast<int64_t>(params.num_heads) * params.v_head_size * params.sequence_length;
p.causal = params.causal;
p.bias_strideB = params.is_attn_bias_batched ? static_cast<int64_t>(p.bias_strideH) * params.num_heads : 0;
}
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;

View file

@ -1,947 +0,0 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if USE_FLASH_ATTENTION
#include <cmath>
#include <vector>
#include "cutlass/bfloat16.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "41_fused_multi_head_attention/attention_scaling_coefs_updater.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "41_fused_multi_head_attention/debug_utils.h"
#include "41_fused_multi_head_attention/epilogue_pipelined.h"
#include "41_fused_multi_head_attention/epilogue_rescale_output.h"
#include "41_fused_multi_head_attention/find_default_mma.h"
#include "41_fused_multi_head_attention/gemm_kernel_utils.h"
#include "41_fused_multi_head_attention/mma_from_smem.h"
#include <inttypes.h>
using namespace gemm_kernel_utils;
namespace {
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
return (
Arch::kMinComputeCapability >= 80 &&
!cutlass::platform::is_same<scalar_t, float>::value
? 16
: 12);
}
} // namespace
template <
// The datatype of Q/K/V
typename scalar_t_,
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
typename ArchTag,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
>
struct AttentionKernel {
using scalar_t = scalar_t_;
using accum_t = float;
using lse_scalar_t = float;
using output_t = scalar_t;
// Accumulator between 2 iterations
// Using `accum_t` improves perf on f16 at the cost of
// numerical errors
using output_accum_t = accum_t;
static constexpr bool kIsAligned = isAligned_;
static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
static_assert(kQueriesPerBlock % 32 == 0, "");
static_assert(kKeysPerBlock % 32 == 0, "");
static constexpr int kNumWarpsPerBlock =
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
static constexpr int kWarpSize = 32;
// Launch bounds
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int kMinBlocksPerSm =
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
struct Params {
// Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
int32_t* cu_seqlens_q_ptr = nullptr;
int32_t* cu_seqlens_k_ptr = nullptr;
// Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
output_accum_t*
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
bool causal;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t o_strideH;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int64_t o_strideB;
int32_t num_batches;
int32_t num_heads;
// https://github.com/NVIDIA/cutlass/issues/771
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
return head_dim_value * num_heads;
}
// Moves pointers to what we should process
// Returns "false" if there is no work to do
CUTLASS_DEVICE bool advance_to_block() {
auto batch_id = blockIdx.z;
auto head_id = blockIdx.y;
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
int64_t q_start, k_start;
// Advance to current batch - in case of different sequence lengths
if (cu_seqlens_q_ptr != nullptr) {
assert(cu_seqlens_k_ptr != nullptr);
cu_seqlens_q_ptr += batch_id;
cu_seqlens_k_ptr += batch_id;
q_start = cu_seqlens_q_ptr[0];
k_start = cu_seqlens_k_ptr[0];
int64_t q_next_start = cu_seqlens_q_ptr[1];
int64_t k_next_start = cu_seqlens_k_ptr[1];
num_queries = q_next_start - q_start;
num_keys = k_next_start - k_start;
if (query_start >= num_queries) {
return false;
}
} else {
query_ptr += batch_id * q_strideB;
key_ptr += batch_id * k_strideB;
value_ptr += batch_id * v_strideB;
output_ptr += batch_id * o_strideB;
if (output_accum_ptr != nullptr) {
output_accum_ptr += batch_id * o_strideB;
}
q_start = 0;
k_start = 0;
}
// Advance to the current batch / head / query_start
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + head_id * k_strideH;
value_ptr += k_start * v_strideM + head_id * v_strideH;
output_ptr += int64_t(q_start + query_start) * o_strideM() +
head_id * o_strideH;
if (output_accum_ptr != nullptr) {
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
head_id * o_strideH;
} else {
// Accumulate directly in the destination buffer (eg for f32)
output_accum_ptr = (accum_t*)output_ptr;
}
if (logsumexp_ptr != nullptr) {
// lse[batch_id, head_id, query_start]
logsumexp_ptr +=
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
}
num_queries -= query_start;
if (causal) {
num_keys = cutlass::fast_min(
int32_t(query_start + kQueriesPerBlock), num_keys);
}
num_batches = 0; // no longer used after
// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
output_ptr = warp_uniform(output_ptr);
output_accum_ptr = warp_uniform(output_accum_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
return true;
}
__host__ dim3 getBlocksGrid() const {
return dim3(
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
num_heads,
num_batches);
}
__host__ dim3 getThreadsGrid() const {
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
}
};
struct MM0 {
/*
In this first matmul, we compute a block of `Q @ K.T`.
While the calculation result is still hot in registers, we update
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
into a shared-memory ("AccumulatorSharedStorage") that is used later as
operand A for the second matmul (see MM1)
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
scalar_t,
scalar_t,
scalar_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA =
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
static constexpr int kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
cutlass::layout::ColumnMajor, // LayoutB,
kAlignmentB,
accum_t,
cutlass::layout::RowMajor, // LayoutC,
OpClass,
ArchTag, // ArchTag
ThreadblockShape, // ThreadblockShape
WarpShape, // WarpShape
typename GemmType::InstructionShape, // InstructionShape
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
// uses too much smem
typename GemmType::Operator // Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
typename Mma::Operator::IteratorC,
accum_t,
kWarpSize>::Updater;
static_assert(
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
MmaCore::WarpCount::kK ==
kNumWarpsPerBlock,
"");
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MM1 {
/**
Second matmul: perform `attn @ V` where `attn` is the attention (not
normalized) and stored in shared memory
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
scalar_t,
scalar_t,
output_accum_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
static constexpr int kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
using LayoutB = cutlass::layout::RowMajor;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
LayoutB, // LayoutB,
kAlignmentB,
output_accum_t,
cutlass::layout::RowMajor, // LayoutC,
accum_t,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage>;
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
static_assert(
WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
"");
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_t>;
using OutputTileIteratorAccum =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
static constexpr int64_t kAlignmentV = 1;
// Shared storage - depends on kernel params
struct ScalingCoefs {
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
cutlass::Array<accum_t, kQueriesPerBlock> mi;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return epilogue;
}
};
struct SharedStorageEpilogueInLoop : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return after_mm0.epilogue;
}
};
using SharedStorage = typename cutlass::platform::conditional<
kSingleValueIteration || kKeepOutputInRF,
SharedStorageEpilogueAtEnd,
SharedStorageEpilogueInLoop>::type;
static bool __host__ check_supported(Params const& p) {
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
XFORMERS_CHECK(
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
XFORMERS_CHECK(
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
return true;
}
static void CUTLASS_DEVICE attention_kernel(Params& p) {
// In this block, we will only ever:
// - read query[query_start:query_end, :]
// - write to output[query_start:query_end, :]
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& mi = shared_storage.mi;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
}
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
p.output_ptr,
typename OutputTileIterator::TensorCoord{
p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
auto createOutputAccumIter = [&](int col) ->
typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
p.output_accum_ptr,
typename OutputTileIteratorAccum::TensorCoord{
p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
iter_key_start += kKeysPerBlock) {
int32_t problem_size_0_m =
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
int32_t problem_size_0_n = cutlass::fast_min(
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
int32_t const& problem_size_0_k = p.head_dim;
int32_t const& problem_size_1_n = p.head_dim_value;
int32_t const& problem_size_1_k = problem_size_0_n;
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm,
iterator_V,
thread_id(),
problem_size_1_k);
};
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
// updated from end of prev iter
//
// MATMUL: Q.K_t
//
// Computes the block-matrix product of:
// (a) query[query_start:query_end, :]
// with
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
// and stores that into `shared_storage.si`
//
// Compute threadblock location
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
cutlass::MatrixCoord tb_offset_A{
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
cutlass::MatrixCoord tb_offset_B{
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
// Construct iterators to A and B operands
typename MM0::IteratorA iterator_A(
typename MM0::IteratorA::Params(
typename MM0::MmaCore::LayoutA(p.q_strideM)),
p.query_ptr,
{problem_size_0_m, problem_size_0_k},
thread_id(),
tb_offset_A);
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(p.k_strideM)),
p.key_ptr + iter_key_start * p.k_strideM,
{problem_size_0_k, problem_size_0_n},
thread_id(),
tb_offset_B);
auto my_warp_id = warp_id();
auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply
typename MM0::Mma mma(
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
typename MM0::Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPreloadV) {
prologueV(0);
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
iteratorC_tile_offset = {
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
(my_warp_id % MM0::Mma::WarpCount::kM),
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
(my_warp_id / MM0::Mma::WarpCount::kM)};
// Mask out last if causal
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::ScalingCoefsUpdater::iterateRows(
lane_offset,
[&](int accum_m) {
last_col = query_start + accum_m - iter_key_start;
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
MM0::ScalingCoefsUpdater::update<
kQueriesPerBlock,
kFullColumns,
kIsFirst,
kKeepOutputInRF>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
}));
}));
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
auto output_tile_coords = cutlass::MatrixCoord{
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
MM0::B2bGemm::accumToSmem(
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
__syncthreads();
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
// `attn` is read from shared memory (in `shared_storage_si`)
// `V` is read from global memory (with iterator_B)
//
const int64_t nBlockN = kSingleValueIteration
? 1
: ceil_div(
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
for (int blockN = 0; blockN < nBlockN; ++blockN) {
int gemm_k_iterations =
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add and store it in accum
// (in registers)
if (!kPreloadV) {
__syncthreads(); // we share shmem between mma and epilogue
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
}
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
__syncthreads();
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
prologueV(blockN + 1);
}
if (!kKeepOutputInRF) {
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= p.num_keys,
kIsLast,
([&] {
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp =
typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
ElementCompute,
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // Read
// iterator
>;
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter = call_conditional<
kIsLast,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
if (!kSingleValueIteration) {
__syncthreads();
}
}
}
__syncthreads(); // we modify `m_prime` after
}
if (kKeepOutputInRF) {
constexpr bool kIsFirst = true;
constexpr bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
output_t, // output
output_accum_t, // source
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator, // accum
output_accum_t, // compute
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MM1::OutputTileIterator, // destination
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o);
}
// 7. Calculate logsumexp
// To make the backward easier, we pad logsumexp with `inf`
// this avoids a few bound checks, and is not more expensive during fwd
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
cutlass::platform::numeric_limits<accum_t>::infinity();
}
}
}
static CUTLASS_DEVICE int8_t lane_id() {
return threadIdx.x;
}
static CUTLASS_DEVICE int8_t warp_id() {
return threadIdx.y;
}
static CUTLASS_DEVICE int16_t thread_id() {
return threadIdx.x + threadIdx.y * blockDim.x;
}
};
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched_impl(typename AK::Params p) {
if (!p.advance_to_block()) {
return;
}
AK::attention_kernel(p);
}
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched(typename AK::Params params);
#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \
template <> \
__global__ void __launch_bounds__( \
__VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \
attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \
using Kernel = __VA_ARGS__;
#define _ATTENTION_KERNEL_FORWARD_END() }
#ifdef __CUDA_ARCH__
#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__
#else
#define __CUDA_ARCH_OR_ZERO__ 0
#endif
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \
ARCH, \
SCALAR_T, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER) \
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
SCALAR_T, \
cutlass::arch::Sm##ARCH, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER>) \
if (!p.advance_to_block()) { \
return; \
} \
Kernel::attention_kernel(p); \
_ATTENTION_KERNEL_FORWARD_END();
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \
ARCH, \
SCALAR_T, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER) \
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
SCALAR_T, \
cutlass::arch::Sm##ARCH, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER>) \
printf( \
"FATAL: this function is for sm%d, but was built for sm%d\n", \
int(ARCH), \
int(__CUDA_ARCH_OR_ZERO__)); \
_ATTENTION_KERNEL_FORWARD_END();
// All kernels are disabled by default
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
// Enable the right one based on __CUDA_ARCH__
#ifndef __CUDA_ARCH__
#elif __CUDA_ARCH__ < 500
//#error "Need cuda arch at least 5.0"
#elif __CUDA_ARCH__ < 700
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
#elif __CUDA_ARCH__ < 750
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
#elif __CUDA_ARCH__ < 800
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
#elif __CUDA_ARCH__ >= 800
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
#endif
#endif

View file

@ -21,15 +21,21 @@ struct MemoryEfficientAttentionParams {
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
// The default shape of attn_bias is [1, N, S, S*]. Sometimes we need to use [B, N, S, S*] in custom models.
bool is_attn_bias_batched;
int32_t* cu_seqlens_q;
int32_t* cu_seqlens_k;
float scale;
const void* query; // [B, S, N, H]
const void* key; // [B, L, N, H], where L is kv_sequence_length
const void* value; // [B, L, N, H_v]
void* output; // [B, S, N, H_v]
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
int32_t* seqstart_q_ptr;
int32_t* seqstart_k_ptr;
int32_t* seqlen_k_ptr;
const void* query; // [B, S, N, H]
const void* key; // [B, L, N, H], where L is kv_sequence_length
const void* value; // [B, L, N, H_v]
const void* attn_bias; // [N, S, S*] or null
void* output; // [B, S, N, H_v]
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
cudaStream_t stream;
static bool need_workspace(size_t v_head_size, bool is_float) {

View file

@ -116,6 +116,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
int sm = device_prop.major * 10 + device_prop.minor;
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
bool is_mask_1d_key_seq_len_start = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
bool use_fused_cross_attention = !disable_fused_cross_attention_ &&
nullptr == key_padding_mask &&
@ -168,12 +169,14 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32 ||
parameters.kv_sequence_length >= attention::kMinSequenceLengthForMemoryEfficientAttentionFp32;
bool is_good_for_rpb = relative_position_bias != nullptr && parameters.sequence_length % (4 * sizeof(T)) == 0;
bool use_memory_efficient_attention = fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
is_long_sequence &&
nullptr == key_padding_mask && // TODO: support 1D mask
nullptr == relative_position_bias &&
(relative_position_bias == nullptr || is_good_for_rpb) &&
(nullptr == key_padding_mask || is_mask_1d_key_seq_len_start) &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
#else
constexpr bool use_memory_efficient_attention = false;

View file

@ -272,7 +272,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
"mask_index",
"Attention mask with shape (batch_size, 1, max_sequence_length, max_sequence_length), "
"(batch_size, total_sequence_length) or (batch_size, sequence_length, total_sequence_length), "
"or index with shape (batch_size) or (2 * batch_size)",
"or index with shape (batch_size) or (2 * batch_size) or (3 * batch_size + 2)",
"M",
OpSchema::Optional)
.Input(4,
@ -590,7 +590,7 @@ ONNX_MS_OPERATOR_SET_SCHEMA(
OpSchema::Optional)
.Input(4,
"key_padding_mask",
"Key padding mask with shape (batch_size) or (batch_size, kv_sequence_length)",
"Key padding mask with shape (batch_size) or (3 * batch_size + 2) or (batch_size, kv_sequence_length)",
"M",
OpSchema::Optional)
.Input(5,

View file

@ -3090,6 +3090,195 @@ void GetSelfAttentionDataWithPast(AttentionTestData& data) {
data.is_static_kv = false;
}
void GetAttentionDataCutlassRelPosBias(AttentionTestData& data) {
data.hidden_size = 8;
data.v_hidden_size = 8;
data.num_heads = 2;
data.batch_size = 1;
data.sequence_length = 8;
data.kv_sequence_length = 0;
data.mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN_START;
data.key_padding_mask_data = {8, 0, 8, 0, 8};
data.skip_kernel_types = {
AttentionKernelType::AttentionKernel_TrtFlashAttention,
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
AttentionKernelType::AttentionKernel_TrtFusedAttention};
{
data.query_data = {
-0.029273793f, 0.079709493f, 0.064531095f, 0.24270254f,
-0.28326464f, 0.20984903f, -0.10173888f, 0.18373983f,
0.089472905f, -0.0063416883f, -0.049477674f, 0.36512995f,
-0.23620239f, 0.1464397f, 0.068258412f, 0.31627196f,
0.12436871f, -0.0075563118f, -0.11576633f, 0.41008925f,
-0.19456652f, 0.20145792f, 0.11790096f, 0.39789933f,
0.002485469f, 0.029660821f, -0.043821491f, 0.3892332f,
-0.26994205f, 0.14530671f, 0.12950704f, 0.36185294f,
-0.029273793f, 0.079709493f, 0.064531095f, 0.24270254f,
-0.28326464f, 0.20984903f, -0.10173888f, 0.18373983f,
0.089472905f, -0.0063416883f, -0.049477674f, 0.36512995f,
-0.23620239f, 0.1464397f, 0.068258412f, 0.31627196f,
0.12436871f, -0.0075563118f, -0.11576633f, 0.41008925f,
-0.19456652f, 0.20145792f, 0.11790096f, 0.39789933f,
0.002485469f, 0.029660821f, -0.043821491f, 0.3892332f,
-0.26994205f, 0.14530671f, 0.12950704f, 0.36185294f,
};
}
{
data.key_data = {
-0.32538497f, 0.34121913f, -0.18170178f, -0.015152611f,
0.20429322f, 0.25979176f, 0.21269324f, 0.0025638193f,
-0.24246037f, 0.21112341f, -0.36959589f, -0.16091451f,
0.24183474f, 0.18856162f, 0.094487116f, -0.3053959f,
-0.35736683f, 0.29276621f, -0.4217523f, -0.20031664f,
0.33148992f, 0.26928401f, 0.19360018f, -0.39494509f,
-0.28043351f, 0.24279942f, -0.29154932f, -0.13657911f,
0.31932494f, 0.3500579f, 0.027172565f, -0.19327414f,
-0.32538497f, 0.34121913f, -0.18170178f, -0.015152611f,
0.20429322f, 0.25979176f, 0.21269324f, 0.0025638193f,
-0.24246037f, 0.21112341f, -0.36959589f, -0.16091451f,
0.24183474f, 0.18856162f, 0.094487116f, -0.3053959f,
-0.35736683f, 0.29276621f, -0.4217523f, -0.20031664f,
0.33148992f, 0.26928401f, 0.19360018f, -0.39494509f,
-0.28043351f, 0.24279942f, -0.29154932f, -0.13657911f,
0.31932494f, 0.3500579f, 0.027172565f, -0.19327414f,
};
}
{
data.value_data = {
0.56916672f, -0.2443777f, 0.47111356f, -0.52134115f,
0.010381341f, 0.0696759f, -0.071910433f, -0.35201436f,
0.70809275f, -0.24479815f, 0.41633749f, -0.34744334f,
-0.0044222325f, 0.25929695f, -0.087832771f, -0.281232f,
0.90039468f, -0.28931504f, 0.56394172f, -0.43948689f,
-0.05856207f, 0.33713666f, -0.10320446f, -0.38833332f,
0.76054728f, -0.29080144f, 0.50414616f, -0.42371163f,
-0.047198489f, 0.31959397f, -0.22683662f, -0.30321664f,
0.56916672f, -0.2443777f, 0.47111356f, -0.52134115f,
0.010381341f, 0.0696759f, -0.071910433f, -0.35201436f,
0.70809275f, -0.24479815f, 0.41633749f, -0.34744334f,
-0.0044222325f, 0.25929695f, -0.087832771f, -0.281232f,
0.90039468f, -0.28931504f, 0.56394172f, -0.43948689f,
-0.05856207f, 0.33713666f, -0.10320446f, -0.38833332f,
0.76054728f, -0.29080144f, 0.50414616f, -0.42371163f,
-0.047198489f, 0.31959397f, -0.22683662f, -0.30321664f,
};
}
{
data.bias_data = {
-0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f,
0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f,
-0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f,
-0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f,
0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f,
-0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f
};
}
{
data.rel_pos_bias_data = {
-10.808288f, -10.887209f, 7.8799553f, -4.6565766f,
-1.6700006f, -0.033962168f, 7.4929152f, 10.944146f,
8.640254f, -18.862164f, -3.1202927f, -6.3049207f,
3.4508536f, 11.722519f, 3.3550568f, -5.4888172f,
-2.0828252f, -13.241742f, 2.9868939f, 1.4455698f,
-15.262972f, -10.457437f, -8.4519463f, -4.4281874f,
10.212368f, -0.28622282f, 12.087646f, 6.5218501f,
8.1785011f, 13.985523f, -8.2068987f, 5.4260745f,
-10.808288f, -10.887209f, 7.8799553f, -4.6565766f,
-1.6700006f, -0.033962168f, 7.4929152f, 10.944146f,
8.640254f, -18.862164f, -3.1202927f, -6.3049207f,
3.4508536f, 11.722519f, 3.3550568f, -5.4888172f,
-2.0828252f, -13.241742f, 2.9868939f, 1.4455698f,
-15.262972f, -10.457437f, -8.4519463f, -4.4281874f,
10.212368f, -0.28622282f, 12.087646f, 6.5218501f,
8.1785011f, 13.985523f, -8.2068987f, 5.4260745f,
-10.808288f, -10.887209f, 7.8799553f, -4.6565766f,
-1.6700006f, -0.033962168f, 7.4929152f, 10.944146f,
8.640254f, -18.862164f, -3.1202927f, -6.3049207f,
3.4508536f, 11.722519f, 3.3550568f, -5.4888172f,
-2.0828252f, -13.241742f, 2.9868939f, 1.4455698f,
-15.262972f, -10.457437f, -8.4519463f, -4.4281874f,
10.212368f, -0.28622282f, 12.087646f, 6.5218501f,
8.1785011f, 13.985523f, -8.2068987f, 5.4260745f,
-10.808288f, -10.887209f, 7.8799553f, -4.6565766f,
-1.6700006f, -0.033962168f, 7.4929152f, 10.944146f,
8.640254f, -18.862164f, -3.1202927f, -6.3049207f,
3.4508536f, 11.722519f, 3.3550568f, -5.4888172f,
-2.0828252f, -13.241742f, 2.9868939f, 1.4455698f,
-15.262972f, -10.457437f, -8.4519463f, -4.4281874f,
10.212368f, -0.28622282f, 12.087646f, 6.5218501f,
8.1785011f, 13.985523f, -8.2068987f, 5.4260745f,
};
}
{
data.fp16_output_data = {
1.0419922f, 0.13000488f, 0.10528564f, -0.86230469f,
-0.45336914f, 0.39013672f, -0.048858643f, 0.10571289f,
0.97265625f, 0.17590332f, 0.015625f, -0.79248047f,
-0.40917969f, 0.31933594f, 0.082763672f, 0.12976074f,
1.1455078f, 0.13134766f, 0.15014648f, -0.87451172f,
-0.46142578f, 0.40161133f, 0.04309082f, 0.042663574f,
1.0009766f, 0.17004395f, 0.033752441f, -0.80078125f,
-0.41625977f, 0.33349609f, 0.080383301f, 0.11846924f,
1.0419922f, 0.13000488f, 0.10528564f, -0.86230469f,
-0.45336914f, 0.39013672f, -0.048858643f, 0.10571289f,
0.97265625f, 0.17590332f, 0.015625f, -0.79248047f,
-0.40917969f, 0.31933594f, 0.082763672f, 0.12976074f,
1.1455078f, 0.13134766f, 0.15014648f, -0.87451172f,
-0.46142578f, 0.40161133f, 0.04309082f, 0.042663574f,
1.0009766f, 0.17004395f, 0.033752441f, -0.80078125f,
-0.41625977f, 0.33349609f, 0.080383301f, 0.11846924f,
};
}
{
data.fp32_output_data = {};
}
data.is_static_kv = false;
}
bool SkipAttentionKernel(AttentionTestData& data, AttentionKernelType kernel_type) {
return std::find(data.skip_kernel_types.begin(), data.skip_kernel_types.end(), kernel_type) != data.skip_kernel_types.end();
}

View file

@ -62,6 +62,8 @@ void GetCrossAttentionData_HeadSize16(AttentionTestData& data);
void GetCrossAttentionDataWithPast(AttentionTestData& data);
void GetSelfAttentionDataWithPast(AttentionTestData& data);
void GetAttentionDataCutlassRelPosBias(AttentionTestData& data);
bool SkipAttentionKernel(AttentionTestData& data, AttentionKernelType kernel_type);
} // namespace test
} // namespace onnxruntime

View file

@ -55,7 +55,9 @@ static void RunMultiHeadAttentionTest(
std::vector<int64_t> key_dims = {batch_size, is_static_kv ? kv_sequence_length : sequence_length, hidden_size};
std::vector<int64_t> value_dims = {batch_size, is_static_kv ? kv_sequence_length : sequence_length, v_hidden_size};
std::vector<int64_t> bias_dims = {hidden_size + hidden_size + v_hidden_size};
std::vector<int64_t> rel_pos_bias_dims = {1, num_heads, sequence_length, sequence_length + kv_sequence_length};
// TODO(wy): Introduce past sequence length to avoid using kv_sequence_length.
std::vector<int64_t> rel_pos_bias_dims =
{1, num_heads, sequence_length, past_key_data.size() ? sequence_length + kv_sequence_length : sequence_length};
std::vector<int64_t> past_key_dims = {batch_size, num_heads, kv_sequence_length, hidden_size / num_heads};
std::vector<int64_t> past_value_dims = past_key_dims;
std::vector<int64_t> output_dims = {batch_size, sequence_length, v_hidden_size};
@ -82,9 +84,10 @@ static void RunMultiHeadAttentionTest(
std::vector<int64_t> mask_dims_1 = {batch_size};
std::vector<int64_t> mask_dims_2 = {batch_size, kv_sequence_length};
std::vector<int64_t> mask_dims_3 = {3 * batch_size + 2};
std::vector<int64_t>& key_padding_mask_dims = (mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN)
? mask_dims_1
: mask_dims_2;
? mask_dims_1
: (mask_type == AttentionMaskType::MASK_2D_KEY_PADDING ? mask_dims_2 : mask_dims_3);
if (use_float16) {
tester.AddInput<MLFloat16>("query", query_dims, ToFloat16(query));
@ -487,5 +490,11 @@ TEST(MultiHeadAttentionTest, SelfAttentionWithPast) {
RunMultiHeadAttentionTests(data);
}
TEST(MultiHeadAttentionTest, AttentionCutlassRelPosBias) {
AttentionTestData data;
GetAttentionDataCutlassRelPosBias(data);
RunMultiHeadAttentionTests(data);
}
} // namespace test
} // namespace onnxruntime

View file

@ -11,7 +11,7 @@ steps:
packageType: upack
feed: '/7424c8e4-5c62-490e-95c4-79446f31017c'
definition: '517c4f6f-5437-4392-a70d-4f15ec5be2f0'
version: 1.0.32
version: 1.0.36
downloadPath: $(Build.BinariesDirectory)/deps
# The private ADO project
@ -22,7 +22,7 @@ steps:
packageType: upack
feed: '/4c7631f5-24c0-4307-8822-1aa8f180c325'
definition: 'fd9dd5ad-b73e-4678-890e-edcf680dbc1a'
version: 1.0.32
version: 1.0.36
downloadPath: $(Build.BinariesDirectory)/deps
# You can add more ADO accounts at here.