Add memory efficient attention from CUTLASS (#14343)

### Description
Add memory efficient attention from CUTLASS.

TODO (in next pull request): 
(1) Need performance tests on different GPUs, then add a sequence length
threshold (only activate it for long sequence length).
(2) Merge changes from https://github.com/NVIDIA/cutlass/pull/773 when
it is in cutlass master.
This commit is contained in:
Tianlei Wu 2023-01-20 12:33:01 -08:00 committed by GitHub
parent e64f357ad4
commit 414b012f42
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
29 changed files with 2143 additions and 78 deletions

View file

@ -2753,6 +2753,37 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
_____
nvidia/cutlass
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met:
1. Redistributions of source code must retain the above copyright notice, this
list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright notice,
this list of conditions and the following disclaimer in the documentation
and/or other materials provided with the distribution.
3. Neither the name of the copyright holder nor the names of its
contributors may be used to endorse or promote products derived from
this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
_____
Boost

View file

@ -64,6 +64,8 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
option(onnxruntime_USE_FLASH_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
option(onnxruntime_USE_AVX "Use AVX instructions" OFF)
option(onnxruntime_USE_AVX2 "Use AVX2 instructions" OFF)
@ -595,10 +597,31 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu)
set(ORT_PROVIDER_FLAGS)
set(ORT_PROVIDER_CMAKE_FLAGS)
if (onnxruntime_USE_CUDA)
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
if (onnxruntime_DISABLE_CONTRIB_OPS)
set(onnxruntime_USE_FLASH_ATTENTION OFF)
endif()
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
set(onnxruntime_USE_FLASH_ATTENTION OFF)
endif()
else()
set(onnxruntime_USE_FLASH_ATTENTION OFF)
endif()
if (onnxruntime_USE_CUDA)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUDA=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CUDA=1)
list(APPEND ONNXRUNTIME_PROVIDER_NAMES cuda)
if (onnxruntime_USE_FLASH_ATTENTION)
message( STATUS "Enable flash attention for CUDA EP")
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1)
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1)
endif()
endif()
if (onnxruntime_USE_VITISAI)
list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1)
@ -1234,8 +1257,6 @@ endif()
if (onnxruntime_USE_CUDA)
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
enable_language(CUDA)
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
set(CMAKE_CUDA_STANDARD 17)
if(onnxruntime_CUDNN_HOME)
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)

12
cmake/external/cutlass.cmake vendored Normal file
View file

@ -0,0 +1,12 @@
if (onnxruntime_USE_FLASH_ATTENTION)
include(FetchContent)
FetchContent_Declare(cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
GIT_TAG 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6
)
FetchContent_GetProperties(cutlass)
if(NOT cutlass_POPULATED)
FetchContent_Populate(cutlass)
endif()
endif()

View file

@ -484,6 +484,12 @@ if (onnxruntime_USE_CUDA)
if(onnxruntime_CUDNN_HOME)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${onnxruntime_CUDNN_HOME}/include)
endif()
if (onnxruntime_USE_FLASH_ATTENTION)
include(cutlass)
target_include_directories(onnxruntime_providers_cuda PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
endif()
target_include_directories(onnxruntime_providers_cuda PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
set_target_properties(onnxruntime_providers_cuda PROPERTIES LINKER_LANGUAGE CUDA)

View file

@ -23,6 +23,7 @@ set(contrib_ops_excluded_files
"bert/skip_layer_norm.h"
"bert/skip_layer_norm_impl.cu"
"bert/skip_layer_norm_impl.h"
"bert/cutlass_fmha/*"
"bert/tensorrt_fused_multihead_attention/*"
"bert/transformer_common.h"
"bert/transformer_common.cc"

View file

@ -19,7 +19,7 @@ enum AttentionMaskType {
enum AttentionQkvFormat {
Q_K_V_BNSH, // for unfused attention
Q_K_V_BSNH, // input format of query, key and value for MultiHeadAttention
Q_K_V_BSNH, // for memory efficient attention, or format of query, key and value for MultiHeadAttention
QKV_BSN3H, // for TRT fused attention, qkv are packed
Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
@ -30,6 +30,7 @@ enum AttentionKernelType{
AttentionKernel_TrtFusedAttention,
AttentionKernel_TrtFlashAttention,
AttentionKernel_TrtFusedCrossAttention,
AttentionKernel_CutlassMemoryEfficientAttention,
AttentionKernel_Default
};
@ -61,8 +62,11 @@ constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION";
// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled).
constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION";
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";
// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled).
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";
// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION";
} // namespace attention

View file

@ -320,6 +320,110 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co
}
}
template <typename T>
__global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) {
// Format 3 for cutlass memory efficient attention
// Input: BxSx(NxH + NxH + NxH_v) (Packed QKV where K and V has different hidden sizes)
// Output: BxNxSxH + BxNxSxH + BxNxSxH_v
// B is batch_size, S is sequence_length, N is num_heads, H is qk_head_size, H_v is v_head_size
int n = threadIdx.y; // head_num_id
int s = blockIdx.x; // sequence_id
int b = blockIdx.y; // batch_id
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
const int h = threadIdx.x; // head_element_id
const int qk_head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int head_size = (m == 2 ? v_head_size : qk_head_size);
const int total_head_size = num_heads * (qk_head_size + qk_head_size + v_head_size);
int in_offset;
int out_offset;
int bias_offset;
in_offset = b * (total_head_size * sequence_length) + // B
s * (total_head_size) + // S
m * (qk_head_size * num_heads) + // M
n * head_size + // N
h; // H
out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
b * (num_heads * head_size * sequence_length) + // B
s * (num_heads * head_size) + // S
n * (head_size) + // N
h; // H
bias_offset = m * (num_heads * qk_head_size) + // M
n * (head_size) + // N
h; // H
if (h < head_size) {
output[out_offset] = input[in_offset] + biases[bias_offset];
}
}
template <typename T>
__global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) {
// Format 3 for cutlass memory efficient attention
// Input: BxSxMxNxH
// Output: MxBxSxNxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int H = head_size;
const int NH = num_heads * head_size;
const int NHS = NH * sequence_length;
int in_offset = n * head_size + (m + s * M) * NH + b * NHS * M;
const int out_offset = n * head_size + s * NH + b * NHS + m * NHS * batch_size;
const int h = threadIdx.x;
if (h < head_size) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
}
}
template <typename T>
__global__ void AddBiasTransposeCutlassLarge(const int head_size, const T* input, const T* biases, T* output,
const int M) {
// Format 3 for cutlass memory efficient attention
// Input: BxSxMxNxH (Packed QKV)
// Output: MxBxSxNxH
int n = threadIdx.y;
int s = blockIdx.x;
int b = blockIdx.y;
int m = blockIdx.z; // matrix id
const int stride = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int H = head_size;
const int NH = num_heads * H;
const int NHS = NH * sequence_length;
int in_offset = n * H + (m + s * M) * NH + b * NHS * M;
const int out_offset = n * H + s * NH + b * NHS + m * NHS * batch_size;
int h = threadIdx.x;
while (h < H) {
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
h += stride;
}
}
template <typename T>
__global__ void AddBiasTranspose(const T* input, const T* biases, T* output) {
// Format 0 for Separated Q, K, V (N*H <= 1024)
@ -395,6 +499,13 @@ void InvokeAddBiasTranspose(
ORT_ENFORCE(total_matrix_count == 3);
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
}
} else if (format == 3) {
if (v_head_size == -1 || qk_head_size == v_head_size) {
AddBiasTransposeCutlass<T><<<grid, block, 0, stream>>>(total_matrix_count, input, biases, output);
} else {
ORT_ENFORCE(total_matrix_count == 3);
AddBiasTransposeCutlass<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
}
} else { // format == 0
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
}
@ -410,6 +521,13 @@ void InvokeAddBiasTranspose(
// It is rare for hidden size > 4096 (for half precision) and qk_head_size != v_head_size.
ORT_THROW("AddBiasTranspose (format 1) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size");
}
} else if (format == 3) {
if (v_head_size == -1 || qk_head_size == v_head_size) {
AddBiasTransposeCutlassLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output,
total_matrix_count);
} else {
ORT_THROW("AddBiasTranspose (format 3) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size");
}
} else { // format 0
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
}

View file

@ -21,6 +21,9 @@ namespace cuda {
// format 2:
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
// format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3)
// input: (batch_size, sequence_length, num_matrices, num_heads, head_size)
// output: (num_matrices, batch_size, sequence_length, num_heads, head_size)
template <typename T>
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,

View file

@ -7,6 +7,7 @@
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/attention.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
@ -41,8 +42,14 @@ Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB
disable_fused_runner_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
enable_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
#if USE_FLASH_ATTENTION
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
}
template <typename T>
@ -102,12 +109,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_flash_attention_, true);
enable_trt_flash_attention_, true);
if (use_causal_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_flash_attention_, parameters.scale));
enable_trt_flash_attention_, parameters.scale));
}
// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
@ -122,13 +129,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_flash_attention_, false);
enable_trt_flash_attention_, false);
if (use_fused_runner) {
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
enable_flash_attention_, parameters.scale));
enable_trt_flash_attention_, parameters.scale));
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
@ -139,6 +146,18 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
}
}
#if USE_FLASH_ATTENTION
bool use_memory_efficient_attention = fused_runner == nullptr &&
!disable_memory_efficient_attention_ &&
nullptr == mask_index && // TODO: support 1D mask
nullptr == past &&
nullptr == present &&
nullptr == extra_add_qk &&
has_memory_efficient_attention(sm, sizeof(T) == 2);
#else
constexpr bool use_memory_efficient_attention = false;
#endif
cublasHandle_t cublas = GetCublasHandle(context);
typedef typename ToCudaType<T>::MappedType CudaT;
@ -169,7 +188,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length,
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner);
fused_runner,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
typedef typename ToCudaType<T>::MappedType CudaT;
@ -188,6 +208,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = nullptr;
data.use_memory_efficient_attention = use_memory_efficient_attention;
return QkvToContext<CudaT>(device_prop, cublas, Stream(context), parameters, data);
}

View file

@ -22,7 +22,8 @@ class Attention final : public CudaKernel, public AttentionBase {
protected:
bool disable_fused_runner_;
bool enable_flash_attention_;
bool enable_trt_flash_attention_;
bool disable_memory_efficient_attention_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
};

View file

@ -40,6 +40,7 @@ limitations under the License.
#include "contrib_ops/cpu/bert/attention_base.h"
#include "contrib_ops/cuda/bert/bert_padding.h"
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
using namespace onnxruntime::cuda;
using namespace cub;
@ -101,11 +102,25 @@ size_t GetAttentionWorkspaceSize(
size_t sequence_length,
size_t kv_sequence_length,
size_t total_sequence_length,
void* fused_runner) {
void* fused_runner,
bool use_memory_efficient_attention) {
// Note that q, k and v might need alignment for fused attention kernels.
const size_t qkv_bytes = element_size * batch_size * num_heads *
((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size);
#if USE_FLASH_ATTENTION
if (use_memory_efficient_attention) {
size_t fmha_buffer_bytes = 0;
if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) {
fmha_buffer_bytes = batch_size * sequence_length * num_heads * v_head_size * sizeof(float);
}
return qkv_bytes + fmha_buffer_bytes;
}
#else
ORT_UNUSED_PARAMETER(use_memory_efficient_attention);
#endif
if (fused_runner != nullptr) {
size_t sequence_offset_bytes = GetSequenceOffsetSize(static_cast<int>(batch_size), true);
return qkv_bytes + sequence_offset_bytes;
@ -259,12 +274,15 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
const int v_head_size = parameters.v_head_size;
const bool past_present_share_buffer = parameters.past_present_share_buffer;
void* fused_runner = data.fused_runner;
bool use_memory_efficient_attention = data.use_memory_efficient_attention;
T* qkv = data.workspace;
bool use_fused_kernel = (nullptr != fused_runner && data.bias != nullptr && !parameters.is_unidirectional);
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
// Default format for memory efficient attention.
// When there is past state, the format shal be BxNxSxH, so we disable memory efficient attention when there is past.
DUMP_ATTENTION_INIT();
if (nullptr != data.gemm_buffer) {
if (data.bias == nullptr) {
@ -277,13 +295,16 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
} else {
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
// For memory efficient attention, transpose to 3xBxSxNxH (format 3)
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
// For fused causal kernel, use format 1 since we need have K and V in BNSH format to update present state,
// we also update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : 1);
// For fused causal kernel, use format 1 since we need have K and V to update present state,
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
const int format = (use_fused_kernel ? 2 : (use_memory_efficient_attention ? 3 : 1));
qkv_format = use_fused_kernel
? AttentionQkvFormat::QKV_BSN3H
: (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH);
: (use_memory_efficient_attention
? AttentionQkvFormat::Q_K_V_BSNH
: (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH));
// For fused causal, we will update gemm_buffer with bias directly.
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
@ -318,7 +339,21 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
} else if (use_fused_kernel) {
}
#if USE_FLASH_ATTENTION
else if (use_memory_efficient_attention) {
LaunchAddBias(stream, max_threads_per_block,
batch_size, sequence_length, kv_sequence_length,
num_heads, qk_head_size, v_head_size,
data.bias, data.query, data.key, data.value, q, k, v);
DUMP_ATTENTION_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
}
#endif
else if (use_fused_kernel) {
assert(qk_head_size == v_head_size);
// Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
@ -382,9 +417,10 @@ Status QkvToContext(
const bool past_present_share_buffer = parameters.past_present_share_buffer;
const float mask_filter_value = parameters.mask_filter_value;
void* fused_runner = data.fused_runner;
bool use_memory_efficient_attention = data.use_memory_efficient_attention;
// At most one fused kernel is enabled.
assert(int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1);
assert(int(use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1);
const int batches = batch_size * num_heads;
const int size_per_batch_q = sequence_length * qk_head_size;
@ -433,6 +469,7 @@ Status QkvToContext(
assert(data.fused_cross_attention_kernel == nullptr);
assert(!use_fused_kernel);
assert(data.gemm_buffer != nullptr);
assert(!use_memory_efficient_attention);
if (data.present != data.past) {
// For easy testing. Production should better avoid this path.
@ -526,6 +563,38 @@ Status QkvToContext(
return Status::OK();
}
#if USE_FLASH_ATTENTION
if (use_memory_efficient_attention) {
// We only enable fused cross attention when there is no key padding mask.
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
assert(data.mask_index == nullptr);
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
MemoryEfficientAttentionParams p;
p.sm = device_prop.major * 10 + device_prop.minor;
p.is_half = sizeof(T) == 2;
p.batch_size = data.mask_index == nullptr ? parameters.batch_size : 2 * parameters.batch_size;
p.num_heads = parameters.num_heads;
p.sequence_length = parameters.sequence_length;
p.kv_sequence_length = parameters.total_sequence_length;
p.qk_head_size = parameters.head_size;
p.v_head_size = parameters.v_head_size;
p.causal = parameters.is_unidirectional;
p.cu_seqlens_q = nullptr;
p.cu_seqlens_k = nullptr;
p.query = q;
p.key = k;
p.value = v;
p.output = data.output;
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr;
p.stream = stream;
run_memory_efficient_attention(p);
DUMP_ATTENTION("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size);
return Status::OK();
}
#endif
// The following are unfused attention.
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
const int* mask_index = data.mask_index;

View file

@ -27,7 +27,8 @@ size_t GetAttentionWorkspaceSize(
size_t sequence_length,
size_t kv_sequence_length,
size_t total_sequence_length,
void* fused_runner);
void* fused_runner,
bool use_memory_efficient_attention = false);
template <typename T>
struct AttentionData {
@ -48,6 +49,8 @@ struct AttentionData {
void* fused_runner;
const void* fused_cross_attention_kernel;
bool use_memory_efficient_attention;
};
template <typename T>

View file

@ -0,0 +1,116 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if USE_FLASH_ATTENTION
#if defined(__GNUC__)
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
#endif
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, bool single_value_iteration>
void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, single_value_iteration>;
typename Attention::Params p;
{ // set parameters
p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query));
p.key_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.key));
p.value_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.value));
p.cu_seqlens_q_ptr = params.cu_seqlens_q;
p.cu_seqlens_k_ptr = params.cu_seqlens_k;
p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward
p.output_ptr = reinterpret_cast<T*>(params.output);
if (Attention::kNeedsOutputAccumulatorBuffer) {
using Acc = typename Attention::accum_t;
// workspace size: batch_size * sequence_length * num_heads * v_head_size * sizeof(float)
ORT_ENFORCE(params.workspace != nullptr, "Need output accumulator buffer but no workspace provided");
p.output_accum_ptr = reinterpret_cast<Acc*>(params.workspace);
} else {
p.output_accum_ptr = nullptr;
}
p.num_heads = params.num_heads;
p.num_batches = params.batch_size;
p.head_dim = params.qk_head_size;
p.head_dim_value = params.v_head_size;
// When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel
p.num_queries = params.sequence_length;
p.num_keys = params.kv_sequence_length;
p.causal = params.causal;
// Input format is BxSxNxH, output is BxSxNxH
p.q_strideH = params.qk_head_size;
p.k_strideH = params.qk_head_size;
p.v_strideH = params.v_head_size;
p.o_strideH = params.v_head_size;
p.q_strideM = params.num_heads * params.qk_head_size;
p.k_strideM = params.num_heads * params.qk_head_size;
p.v_strideM = params.num_heads * params.v_head_size;
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
p.o_strideB = static_cast<int64_t>(params.num_heads) * params.v_head_size * params.sequence_length;
p.causal = params.causal;
}
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
int smem_bytes = sizeof(typename Attention::SharedStorage);
if (smem_bytes > 0xc000) {
ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!");
static bool once = [&]() {
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
return true;
}();
}
ORT_ENFORCE(Attention::check_supported(p));
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, params.stream>>>(p);
}
template <typename T, typename ArchTag, int queries_per_block, int keys_per_block, bool single_value_iteration>
void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, single_value_iteration>;
// Run a more efficient kernel with `isAligned=True` when memory is correctly aligned.
bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 &&
params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
params.v_head_size % AlignedAK::kAlignmentV == 0;
DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, single_value_iteration>(params);
}));
}
template <typename T, typename ArchTag>
void DispatchBlockSize(const MemoryEfficientAttentionParams& params) {
if (params.v_head_size <= 64) {
DispatchIsAligned<T, ArchTag, 64, 64, true>(params);
} else if (params.v_head_size <= 128) {
DispatchIsAligned<T, ArchTag, 32, 128, true>(params);
} else {
DispatchIsAligned<T, ArchTag, 32, 128, false>(params);
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#if defined(__GNUC__)
#pragma GCC diagnostic pop
#endif
#endif // USE_FLASH_ATTENTION

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if USE_FLASH_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params) {
if (params.is_half) {
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm50>(params);
} else {
DispatchBlockSize<float, cutlass::arch::Sm50>(params);
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#endif // USE_FLASH_ATTENTION

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if USE_FLASH_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params) {
if (params.is_half) {
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm70>(params);
} else {
DispatchBlockSize<float, cutlass::arch::Sm70>(params);
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#endif // USE_FLASH_ATTENTION

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if USE_FLASH_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params) {
if (params.is_half) {
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm75>(params);
} else {
DispatchBlockSize<float, cutlass::arch::Sm75>(params);
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#endif // USE_FLASH_ATTENTION

View file

@ -0,0 +1,24 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if USE_FLASH_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params) {
if (params.is_half) {
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm80>(params);
} else {
DispatchBlockSize<float, cutlass::arch::Sm80>(params);
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#endif // USE_FLASH_ATTENTION

View file

@ -0,0 +1,947 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holdvr nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#if USE_FLASH_ATTENTION
#include <cmath>
#include <vector>
#include "cutlass/bfloat16.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
#include "cutlass/layout/vector.h"
#include "cutlass/numeric_types.h"
#include "41_fused_multi_head_attention/attention_scaling_coefs_updater.h"
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm.h"
#include "cutlass/gemm/threadblock/default_mma.h"
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/matrix_shape.h"
#include "cutlass/platform/platform.h"
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
#include "41_fused_multi_head_attention/debug_utils.h"
#include "41_fused_multi_head_attention/epilogue_pipelined.h"
#include "41_fused_multi_head_attention/epilogue_rescale_output.h"
#include "41_fused_multi_head_attention/find_default_mma.h"
#include "41_fused_multi_head_attention/gemm_kernel_utils.h"
#include "41_fused_multi_head_attention/mma_from_smem.h"
#include <inttypes.h>
using namespace gemm_kernel_utils;
namespace {
template <typename scalar_t, typename Arch>
constexpr int getWarpsPerSm() {
return (
Arch::kMinComputeCapability >= 80 &&
!cutlass::platform::is_same<scalar_t, float>::value
? 16
: 12);
}
} // namespace
template <
// The datatype of Q/K/V
typename scalar_t_,
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
typename ArchTag,
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
bool isAligned_,
int kQueriesPerBlock,
int kKeysPerBlock,
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
>
struct AttentionKernel {
using scalar_t = scalar_t_;
using accum_t = float;
using lse_scalar_t = float;
using output_t = scalar_t;
// Accumulator between 2 iterations
// Using `accum_t` improves perf on f16 at the cost of
// numerical errors
using output_accum_t = accum_t;
static constexpr bool kIsAligned = isAligned_;
static constexpr int32_t kAlignLSE = 32; // block size of backward
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
cutlass::sizeof_bits<scalar_t>::value == 16;
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
!cutlass::platform::is_same<output_accum_t, output_t>::value;
static_assert(kQueriesPerBlock % 32 == 0, "");
static_assert(kKeysPerBlock % 32 == 0, "");
static constexpr int kNumWarpsPerBlock =
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
static constexpr int kWarpSize = 32;
// Launch bounds
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
static constexpr int kMinBlocksPerSm =
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
struct Params {
// Input tensors
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
int32_t* cu_seqlens_q_ptr = nullptr;
int32_t* cu_seqlens_k_ptr = nullptr;
// Output tensors
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
output_accum_t*
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
// Dimensions/strides
int32_t head_dim;
int32_t head_dim_value;
int32_t num_queries;
int32_t num_keys;
bool causal;
int32_t q_strideM;
int32_t k_strideM;
int32_t v_strideM;
// Everything below is only used in `advance_to_block`
// and shouldn't use registers
int32_t q_strideH;
int32_t k_strideH;
int32_t v_strideH;
int32_t o_strideH;
int64_t q_strideB;
int64_t k_strideB;
int64_t v_strideB;
int64_t o_strideB;
int32_t num_batches;
int32_t num_heads;
// https://github.com/NVIDIA/cutlass/issues/771
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
return head_dim_value * num_heads;
}
// Moves pointers to what we should process
// Returns "false" if there is no work to do
CUTLASS_DEVICE bool advance_to_block() {
auto batch_id = blockIdx.z;
auto head_id = blockIdx.y;
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
int64_t q_start, k_start;
// Advance to current batch - in case of different sequence lengths
if (cu_seqlens_q_ptr != nullptr) {
assert(cu_seqlens_k_ptr != nullptr);
cu_seqlens_q_ptr += batch_id;
cu_seqlens_k_ptr += batch_id;
q_start = cu_seqlens_q_ptr[0];
k_start = cu_seqlens_k_ptr[0];
int64_t q_next_start = cu_seqlens_q_ptr[1];
int64_t k_next_start = cu_seqlens_k_ptr[1];
num_queries = q_next_start - q_start;
num_keys = k_next_start - k_start;
if (query_start >= num_queries) {
return false;
}
} else {
query_ptr += batch_id * q_strideB;
key_ptr += batch_id * k_strideB;
value_ptr += batch_id * v_strideB;
output_ptr += batch_id * o_strideB;
if (output_accum_ptr != nullptr) {
output_accum_ptr += batch_id * o_strideB;
}
q_start = 0;
k_start = 0;
}
// Advance to the current batch / head / query_start
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
key_ptr += k_start * k_strideM + head_id * k_strideH;
value_ptr += k_start * v_strideM + head_id * v_strideH;
output_ptr += int64_t(q_start + query_start) * o_strideM() +
head_id * o_strideH;
if (output_accum_ptr != nullptr) {
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
head_id * o_strideH;
} else {
// Accumulate directly in the destination buffer (eg for f32)
output_accum_ptr = (accum_t*)output_ptr;
}
if (logsumexp_ptr != nullptr) {
// lse[batch_id, head_id, query_start]
logsumexp_ptr +=
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
}
num_queries -= query_start;
if (causal) {
num_keys = cutlass::fast_min(
int32_t(query_start + kQueriesPerBlock), num_keys);
}
num_batches = 0; // no longer used after
// Make sure the compiler knows these variables are the same on all
// the threads of the warp.
query_ptr = warp_uniform(query_ptr);
key_ptr = warp_uniform(key_ptr);
value_ptr = warp_uniform(value_ptr);
output_ptr = warp_uniform(output_ptr);
output_accum_ptr = warp_uniform(output_accum_ptr);
logsumexp_ptr = warp_uniform(logsumexp_ptr);
num_queries = warp_uniform(num_queries);
num_keys = warp_uniform(num_keys);
head_dim = warp_uniform(head_dim);
head_dim_value = warp_uniform(head_dim_value);
return true;
}
__host__ dim3 getBlocksGrid() const {
return dim3(
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
num_heads,
num_batches);
}
__host__ dim3 getThreadsGrid() const {
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
}
};
struct MM0 {
/*
In this first matmul, we compute a block of `Q @ K.T`.
While the calculation result is still hot in registers, we update
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
into a shared-memory ("AccumulatorSharedStorage") that is used later as
operand A for the second matmul (see MM1)
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
scalar_t,
scalar_t,
scalar_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA =
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
static constexpr int kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
cutlass::layout::ColumnMajor, // LayoutB,
kAlignmentB,
accum_t,
cutlass::layout::RowMajor, // LayoutC,
OpClass,
ArchTag, // ArchTag
ThreadblockShape, // ThreadblockShape
WarpShape, // WarpShape
typename GemmType::InstructionShape, // InstructionShape
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
// uses too much smem
typename GemmType::Operator // Operator
>::DefaultMma;
using MmaCore = typename DefaultMma::MmaCore;
using IteratorA = typename DefaultMma::IteratorA;
using IteratorB = typename DefaultMma::IteratorB;
using Mma = typename DefaultMma::ThreadblockMma;
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
typename Mma::Operator::IteratorC,
accum_t,
kWarpSize>::Updater;
static_assert(
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
MmaCore::WarpCount::kK ==
kNumWarpsPerBlock,
"");
// Epilogue to store to shared-memory in a format that we can use later for
// the second matmul
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
typename Mma::Operator::IteratorC,
typename Mma::Operator,
scalar_t,
WarpShape,
ThreadblockShape>;
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
};
struct MM1 {
/**
Second matmul: perform `attn @ V` where `attn` is the attention (not
normalized) and stored in shared memory
*/
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
using OpClass = typename GemmType::OpClass;
using DefaultConfig =
typename cutlass::gemm::device::DefaultGemmConfiguration<
OpClass,
ArchTag,
scalar_t,
scalar_t,
output_accum_t, // ElementC
accum_t // ElementAccumulator
>;
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
static constexpr int kAlignmentB =
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
using ThreadblockShape = cutlass::gemm::
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
using InstructionShape = typename GemmType::InstructionShape;
using LayoutB = cutlass::layout::RowMajor;
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
scalar_t, // ElementA,
cutlass::layout::RowMajor, // LayoutA,
kAlignmentA,
scalar_t, // ElementB,
LayoutB, // LayoutB,
kAlignmentB,
output_accum_t,
cutlass::layout::RowMajor, // LayoutC,
accum_t,
OpClass,
ArchTag,
ThreadblockShape,
WarpShape,
typename GemmType::InstructionShape,
typename DefaultConfig::EpilogueOutputOp,
void, // ThreadblockSwizzle - not used
DefaultConfig::kStages,
false, // SplitKSerial
typename GemmType::Operator>;
using DefaultMmaFromSmem =
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
typename DefaultGemm::Mma,
typename MM0::AccumulatorSharedStorage>;
using Mma = typename DefaultMmaFromSmem::Mma;
using IteratorB = typename Mma::IteratorB;
using WarpCount = typename Mma::WarpCount;
static_assert(
WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
"");
using DefaultEpilogue = typename DefaultGemm::Epilogue;
using OutputTileIterator =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_t>;
using OutputTileIteratorAccum =
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
output_accum_t>;
struct SharedStorageMM1 {
typename Mma::SharedStorage mm;
};
};
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
static constexpr int64_t kAlignmentV = 1;
// Shared storage - depends on kernel params
struct ScalingCoefs {
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
cutlass::Array<accum_t, kQueriesPerBlock> mi;
};
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return epilogue;
}
};
struct SharedStorageEpilogueInLoop : ScalingCoefs {
struct SharedStorageAfterMM0 {
// Everything here might be overwritten during MM0
typename MM0::AccumulatorSharedStorage si;
typename MM1::SharedStorageMM1 mm1;
typename MM1::DefaultEpilogue::SharedStorage epilogue;
};
union {
typename MM0::Mma::SharedStorage mm0;
SharedStorageAfterMM0 after_mm0;
};
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
epilogue_shared_storage() {
return after_mm0.epilogue;
}
};
using SharedStorage = typename cutlass::platform::conditional<
kSingleValueIteration || kKeepOutputInRF,
SharedStorageEpilogueAtEnd,
SharedStorageEpilogueInLoop>::type;
static bool __host__ check_supported(Params const& p) {
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
XFORMERS_CHECK(
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
XFORMERS_CHECK(
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
XFORMERS_CHECK(
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
XFORMERS_CHECK(
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
return true;
}
static void CUTLASS_DEVICE attention_kernel(Params& p) {
// In this block, we will only ever:
// - read query[query_start:query_end, :]
// - write to output[query_start:query_end, :]
extern __shared__ char smem_buffer[];
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
auto& m_prime = shared_storage.m_prime;
auto& s_prime = shared_storage.s_prime;
auto& mi = shared_storage.mi;
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = accum_t(0);
m_prime[thread_id()] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
}
typename MM1::Mma::FragmentC accum_o;
accum_o.clear();
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
using OutputTileIterator = typename MM1::OutputTileIterator;
return OutputTileIterator(
typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
p.output_ptr,
typename OutputTileIterator::TensorCoord{
p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
auto createOutputAccumIter = [&](int col) ->
typename MM1::OutputTileIteratorAccum {
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
return OutputTileIteratorAccum(
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
p.output_accum_ptr,
typename OutputTileIteratorAccum::TensorCoord{
p.num_queries, p.head_dim_value},
thread_id(),
{0, col});
};
// Iterate through keys
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
iter_key_start += kKeysPerBlock) {
int32_t problem_size_0_m =
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
int32_t problem_size_0_n = cutlass::fast_min(
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
int32_t const& problem_size_0_k = p.head_dim;
int32_t const& problem_size_1_n = p.head_dim_value;
int32_t const& problem_size_1_k = problem_size_0_n;
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
MM1::Mma::prologue(
shared_storage.after_mm0.mm1.mm,
iterator_V,
thread_id(),
problem_size_1_k);
};
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
// updated from end of prev iter
//
// MATMUL: Q.K_t
//
// Computes the block-matrix product of:
// (a) query[query_start:query_end, :]
// with
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
// and stores that into `shared_storage.si`
//
// Compute threadblock location
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
cutlass::MatrixCoord tb_offset_A{
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
cutlass::MatrixCoord tb_offset_B{
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
// Construct iterators to A and B operands
typename MM0::IteratorA iterator_A(
typename MM0::IteratorA::Params(
typename MM0::MmaCore::LayoutA(p.q_strideM)),
p.query_ptr,
{problem_size_0_m, problem_size_0_k},
thread_id(),
tb_offset_A);
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(p.k_strideM)),
p.key_ptr + iter_key_start * p.k_strideM,
{problem_size_0_k, problem_size_0_n},
thread_id(),
tb_offset_B);
auto my_warp_id = warp_id();
auto my_lane_id = lane_id();
// Construct thread-scoped matrix multiply
typename MM0::Mma mma(
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
typename MM0::Mma::FragmentC accum;
accum.clear();
auto gemm_k_iterations =
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
__syncthreads();
if (kPreloadV) {
prologueV(0);
}
typename MM0::Mma::Operator::IteratorC::TensorCoord
iteratorC_tile_offset = {
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
(my_warp_id % MM0::Mma::WarpCount::kM),
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
(my_warp_id / MM0::Mma::WarpCount::kM)};
// Mask out last if causal
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
auto query_start = blockIdx.x * kQueriesPerBlock;
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
lane_id(), warp_id(), iteratorC_tile_offset);
int32_t last_col;
MM0::ScalingCoefsUpdater::iterateRows(
lane_offset,
[&](int accum_m) {
last_col = query_start + accum_m - iter_key_start;
},
[&](int accum_m, int accum_n, int idx) {
if (accum_n > last_col) {
accum[idx] =
-cutlass::platform::numeric_limits<accum_t>::infinity();
}
},
[&](int accum_m) {});
}
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
p.num_keys - iter_key_start >= kKeysPerBlock,
kFullColumns,
([&] {
// Update `mi` from accum stored in registers
// Also updates `accum` with accum[i] <-
// exp(accum[i] * scale
// - mi)
MM0::ScalingCoefsUpdater::update<
kQueriesPerBlock,
kFullColumns,
kIsFirst,
kKeepOutputInRF>(
accum_o,
accum,
mi,
m_prime,
s_prime,
lane_id(),
thread_id(),
warp_id(),
p.num_keys - iter_key_start,
iteratorC_tile_offset,
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
}));
}));
// Output results to shared-memory
int warp_idx_mn_0 = my_warp_id %
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
auto output_tile_coords = cutlass::MatrixCoord{
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
MM0::B2bGemm::accumToSmem(
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
__syncthreads();
//
// MATMUL: Attn . V
// Run the matmul `attn @ V` for a block of attn and V.
// `attn` is read from shared memory (in `shared_storage_si`)
// `V` is read from global memory (with iterator_B)
//
const int64_t nBlockN = kSingleValueIteration
? 1
: ceil_div(
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
for (int blockN = 0; blockN < nBlockN; ++blockN) {
int gemm_k_iterations =
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
// Compute threadblock-scoped matrix multiply-add and store it in accum
// (in registers)
if (!kPreloadV) {
__syncthreads(); // we share shmem between mma and epilogue
}
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
p.value_ptr + iter_key_start * p.v_strideM,
{problem_size_1_k, problem_size_1_n},
thread_id(),
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
typename MM1::Mma mma_pv(
shared_storage.after_mm0.mm1.mm,
shared_storage.after_mm0.si,
(int)thread_id(),
(int)warp_id(),
(int)lane_id(),
(int)problem_size_1_k);
mma_pv.set_prologue_done(kPreloadV);
if (!kKeepOutputInRF) {
accum_o.clear();
}
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
__syncthreads();
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
prologueV(blockN + 1);
}
if (!kKeepOutputInRF) {
DISPATCH_BOOL(
iter_key_start == 0, kIsFirst, ([&] {
DISPATCH_BOOL(
(iter_key_start + kKeysPerBlock) >= p.num_keys,
kIsLast,
([&] {
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp =
typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp = typename cutlass::epilogue::
thread::MemoryEfficientAttentionNormalize<
typename cutlass::platform::conditional<
kIsLast,
output_t,
output_accum_t>::type,
output_accum_t,
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator,
ElementCompute,
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue = typename cutlass::epilogue::threadblock::
EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename cutlass::platform::conditional<
kIsLast,
typename MM1::OutputTileIterator,
typename MM1::OutputTileIteratorAccum>::type,
typename DefaultEpilogue::
AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // Read
// iterator
>;
int col = blockN * MM1::Mma::Shape::kN;
auto source_iter = createOutputAccumIter(col);
auto dest_iter = call_conditional<
kIsLast,
decltype(createOutputIter),
decltype(createOutputAccumIter)>::
apply(createOutputIter, createOutputAccumIter, col);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o, source_iter);
}));
}));
if (!kSingleValueIteration) {
__syncthreads();
}
}
}
__syncthreads(); // we modify `m_prime` after
}
if (kKeepOutputInRF) {
constexpr bool kIsFirst = true;
constexpr bool kIsLast = true;
using DefaultEpilogue = typename MM1::DefaultEpilogue;
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
using ElementCompute = typename DefaultOp::ElementCompute;
using EpilogueOutputOp =
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
output_t, // output
output_accum_t, // source
DefaultOp::kCount,
typename DefaultOp::ElementAccumulator, // accum
output_accum_t, // compute
kIsFirst,
kIsLast,
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
using Epilogue =
typename cutlass::epilogue::threadblock::EpiloguePipelined<
typename DefaultEpilogue::Shape,
typename MM1::Mma::Operator,
DefaultEpilogue::kPartitionsK,
typename MM1::OutputTileIterator, // destination
typename DefaultEpilogue::AccumulatorFragmentIterator,
typename DefaultEpilogue::WarpTileIterator,
typename DefaultEpilogue::SharedLoadIterator,
EpilogueOutputOp,
typename DefaultEpilogue::Padding,
DefaultEpilogue::kFragmentsPerIteration,
true, // IterationsUnroll
typename MM1::OutputTileIteratorAccum // source tile
>;
auto dest_iter = createOutputIter(0);
EpilogueOutputOp rescale(s_prime, m_prime);
Epilogue epilogue(
shared_storage.epilogue_shared_storage(),
thread_id(),
warp_id(),
lane_id());
epilogue(rescale, dest_iter, accum_o);
}
// 7. Calculate logsumexp
// To make the backward easier, we pad logsumexp with `inf`
// this avoids a few bound checks, and is not more expensive during fwd
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
if (thread_id() < p.num_queries) {
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
cutlass::fast_log(accum_t(s_prime[thread_id()]));
} else if (thread_id() < lse_dim) {
p.logsumexp_ptr[thread_id()] =
cutlass::platform::numeric_limits<accum_t>::infinity();
}
}
}
static CUTLASS_DEVICE int8_t lane_id() {
return threadIdx.x;
}
static CUTLASS_DEVICE int8_t warp_id() {
return threadIdx.y;
}
static CUTLASS_DEVICE int16_t thread_id() {
return threadIdx.x + threadIdx.y * blockDim.x;
}
};
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched_impl(typename AK::Params p) {
if (!p.advance_to_block()) {
return;
}
AK::attention_kernel(p);
}
template <typename AK>
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
attention_kernel_batched(typename AK::Params params);
#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \
template <> \
__global__ void __launch_bounds__( \
__VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \
attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \
using Kernel = __VA_ARGS__;
#define _ATTENTION_KERNEL_FORWARD_END() }
#ifdef __CUDA_ARCH__
#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__
#else
#define __CUDA_ARCH_OR_ZERO__ 0
#endif
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \
ARCH, \
SCALAR_T, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER) \
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
SCALAR_T, \
cutlass::arch::Sm##ARCH, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER>) \
if (!p.advance_to_block()) { \
return; \
} \
Kernel::attention_kernel(p); \
_ATTENTION_KERNEL_FORWARD_END();
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \
ARCH, \
SCALAR_T, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER) \
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
SCALAR_T, \
cutlass::arch::Sm##ARCH, \
IS_ALIGNED, \
QUERIES_PER_BLOCK, \
KEYS_PER_BLOCK, \
SINGLE_VALUE_ITER>) \
printf( \
"FATAL: this function is for sm%d, but was built for sm%d\n", \
int(ARCH), \
int(__CUDA_ARCH_OR_ZERO__)); \
_ATTENTION_KERNEL_FORWARD_END();
// All kernels are disabled by default
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
// Enable the right one based on __CUDA_ARCH__
#ifndef __CUDA_ARCH__
#elif __CUDA_ARCH__ < 500
//#error "Need cuda arch at least 5.0"
#elif __CUDA_ARCH__ < 700
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
#elif __CUDA_ARCH__ < 750
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
#elif __CUDA_ARCH__ < 800
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
#elif __CUDA_ARCH__ >= 800
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
#endif
#endif

View file

@ -0,0 +1,30 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#if USE_FLASH_ATTENTION
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params) {
const int32_t& sm = params.sm;
if (sm >= 80) {
run_memory_efficient_attention_sm80(params);
} else if (sm >= 75) {
run_memory_efficient_attention_sm75(params);
} else if (sm >= 70) {
run_memory_efficient_attention_sm70(params);
} else if (sm >= 50) {
run_memory_efficient_attention_sm50(params);
} else {
assert(false); // shall not reach here.
}
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#endif // USE_FLASH_ATTENTION

View file

@ -0,0 +1,55 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#if USE_FLASH_ATTENTION
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention_common.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
struct MemoryEfficientAttentionParams {
int32_t sm;
bool is_half;
int32_t batch_size;
int32_t num_heads;
int32_t sequence_length;
int32_t kv_sequence_length;
int32_t qk_head_size;
int32_t v_head_size;
bool causal;
int32_t* cu_seqlens_q;
int32_t* cu_seqlens_k;
const void* query; // [B, S, N, H]
const void* key; // [B, L, N, H], where L is kv_sequence_length
const void* value; // [B, L, N, H_v]
void* output; // [B, S, N, H_v]
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
cudaStream_t stream;
static bool need_workspace(size_t v_head_size, bool is_float) {
return (v_head_size > 128 && !is_float);
}
};
void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params);
inline bool has_memory_efficient_attention(int32_t sm, bool is_half) {
return sm >= (is_half ? 53 : 50);
}
void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params);
void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params);
void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params);
void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
#endif // USE_FLASH_ATTENTION

View file

@ -6,6 +6,7 @@
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "contrib_ops/cuda/bert/multihead_attention.h"
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
using namespace onnxruntime::cuda;
using namespace ::onnxruntime::common;
@ -42,8 +43,14 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
disable_fused_runner_ = sizeof(T) != 2 ||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
enable_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
enable_trt_flash_attention_ = sizeof(T) == 2 &&
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
#if USE_FLASH_ATTENTION
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
#else
disable_memory_efficient_attention_ = true;
#endif
disable_fused_cross_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
}
@ -85,7 +92,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
bool use_fused_cross_attention = !disable_fused_cross_attention_ &&
nullptr == key_padding_mask &&
parameters.hidden_size == parameters.v_hidden_size &&
@ -104,17 +110,18 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
bool use_fused_runner = !disable_fused_runner_ &&
fused_cross_attention_kernel == nullptr &&
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
parameters.hidden_size == parameters.v_hidden_size &&
parameters.sequence_length == parameters.kv_sequence_length &&
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
enable_flash_attention_, false);
enable_trt_flash_attention_, false);
if (use_fused_runner) {
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
if (nullptr == fused_fp16_runner_.get()) {
constexpr bool is_unidirectional = false;
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(
num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_, parameters.scale));
num_heads_, parameters.head_size, sm, is_unidirectional, enable_trt_flash_attention_, parameters.scale));
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
@ -124,6 +131,16 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
}
}
#if USE_FLASH_ATTENTION
bool use_memory_efficient_attention = fused_runner == nullptr &&
fused_cross_attention_kernel == nullptr &&
!disable_memory_efficient_attention_ &&
nullptr == key_padding_mask && // TODO: support 1D mask
has_memory_efficient_attention(sm, sizeof(T) == 2);
#else
constexpr bool use_memory_efficient_attention = false;
#endif
constexpr size_t element_size = sizeof(T);
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
parameters.batch_size,
@ -133,7 +150,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
parameters.sequence_length,
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner);
fused_runner,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
typedef typename ToCudaType<T>::MappedType CudaT;
@ -152,6 +170,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
data.present = nullptr;
data.fused_runner = reinterpret_cast<void*>(fused_runner);
data.fused_cross_attention_kernel = fused_cross_attention_kernel;
data.use_memory_efficient_attention = use_memory_efficient_attention;
cublasHandle_t cublas = GetCublasHandle(context);
return QkvToContext<CudaT>(

View file

@ -24,8 +24,9 @@ class MultiHeadAttention final : public CudaKernel {
int num_heads_; // number of attention heads
float mask_filter_value_;
bool disable_fused_runner_;
bool enable_flash_attention_;
bool enable_trt_flash_attention_;
bool disable_fused_cross_attention_;
bool disable_memory_efficient_attention_;
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_;
};

View file

@ -174,6 +174,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
Tensor* present = context->Output(1, present_shape);
void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up
bool use_memory_efficient_attention = false;
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
batch_size,
parameters.num_heads,
@ -182,7 +183,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
sequence_length,
parameters.kv_sequence_length,
parameters.total_sequence_length,
fused_runner);
fused_runner,
use_memory_efficient_attention);
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
@ -202,6 +204,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
data.fused_runner = fused_runner;
data.fused_cross_attention_kernel = nullptr;
data.use_memory_efficient_attention = use_memory_efficient_attention;
return QkvToContext<CudaT>(GetDeviceProp(), cublas, Stream(context), parameters, data);
}

View file

@ -544,7 +544,8 @@ def get_ort_environment_variables():
env_names = [
"ORT_DISABLE_FUSED_ATTENTION",
"ORT_DISABLE_FUSED_CROSS_ATTENTION",
"ORT_DISABLE_FLASH_ATTENTION",
"ORT_DISABLE_TRT_FLASH_ATTENTION",
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
"ORT_TRANSFORMER_OPTIONS",
"ORT_CUDA_GEMM_OPTIONS",
]

View file

@ -929,7 +929,7 @@ TEST(AttentionTest, Causal_EmptyPastState) {
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
@ -940,7 +940,7 @@ TEST(AttentionTest, Causal_EmptyPastState) {
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
@ -951,7 +951,7 @@ TEST(AttentionTest, Causal_EmptyPastState) {
{
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,

View file

@ -1902,6 +1902,392 @@ void GetCrossAttentionData_HeadSize40(AttentionTestData& data) {
1.2402344f, 2.2792969f, 0.33398438f, 2.2519531f, 0.67041016f, -0.55957031f, 0.20666504f, 1.3583984f, -1.9716797f, 2.6074219f, 2.2832031f, -2.0546875f, -2.4335938f, 0.53515625f, -0.15100098f, 1.9599609f, -0.51513672f, 0.31030273f, -0.49169922f, 1.4677734f, 2.234375f, 0.87451172f, 0.54736328f, -1.8681641f, -4.2265625f, -0.97509766f, -7.296875f, -1.3486328f, 1.3769531f, -1.8427734f, 3.1601562f, -2.4238281f, -0.82421875f, -2.7324219f, -0.52734375f, 2.2089844f, 0.66796875f, -0.42236328f, -3.03125f, -0.047302246f};
}
}
void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d) {
data.hidden_size = 64;
data.v_hidden_size = 64;
data.num_heads = 2;
data.batch_size = 2;
data.sequence_length = 2;
data.kv_sequence_length = 3;
if (is_mask_1d) {
data.mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
data.key_padding_mask_data = {1, 2};
} else {
data.mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
data.key_padding_mask_data = {1, 0, 0,
1, 1, 0};
}
data.skip_kernel_types = {AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
AttentionKernelType::AttentionKernel_TrtFusedAttention,
AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention};
{
data.query_data = {
0.66417468f, -2.82039404f, 1.66603971f, 4.84341049f, -1.63285708f, 3.61133432f,
-1.07151258f, -0.41698062f, -1.38491797f, -3.79137778f, 1.34514475f, -2.97253704f,
2.12579250f, -0.02954102f, 2.30081463f, 0.21410012f, 1.84038579f, 0.46486610f,
-4.49463224f, 0.69027799f, 1.01090157f, 0.04715919f, -1.60957003f, 0.10730582f,
-5.77672052f, 0.37593889f, 2.04825425f, -1.00890708f, -3.88195300f, -2.69047785f,
1.15699422f, -1.13536406f, -0.42816854f, 3.12039518f, 3.21898699f, -0.51998949f,
-4.72336435f, -0.78055519f, -0.72722042f, 3.17147565f, -1.31066322f, -3.09425855f,
-3.54743338f, -0.07284085f, 1.10525322f, 1.82087338f, -2.03681397f, -4.27978802f,
0.26408362f, 0.58637118f, -2.07128787f, -3.48036027f, -0.03049034f, -1.99293542f,
-0.67289937f, 1.17342246f, -4.84998703f, -2.43558168f, 1.16422236f, 0.26511097f,
-1.98199308f, -1.86423326f, 1.61366916f, -0.35201707f,
-1.43554640f, -1.37493825f, 2.32563400f, -1.31762123f, -1.46716797f, 0.18536982f,
0.85819042f, -3.11506653f, -1.25773919f, 1.30177450f, 0.58314162f, -1.72039497f,
-4.55264997f, 0.02031951f, -2.83490133f, 2.69835496f, -0.07102034f, -2.05412841f,
-1.26518285f, 3.30601740f, -4.54173231f, 0.80148667f, -1.36685658f, -2.26921320f,
-0.94192690f, -2.77439642f, 0.43918809f, 1.44727242f, 1.53386545f, 2.67014980f,
3.30231142f, -1.60745978f, -1.26032567f, 1.27801156f, 0.31288767f, 3.04471421f,
-1.09798527f, -2.76303077f, -1.68329728f, -4.78179169f, -0.86371553f, -1.57159030f,
-1.06435764f, 3.61700702f, 0.71459293f, -0.25048330f, 1.31865597f, -1.83117080f,
-1.10344386f, 2.94894052f, -1.33930528f, 1.94855583f, -1.94283628f, -0.64020038f,
2.24100995f, 1.06447530f, -0.03809617f, 3.47241497f, -2.55227089f, 0.12048072f,
2.88777542f, -1.73300576f, 3.10077643f, -0.37158102f,
-0.76705527f, -1.27237630f, 3.55744553f, 0.84103155f, -2.37726879f, 0.20218298f,
-3.41723180f, 1.26160014f, 1.45791709f, -1.47226799f, -2.36974764f, 1.49916458f,
1.68845606f, -1.33727181f, -2.18113089f, -0.64312577f, -1.06002951f, -0.98938328f,
1.95285964f, 3.08321524f, 1.28492856f, 2.28907299f, 1.14324796f, -0.11273877f,
-5.96574259f, -1.80337310f, 3.86340094f, -2.42390299f, -1.29642844f, 0.14276078f,
-1.23373103f, -0.51519167f, -1.04046988f, 0.60624832f, -0.93274558f, 2.46919179f,
-0.58201206f, -3.43382907f, 1.63227773f, 1.92112875f, -0.17216301f, 2.79771209f,
2.67759442f, 1.73900354f, -0.00557053f, -0.63086307f, -0.37115061f, 0.82691956f,
1.81370568f, -0.48766607f, -1.05545425f, -2.79009533f, -7.64374399f, -2.65407372f,
-0.84429693f, 1.35677493f, -1.25277543f, 2.26928639f, -1.77852845f, 2.31752825f,
-1.28869593f, -2.97340727f, -2.87103486f, 2.17401385f,
0.20970306f, -1.19119942f, 1.11263359f, 0.21227169f, -5.30872822f, -2.15851903f,
0.63067430f, -0.49583313f, 3.05784941f, 0.09588236f, 0.76925617f, 1.18900692f,
0.35771871f, -0.97235727f, 1.14949071f, -1.25595427f, 2.37192512f, -0.32522821f,
1.42988098f, -0.38017935f, 2.49831486f, -0.30629224f, 1.08675146f, -1.02598715f,
-0.17971759f, -0.55683851f, 1.04535389f, 1.54741859f, -0.05179391f, 0.73957652f,
0.54304504f, 1.95280874f, -1.19504929f, -1.19528544f, 1.33258319f, 0.13532166f,
-1.87509251f, 0.99605685f, 2.69439840f, 1.03421521f, 1.79539657f, 0.15001571f,
0.55184591f, -0.84038037f, -2.08177447f, -1.43082356f, -1.52199960f, 1.69448102f,
2.12475252f, -2.64191580f, 0.10776700f, -4.01538181f, 1.15558016f, -0.09849232f,
0.33533198f, 3.34633803f, -2.89805937f, -2.51580763f, 0.94939411f, 1.36254668f,
0.47172806f, 4.40817642f, -0.11368597f, -2.70789719f};
}
{
data.key_data = {
1.18319833f, -0.20700163f, -0.64873743f, 3.88316822f, -2.82827115f, 4.12166834f, 0.84225285f, -1.11044288f,
-1.75086212f, -1.66724730f, 2.22730064f, -3.22617316f, -0.14071584f, 0.58066225f, 3.04375815f, -1.43881261f,
-2.39294887f, 1.03637624f, -0.98744214f, 1.13576865f, -0.23876363f, 0.27395499f, -0.51450062f, -2.23614597f,
-2.12345290f, -0.68864471f, 2.56223369f, -1.14069867f, -2.14457107f, -1.32647824f, -1.20575166f, -0.98427975f,
0.43083039f, -1.72496212f, 0.89925444f, -0.33879194f, -1.01836991f, 0.06260723f, -4.40405083f, 1.51136112f,
-1.57057071f, -2.49242449f, -0.37187487f, -3.55319405f, 1.50083232f, 0.37271553f, 1.00157571f, -0.50416815f,
1.28753221f, -0.82453167f, -1.13294256f, -1.49514699f, 0.11243388f, 1.89696264f, -1.46173263f, 3.32755566f,
-0.54521537f, -2.61305809f, -0.43132567f, -0.33066380f, -0.47485363f, 3.62707257f, -0.61352783f, 2.21147466f,
-2.39673638f, 0.89925957f, -2.58643913f, -0.81968069f, 3.34945726f, 0.73745269f, -1.62732553f, -4.55126476f,
2.78017616f, 0.33757699f, 2.50468874f, -4.14928627f, 0.20017165f, 3.62233806f, -4.17984772f, 2.60447359f,
2.16826940f, 1.70457518f, 1.03199887f, 2.66712570f, 0.50808340f, -3.47132921f, -2.60008478f, 1.03852415f,
-0.53876096f, 3.36212158f, -5.49142551f, 1.69825470f, -2.98179603f, -3.39561105f, -2.33971524f, 1.23642313f,
2.13283253f, -0.56307364f, -2.49120903f, 2.97641850f, -1.28758216f, 3.43342829f, 2.49575281f, 0.09292871f,
-0.46469527f, -3.95696974f, 2.16474032f, -2.15254521f, -2.24547267f, 2.34235692f, -1.02470589f, 3.97816467f,
3.60425544f, 1.87994969f, -2.46964216f, 1.47802746f, -1.81441534f, -1.56946301f, 0.56189334f, -1.69905055f,
-1.83049631f, 4.64296293f, 3.36173010f, 1.17065477f, 0.62365234f, 1.23748016f, 0.63865232f, -2.90434527f,
1.80253839f, 3.11227179f, -3.96782875f, -2.78780794f, 3.76587057f, -1.66908360f, 1.83301187f, -1.74414611f,
-2.83874130f, -2.00238085f, -6.45539570f, 0.56152177f, 2.52830791f, -4.32480669f, 1.40038610f, 0.83278954f,
0.16065764f, -0.13457650f, 2.17216778f, -4.28218699f, 0.75475001f, -0.67497885f, -0.95346600f, 3.29623652f,
1.84325528f, 1.18348145f, -0.23741919f, 2.49520302f, 0.88820332f, 1.15528166f, 0.75733638f, 2.09371948f,
-1.16427231f, 1.36415648f, -1.17721760f, 0.19180456f, -3.83617687f, -0.22694540f, 5.14728260f, -0.43242604f,
-2.59039426f, -1.40904129f, 0.58194822f, -2.59625196f, -3.60205126f, 1.45633197f, 3.66319609f, -4.45727873f,
3.95457315f, -0.17875004f, 2.43404126f, 2.83592010f, 0.87342203f, 1.24538708f, 3.10003138f, 2.63025975f,
4.57258415f, -5.20645714f, -2.55821514f, 0.60136455f, -4.13579988f, -2.04082966f, 2.21142578f, -1.05740535f,
1.78609943f, -3.10438013f, -0.13040465f, -3.02957106f, 0.91924584f, 0.45405358f, -1.90627027f, -1.05065346f,
-1.21743047f, -1.65989709f, -0.51138550f, 2.04327297f, 0.65217698f, 0.77914226f, 1.86315429f, 0.75791669f,
-0.55304748f, -1.23857486f, 2.63207936f, -0.51371288f, 5.48993397f, -2.35509205f, -2.30255723f, 3.88706803f,
-1.93575382f, 0.03364474f, -1.61156952f, -2.74172544f, 1.64667726f, 0.04652762f, 2.88130736f, -2.00066185f,
0.74907655f, -3.35894132f, -1.85703170f, 1.78695405f, 0.16497552f, 0.94382036f, 3.04452896f, -4.42404556f,
-1.67239439f, 0.93356639f, 0.08288032f, -0.11422639f, -3.94759631f, 0.35302341f, -1.20778334f, -1.92491865f,
-1.86599326f, -1.29324412f, -1.12795746f, 0.24268979f, -0.50242394f, 2.26449108f, 0.91289425f, -2.48235416f,
-1.12685704f, -0.32806787f, 3.28139257f, 3.19231367f, 0.99441254f, -1.86975384f, -3.57600951f, 0.07424650f,
-0.45312887f, 5.02197504f, -3.93365264f, -3.30742884f, -1.48101401f, 1.03335130f, 2.79531693f, -3.71739435f,
1.58574414f, -4.52857542f, 1.99908066f, 1.53755212f, 1.60631371f, -2.46801257f, -1.85840714f, 5.07508087f,
1.69143867f, -1.04688716f, -3.17096090f, -4.08357859f, -0.02436948f, -1.26299214f, 1.55509603f, 3.11954260f,
3.55844116f, 0.10080734f, -0.57031679f, 2.01342750f, -0.66671205f, -1.89724469f, 2.52388906f, 3.71421099f,
0.77953398f, -1.63364959f, -1.90900147f, -3.60591793f, 1.17604601f, -1.69456589f, -1.62096381f, -1.44886708f,
-1.09821022f, -1.27646899f, 2.73696446f, -2.21802664f, -0.22022307f, 1.76918471f, -1.55524099f, 0.27310926f,
-0.56175643f, -0.59620953f, 2.34752941f, -0.74946308f, -2.33520174f, 1.37984359f, -1.82466078f, -0.04973821f,
-4.77387571f, -0.85034770f, 3.39579129f, -2.82413197f, -2.37980723f, 0.10482252f, 0.10614476f, 0.38176090f,
-0.03948998f, -3.33898020f, 0.33013302f, -0.24926627f, 1.82249093f, 0.57584983f, -0.68790460f, -0.62760007f,
0.17052543f, -0.54540014f, 1.66043472f, -0.29917845f, 3.31803465f, 0.86704284f, -0.26854402f, 2.23795938f,
-0.65058500f, -2.01540327f, -2.32472515f, -2.85143948f, -3.76564598f, -0.25596800f, -2.08064461f, -0.60812098f,
3.64154029f, -2.58636141f, -0.25312662f, -2.22530699f, -1.24763203f, -3.08458424f, 0.69228125f, -1.84211481f,
1.09744453f, -1.35679579f, 1.68044925f, 0.89537722f, 3.56465936f, -0.64790231f, -1.42140329f, -2.85126376f,
0.88302374f, -0.77923191f, -0.61865216f, -3.08081675f, 0.87791818f, -0.27943787f, 0.46918952f, 1.50163293f,
3.43236423f, 1.99953759f, -2.42805409f, 4.97383118f, -2.13942194f, 1.45409000f, -1.14207470f, 0.63804722f,
-4.23801470f, 1.23076391f, 2.71176004f, 1.13607812f, 2.27742863f, 1.64165723f, 1.20048785f, -0.66269439f};
}
{
data.value_data = {
2.52855659f, 1.00436294f, 0.83871710f, 0.97005701f, 1.33615291f, -2.07353282f, 0.14190522f, -1.42923164f,
-0.05781263f, -3.81081843f, 1.15263164f, 0.62601233f, -0.93824124f, 1.21525323f, -0.17992918f, 2.08717370f,
3.61659431f, -0.16836943f, 2.17779160f, -0.63968349f, 0.32170480f, 1.74428463f, -0.46570981f, -0.07432288f,
-0.21569058f, 0.65559602f, 3.58669281f, 0.40837619f, 2.40912223f, 1.31780922f, -4.45945454f, 0.64903581f,
-1.10752177f, -1.79390311f, 0.89312351f, -1.84512544f, -1.13948750f, 3.87221098f, -2.74163318f, 2.90849519f,
-0.31782085f, 3.12108278f, 0.80056298f, 1.02164125f, -0.07995117f, -0.96148860f, 3.49803638f, -4.48321056f,
-1.50024915f, -2.58987570f, 0.61711067f, 4.13532829f, -4.38111591f, -2.48988461f, -0.43977243f, -3.93134618f,
-2.67314148f, 2.64455128f, 0.11041284f, 1.26786041f, -0.24446392f, -0.86178148f, 2.35680771f, -1.69236851f,
-1.22143269f, 1.99185669f, 2.99625540f, -2.32311869f, -2.26162481f, 3.13980794f, 0.37014920f, 3.22335911f,
2.55935216f, 2.19479871f, 4.89236355f, 1.76135564f, -2.74285603f, 1.39842391f, -0.25135490f, -4.76257038f,
-0.80362052f, -1.75548995f, -4.70487833f, 1.72763062f, 3.14491320f, 3.97562551f, -0.64091396f, -0.49683607f,
1.09094775f, -0.04886785f, -0.20181555f, 2.22182846f, 3.00734067f, -0.52149582f, -1.55592132f, 4.41542721f,
4.68795204f, -1.03364658f, 1.12266266f, -1.50595415f, -4.82583904f, -0.65535200f, -1.44525290f, -0.24540535f,
-0.44778955f, 2.32284093f, 1.60033488f, 0.12583408f, -4.42107201f, -1.32412672f, -1.84733653f, -1.53440499f,
3.21279287f, -0.37051341f, 0.26685789f, 2.25037003f, 0.01608747f, 1.66141725f, -0.53394145f, 1.35017800f,
1.35997009f, -2.73341703f, 5.47488451f, 5.49519920f, -1.90401053f, 3.37626982f, -1.97467375f, 1.91208827f,
-0.39609963f, -3.46037388f, -1.47946858f, 3.59935665f, 2.36377144f, -2.32310963f, 1.95714176f, -3.10615826f,
-1.72878003f, 0.37169266f, -5.95610952f, -1.32819366f, -1.24326205f, 0.17746472f, 2.59834385f, 1.83808351f,
2.94952321f, 3.01939392f, 1.37281823f, 2.67180538f, -0.32547897f, 1.11373281f, -0.26456773f, 0.30103314f,
-1.05465972f, -1.74858260f, 4.66243505f, -0.58474910f, 1.26216507f, 1.28856802f, 0.30135399f, -3.24127388f,
1.57217860f, -3.84659171f, 1.52000761f, -0.57999939f, 7.80852032f, 2.83661318f, -1.72516418f, 0.70036685f,
5.33224869f, 3.27205563f, 0.22613347f, 1.27628899f, 0.63828707f, 0.60137266f, 2.23047280f, -3.12771320f,
-0.03023779f, 0.80765182f, -2.25078392f, -2.55701947f, -1.01789987f, -4.81986141f, 5.08153057f, -1.74439597f,
-2.12658811f, -0.01458025f, -2.19556737f, 0.66254830f, -0.97602153f, -0.09858370f, -2.05090475f, -3.57909155f,
4.57896709f, -1.96923888f, -3.86827421f, 3.18770289f, -5.16361237f, 1.42594528f, -1.43490076f, 1.62748218f,
0.91413617f, -0.27147734f, 0.89311242f, 0.39315015f, 1.18184900f, 4.30172014f, -2.32771754f, 1.61144018f,
1.31702828f, 1.47999883f, -0.20565452f, 0.75846130f, -0.13237280f, -2.10059071f, 0.12025893f, -0.58277643f,
1.93927395f, -3.11170292f, 0.84666562f, 0.08490577f, -0.36315954f, -3.13071823f, 0.12070303f, -0.10385191f,
-2.37523723f, 2.28944397f, 0.12518460f, -1.10043252f, -1.94665289f, 3.44240570f, 1.14374518f, 3.27769613f,
1.40222466f, 0.68902296f, 2.48193359f, 1.85469973f, 0.53099388f, -2.16307211f, 0.67865700f, -0.05084896f,
0.09825261f, 1.40057099f, -0.74452353f, 0.81515837f, 1.51540780f, -1.30754757f, -1.50317430f, -2.04524612f,
-0.49154273f, 0.75809133f, -0.25134420f, 0.36961895f, -0.01882899f, -1.72547066f, 1.12012851f, -6.72828960f,
1.76177442f, 1.19128907f, -0.77717477f, -1.97159290f, -2.30860472f, 2.01583147f, 5.43375349f, 2.58655977f,
0.71099019f, 0.71843386f, 3.10709906f, 1.48128355f, 0.22561067f, -4.27442265f, -2.49249840f, 4.71605539f,
2.19818974f, -1.96133125f, 0.41619009f, 0.66834581f, -3.74457240f, -0.48215276f, -1.28305256f, -1.83142948f,
-0.72452945f, -1.97440028f, -0.14068973f, 0.11765432f, 0.49793118f, 0.40227121f, -1.34390569f, 0.92099732f,
-1.21718168f, -1.95382285f, 1.37468243f, -0.72062874f, 2.66714525f, 1.06695974f, -2.86761045f, 1.34743905f,
3.30500460f, -0.91894615f, -0.09608981f, -4.09408808f, -2.57941151f, -0.36501098f, 1.93333972f, 1.54577386f,
-2.96415496f, -2.09494066f, 1.63500857f, -1.51829720f, -0.98314112f, -1.89401948f, -0.54314089f, -3.68928242f,
1.07439506f, 1.70869648f, 0.86973846f, 1.71959770f, 1.78241849f, -4.29455566f, -1.55857742f, -3.32966399f,
0.20903873f, 1.40176547f, -6.08825064f, 2.12755013f, 3.84799123f, -0.83979988f, -1.64312506f, -0.69876713f,
4.00779629f, -2.85212469f, 0.09145057f, 1.72984874f, -0.77233994f, 1.21815240f, -1.75377214f, 4.08561277f,
-1.20909250f, -1.24881196f, 4.37579060f, 4.27434301f, -2.01065826f, 2.96602201f, 3.07406378f, 1.22374272f,
0.06376281f, -1.60328245f, -1.32239270f, 1.00765312f, 1.27593243f, -2.14843464f, -3.47884607f, -0.32401958f,
-2.52805567f, -1.01782882f, 0.74270618f, 1.47170806f, -2.56010485f, -1.49985540f, 0.92767721f, 3.42378139f,
5.23711205f, 0.47062784f, -0.26747131f, -2.06014609f, -0.20237172f, -1.60944867f, -2.51956654f, 0.59529293f,
2.63805699f, 0.43868792f, -5.84081888f, 3.25271368f, -4.44406748f, -3.80642724f, -1.59846020f, -2.59634686f,
0.11074528f, 2.04441738f, -1.51878321f, -2.59639883f, 2.23697233f, 0.07920718f, 1.31056094f, -8.10540771f};
}
{
data.bias_data = {
-0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f, -0.34948170f, -0.19608477f, 0.19725692f, 0.39987487f,
0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f, -0.23178342f, -0.13692456f, -0.04889601f, 0.48739988f,
-0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f, 0.48096961f, -0.47906545f, 0.43613154f, -0.23511401f,
-0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f, -0.45826656f, 0.23286396f, -0.43407962f, 0.40754890f,
0.23778325f, 0.34850210f, -0.01385659f, 0.32141626f, -0.27738628f, 0.27683002f, 0.31886810f, -0.24781504f,
-0.25476855f, -0.46742713f, -0.12478521f, 0.39731556f, -0.12087554f, 0.40822440f, 0.13202906f, -0.23747686f,
0.30502868f, 0.27182943f, -0.03640261f, -0.39626551f, -0.22411832f, 0.17324352f, -0.49959660f, -0.49318257f,
0.31363028f, 0.05469471f, -0.00390345f, -0.46100286f, -0.27253938f, 0.17251462f, 0.46564627f, 0.21038425f,
0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f, 0.27376485f, -0.38174152f, -0.43700469f, 0.38040614f,
-0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f, 0.07120579f, -0.08055863f, 0.12095112f, -0.27988660f,
0.06004709f, -0.05600315f, -0.25510073f, 0.41887105f, -0.19016314f, 0.47241372f, 0.12890404f, -0.24272856f,
0.21106839f, -0.40523255f, 0.10336459f, -0.11084765f, 0.42408967f, -0.15285304f, -0.28945464f, -0.25714916f,
0.40978593f, -0.09138483f, -0.02013114f, -0.39042589f, -0.19557095f, 0.07540411f, 0.33955890f, 0.41873980f,
-0.27744853f, -0.33097768f, -0.44587523f, -0.01648277f, 0.34952271f, -0.48838940f, -0.17273578f, 0.37286615f,
-0.10157353f, -0.08097187f, 0.23243034f, 0.25516337f, -0.45793599f, 0.08089012f, 0.17673731f, 0.03000754f,
0.48834521f, 0.35069120f, -0.32989410f, 0.20729345f, 0.24406803f, 0.35393929f, -0.16146761f, 0.04258209f,
-0.10567203f, 0.26791072f, -0.08976898f, 0.31341976f, 0.06027532f, 0.14307594f, 0.31587386f, 0.16180152f,
0.34785229f, 0.00531715f, -0.35168743f, -0.11641458f, 0.39196932f, 0.44535065f, 0.43545735f, 0.15593112f,
0.06171834f, -0.42181283f, -0.41170910f, 0.40969193f, -0.01510030f, 0.07973170f, -0.18156880f, 0.21522856f,
0.03915739f, -0.20913908f, -0.47068381f, 0.35633272f, -0.35124153f, 0.36624825f, -0.05567622f, -0.35343069f,
0.12821168f, 0.35526341f, -0.23420528f, -0.46328634f, -0.21994811f, -0.27556795f, 0.01653767f, 0.42626363f,
0.23239774f, 0.39632857f, 0.32416028f, -0.48494491f, -0.05365932f, -0.10860911f, 0.06893444f, 0.46116674f,
0.34345043f, -0.02719739f, -0.39574289f, -0.39339882f, 0.23044002f, -0.06155324f, 0.23292047f, 0.39775699f,
0.12789404f, -0.44719657f, 0.12020230f, 0.26871282f, -0.10917315f, -0.29244915f, 0.09059817f, -0.19613290f};
}
{
data.fp32_output_data = {
2.42288446f, 1.27227366f, 0.74894810f, 1.28347683f, 1.39642823f, -1.93045688f, 0.45777908f, -1.26743007f,
0.29003966f, -3.80550122f, 0.80094421f, 0.50959778f, -0.54627192f, 1.66060388f, 0.25552815f, 2.24310493f,
3.67831278f, -0.59018224f, 1.76608253f, -0.22999156f, 0.30660450f, 1.82401633f, -0.64727861f, 0.14090568f,
-0.17653319f, 0.44645694f, 3.11600900f, 0.76470888f, 2.05788064f, 1.68405747f, -4.51513100f, 0.29560512f,
-0.97931010f, -1.43863964f, 0.65891826f, -2.30841184f, -1.35943556f, 3.59664297f, -2.72509551f, 3.33475876f,
-0.08542311f, 3.51741123f, 1.12472320f, 0.53669631f, -0.13361049f, -1.07009768f, 3.56697083f, -4.02204370f,
-1.15679872f, -2.61707306f, 0.22136778f, 3.74192953f, -4.15067577f, -2.55143785f, -0.20685196f, -3.53358912f,
-2.54524755f, 2.19735479f, 0.23061514f, 1.53657317f, -0.35363707f, -1.15423059f, 2.44740582f, -1.88850141f,
2.42288446f, 1.27227366f, 0.74894810f, 1.28347683f, 1.39642823f, -1.93045688f, 0.45777908f, -1.26743007f,
0.29003966f, -3.80550122f, 0.80094421f, 0.50959778f, -0.54627192f, 1.66060388f, 0.25552815f, 2.24310493f,
3.67831278f, -0.59018224f, 1.76608253f, -0.22999156f, 0.30660450f, 1.82401633f, -0.64727861f, 0.14090568f,
-0.17653319f, 0.44645694f, 3.11600900f, 0.76470888f, 2.05788064f, 1.68405747f, -4.51513100f, 0.29560512f,
-0.97931010f, -1.43863964f, 0.65891826f, -2.30841184f, -1.35943556f, 3.59664297f, -2.72509551f, 3.33475876f,
-0.08542311f, 3.51741123f, 1.12472320f, 0.53669631f, -0.13361049f, -1.07009768f, 3.56697083f, -4.02204370f,
-1.15679872f, -2.61707306f, 0.22136778f, 3.74192953f, -4.15067577f, -2.55143785f, -0.20685196f, -3.53358912f,
-2.54524755f, 2.19735479f, 0.23061514f, 1.53657317f, -0.35363707f, -1.15423059f, 2.44740582f, -1.88850141f,
4.47329473f, -1.70132744f, -3.95804238f, 3.50112128f, -5.10333633f, 1.56902146f, -1.11902511f, 1.78928399f,
1.26198828f, -0.26615992f, 0.54142559f, 0.27673587f, 1.57381809f, 4.74706888f, -1.89226031f, 1.76737213f,
1.37874687f, 1.05818522f, -0.61736351f, 1.16815329f, -0.14747408f, -2.02085853f, -0.06131025f, -0.36754823f,
1.97843063f, -3.32084179f, 0.37598154f, 0.44123849f, -0.71440083f, -2.76446915f, 0.06502641f, -0.45728233f,
-1.93884647f, 1.51549935f, 0.22349268f, -1.46264625f, -0.93878794f, 2.53468966f, 0.09279048f, 3.19028425f,
2.14098549f, 0.65744257f, 2.12003636f, -0.21332240f, -0.35039914f, -1.79318547f, 1.08148456f, 0.83520722f,
-0.37325758f, 0.44315636f, -0.50703102f, -0.19921407f, 1.08093989f, -1.52517128f, -1.01477206f, -2.08499599f,
0.05307493f, 0.56386751f, 0.16719794f, 0.99758488f, 0.35134155f, -2.70159864f, 0.49787593f, -6.01998806f,
1.88393891f, 1.20359635f, -1.11693203f, -1.24092197f, -2.47922421f, 2.11120105f, 5.19413376f, 2.67079711f,
1.07527149f, 0.64369327f, 2.57635832f, 1.27686763f, 0.69491446f, -3.13548803f, -2.04371452f, 4.62090492f,
2.18864536f, -2.10483122f, -0.04580984f, 1.08532572f, -3.46754074f, -0.53330994f, -1.35113037f, -1.51521778f,
-0.46994060f, -2.27551699f, -0.53152251f, 0.47133854f, 0.07705012f, 0.48279381f, -1.28113365f, 0.48468336f,
-1.54411674f, 0.06915778f, 0.64939111f, -1.33318806f, 0.63385141f, 1.72500539f, -1.27450287f, 2.53234506f,
2.78955889f, 0.10935718f, 1.24130249f, -2.24100065f, -1.41059852f, -1.18030620f, 1.50915027f, 1.37942517f,
-1.41709673f, -0.74830860f, 0.30404601f, -0.99458563f, 0.22929534f, -1.72507358f, -0.68753922f, -2.64537501f,
0.58683372f, 0.88788664f, 0.54932535f, 1.45773280f, 0.96530700f, -3.57728553f, -0.41517627f, -4.86154747f};
}
{
data.fp16_output_data = data.fp32_output_data;
}
}
void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data) {
data.hidden_size = 32;
data.v_hidden_size = 32;
data.num_heads = 1;
data.batch_size = 2;
data.sequence_length = 2;
data.kv_sequence_length = 3;
data.mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
data.key_padding_mask_data = {0, 1, 1, // first key sequence has one padding on the left
0, 0, 1}; // second key sequence has two paddings on the left
data.skip_kernel_types = {
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
AttentionKernelType::AttentionKernel_TrtFusedAttention,
AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention
};
{
data.query_data = {
2.88765883f, 1.27536213f, -0.57580215f, 2.73696542f, 2.19016314f, 0.42629790f, 1.55081677f, -2.01307678f,
-0.80203497f, -1.23206115f, 1.78565156f, -2.09875321f, -2.22730732f, -0.98120236f, -0.25774139f, 0.75868356f,
-2.87585187f, -0.41810805f, -2.11528730f, 0.50642025f, -0.29446256f, -3.69675803f, -2.73721838f, -1.51089072f,
0.74300194f, 0.27352047f, -0.88251829f, 2.82622814f, 0.73837662f, -2.14588642f, 0.37608737f, -0.06190044f,
-1.97659302f, -2.22348428f, 2.25573063f, -2.24459195f, -2.28073978f, -0.52412349f, -0.57297325f, 3.29259396f,
1.35617173f, -0.83082151f, 0.03767079f, 1.82568312f, 0.88193995f, 1.15579486f, 1.87845564f, -0.15923920f,
2.37435389f, 1.49093378f, 1.95134592f, -1.67609048f, -0.45959851f, 1.63960719f, 3.44909906f, -0.23531833f,
-0.57074630f, 1.38279045f, 0.58870834f, 0.85297751f, -1.44973445f, 1.56243801f, -0.67229253f, -0.16198707f,
-0.23966503f, -0.15329531f, -3.22765136f, 0.60538405f, -0.33244422f, -1.34865439f, -0.24373266f, -1.78808010f,
-1.53090763f, 1.75037694f, -0.71890754f, 0.12527336f, 1.26654553f, -0.86477917f, -1.49822962f, 1.67973542f,
0.99763191f, -0.07183220f, 1.55289185f, 1.62626481f, -0.04283767f, -2.55072594f, -1.95238030f, 0.60994428f,
-2.53714681f, 1.54605150f, 0.05900350f, 1.42194426f, 0.33801061f, 1.25557244f, 0.67291188f, -1.36867523f,
1.86936152f, -1.19588101f, 0.75778806f, 1.85271311f, 0.02081686f, 2.65807819f, 0.78890860f, -1.07388866f,
4.18109226f, 0.06373940f, 2.86840463f, 0.90427721f, -0.09531648f, -0.40835506f, 1.60812938f, -1.61683714f,
-0.45421624f, -2.25537109f, -1.35910070f, -0.25111723f, -0.71782172f, 0.62597942f, -0.42838976f, 0.23198499f,
1.29250073f, -2.01550317f, 0.14619158f, -0.03868395f, -0.74211842f, -3.17291188f, -1.90475547f, 2.02544284f};
}
{
data.key_data = {
1.14242256f, 1.08148384f, -0.00962424f, -1.62719429f, 0.86478198f, 0.16862091f, 1.01692820f, -1.15278327f,
-1.13622630f, 1.78038371f, 0.58222097f, 0.39166588f, 1.75063372f, -1.20408881f, 0.75154918f, 0.58156419f,
-0.98975772f, -0.82555556f, -0.72656512f, -2.42399549f, 2.19217968f, 2.18518472f, -1.72216129f, 1.35098433f,
-0.34989786f, -0.69064844f, -0.98365444f, 3.10148478f, 0.64813483f, 1.78129303f, -0.47006512f, 2.53122735f,
0.09757380f, 0.04077591f, -0.81791472f, -0.19737752f, 1.13775492f, -1.51351953f, 0.59109330f, 2.86624002f,
-0.09282493f, -1.69204521f, 1.27087700f, 3.53944731f, 0.59776509f, -0.90838081f, -0.15813766f, -1.86199224f,
0.18734205f, -0.76110429f, -0.02243887f, -0.94068182f, 1.32443166f, 0.03512055f, -0.13194422f, -1.50401211f,
0.92001319f, 0.20918207f, -1.34839189f, 1.56431675f, -0.61030018f, 2.39562368f, -1.56722510f, -0.96874726f,
-0.48726845f, -1.41476154f, -1.45116997f, 0.53907454f, -2.14415288f, 1.14340270f, -0.21846619f, -2.72349358f,
2.99664998f, -2.38684058f, 0.95269018f, 0.04208702f, -1.75080788f, 1.24652982f, -1.76879966f, 3.10814905f,
2.48754454f, -0.62601894f, 1.41356945f, 0.10340121f, 1.09059846f, -0.78241473f, -0.61477584f, -0.19339988f,
-0.48253334f, -2.41782594f, 1.04690075f, 0.14725411f, -0.20820639f, -1.95920563f, 0.96303236f, -1.20068836f,
-1.71051037f, -1.90946770f, -2.07985783f, 2.35042953f, 0.35059446f, -0.44228595f, 4.08558750f, -0.60121447f,
0.78836018f, 0.35280651f, 0.23129070f, -0.21523762f, 0.12277550f, 0.12348226f, -1.62759030f, -2.78246498f,
4.04853964f, 0.29263157f, -0.38621908f, -1.07599223f, -1.99170423f, 1.41409016f, 2.19121861f, -3.53451037f,
3.63692737f, 0.68270516f, 2.51469731f, 2.57543731f, -2.39040112f, -3.97164130f, 1.28371549f, 1.64144099f,
-0.70385075f, 2.55361128f, 1.60707259f, 0.84735453f, -2.07756495f, -1.99240303f, -3.60991144f, 2.87136865f,
2.31296396f, 2.30251813f, -1.05624914f, -2.43777156f, -0.27048296f, 2.39037871f, -2.04504776f, 1.65183067f,
-0.38970214f, 0.16808379f, -1.30286717f, 1.90201700f, -2.71696734f, -0.66445369f, 1.27085483f, -0.60816145f,
1.81054437f, -1.55584621f, -2.19360781f, -4.52794456f, -0.90534067f, 0.94724411f, 2.40401077f, -2.94815230f,
-3.19650269f, 2.50638890f, 1.02038431f, 1.50519919f, 0.47196171f, -1.89026380f, -1.86559379f, 0.82210326f,
0.10818237f, 1.45290673f, 1.62321615f, -0.61283481f, -1.42501950f, 2.10349464f, -1.65715265f, 0.30090189f,
-3.81919909f, -2.44903922f, -1.20557833f, -0.69951278f, -1.31475580f, -3.73842764f, 1.49299407f, -0.70933276f,
-1.49021530f, 0.71776378f, -1.23052382f, -2.13119912f, -1.20718014f, 2.30572701f, 1.78386402f, -1.57122159f};
}
{
data.value_data = {
1.79297853f, 0.96909231f, 1.23087275f, -0.61933923f, -0.56477690f, 1.47813499f, 0.51474279f, -3.44743419f,
0.95816678f, -0.20553169f, -0.76906109f, -4.60927439f, 0.40629998f, 0.91934747f, -1.09594405f, -1.45653892f,
-0.59282207f, 0.05621797f, -2.26634383f, -1.30799258f, 1.22072279f, -3.60811162f, 1.70111597f, 0.47336632f,
-1.43857694f, -0.13917151f, -1.34617388f, 1.07960105f, -1.77342618f, 0.31946269f, 1.19137061f, 2.59346104f,
-1.82395399f, 0.73557752f, 2.32600021f, -0.22650969f, -0.48526058f, 1.40349376f, -0.33553454f, 0.45531431f,
0.73859257f, 0.37798560f, 0.85344458f, -1.30447221f, 1.23349071f, -0.26439479f, 1.18636096f, -0.33328748f,
-0.50939041f, 0.53500950f, 1.33486223f, -1.54447496f, -2.88690519f, -0.06809106f, -0.00597921f, -1.07510388f,
0.62182164f, 0.50033569f, -0.88293070f, 2.56142712f, 0.37708595f, 1.59349704f, -1.17139614f, 0.89580274f,
0.69456708f, 2.91441655f, -0.25431669f, -1.20305562f, 2.06701255f, -0.86700624f, -2.23615170f, 0.13303493f,
-2.97540593f, 0.08654684f, 1.40381706f, 3.54294443f, -2.07661867f, -1.33181918f, 2.24228764f, 1.79975545f,
2.14695477f, 1.40222490f, -0.29813689f, 1.94485068f, 1.99623775f, 1.53450203f, 0.28755581f, -0.67934704f,
-0.92102510f, -1.52764773f, 1.11267352f, -3.90122724f, 0.22128634f, 0.14945325f, -4.38529491f, -1.58423281f,
-2.45574522f, -1.91599977f, 5.05240345f, 2.24617362f, 3.99182248f, 0.92924285f, -0.39660916f, -0.08696688f,
0.24855530f, 0.71378094f, 0.92413902f, 1.73599064f, 1.03852975f, 2.44676781f, 0.35013664f, 0.98107171f,
1.62946916f, 0.41239718f, -1.41385484f, 2.49293518f, 2.32976985f, 2.89612579f, 2.66875219f, 1.47379971f,
1.31164551f, -1.82183075f, -5.15272474f, 0.28575048f, 0.16861364f, -0.47264135f, 0.22565089f, -0.37727535f,
-1.13935280f, 0.38051969f, -2.38735437f, -2.80645251f, 0.18637873f, 2.13938355f, 2.92260599f, -0.38653925f,
0.58366799f, -1.67636371f, -2.29396892f, -1.31527638f, 2.39795637f, 0.39815575f, -0.98530269f, -1.29227996f,
0.14452982f, -0.38186538f, -1.71267688f, 0.18121701f, -2.26441002f, -0.94511753f, 0.27371156f, -2.44858527f,
-0.21510160f, -2.65228534f, -2.16755104f, 0.86151361f, 0.77589297f, -1.06628847f, 0.73745233f, 1.15778029f,
-0.73659700f, 0.74325305f, -1.97666430f, -1.07301974f, 0.17534591f, -1.66584718f, 1.21820331f, 0.67675018f,
-1.08938253f, 1.78010321f, 0.39817584f, -0.02914053f, 1.13571596f, -0.44081455f, 1.70561552f, -2.12085509f,
-0.69322622f, -1.87331009f, -2.15000772f, 2.08436966f, 1.70494926f, -3.69169927f, -1.22119129f, -1.60190558f,
-2.09093666f, -1.02816033f, -1.78743768f, 2.34501553f, 2.79939008f, 1.82245076f, 1.47408092f, 1.10063124f};
}
{
data.bias_data = {
-0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f, -0.34948170f, -0.19608477f, 0.19725692f, 0.39987487f,
0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f, -0.23178342f, -0.13692456f, -0.04889601f, 0.48739988f,
-0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f, 0.48096961f, -0.47906545f, 0.43613154f, -0.23511401f,
-0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f, -0.45826656f, 0.23286396f, -0.43407962f, 0.40754890f,
0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f, 0.27376485f, -0.38174152f, -0.43700469f, 0.38040614f,
-0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f, 0.07120579f, -0.08055863f, 0.12095112f, -0.27988660f,
0.06004709f, -0.05600315f, -0.25510073f, 0.41887105f, -0.19016314f, 0.47241372f, 0.12890404f, -0.24272856f,
0.21106839f, -0.40523255f, 0.10336459f, -0.11084765f, 0.42408967f, -0.15285304f, -0.28945464f, -0.25714916f,
-0.10567203f, 0.26791072f, -0.08976898f, 0.31341976f, 0.06027532f, 0.14307594f, 0.31587386f, 0.16180152f,
0.34785229f, 0.00531715f, -0.35168743f, -0.11641458f, 0.39196932f, 0.44535065f, 0.43545735f, 0.15593112f,
0.06171834f, -0.42181283f, -0.41170910f, 0.40969193f, -0.01510030f, 0.07973170f, -0.18156880f, 0.21522856f,
0.03915739f, -0.20913908f, -0.47068381f, 0.35633272f, -0.35124153f, 0.36624825f, -0.05567622f, -0.35343069f};
}
{
data.fp32_output_data = {
0.23503941f, 2.87619758f, 0.01845241f, -0.75242990f, 1.76869011f, -0.40492195f, -1.65323853f, 0.34011719f,
-2.10573196f, 0.13281155f, 0.97480160f, 2.74546146f, -1.21957457f, -0.73649400f, 2.52938581f, 1.65599120f,
1.83545303f, 0.85856718f, -0.48040742f, 1.86428785f, 1.29504943f, 1.38906729f, 0.06474495f, -0.51972288f,
-0.66509569f, -1.45185244f, 0.36160457f, -2.63688278f, -0.10806514f, 0.71859169f, -3.98941422f, -1.58921516f,
-1.89806330f, 1.03079379f, 2.20389438f, 0.07467184f, -0.39299977f, 1.51811528f, -0.04347950f, 0.61307698f,
1.03990030f, 0.37965038f, 0.50865448f, -1.36013806f, 1.58397710f, 0.16757873f, 1.63505113f, -0.15062472f,
-0.41438234f, 0.12406474f, 0.90268815f, -1.09105420f, -2.84080887f, 0.03172458f, -0.18386938f, -0.85491556f,
0.64164376f, 0.26578158f, -1.32860518f, 2.83676863f, 0.02389192f, 1.94164813f, -1.26734924f, 0.51129180f,
-0.84226906f, 1.01116371f, -2.06643319f, -0.75959998f, 0.23562123f, -1.52277124f, 1.53407717f, 0.83855170f,
-0.74153024f, 1.78542042f, 0.04648840f, -0.14555511f, 1.52768528f, 0.00453609f, 2.14107275f, -1.96492398f,
-0.63150787f, -2.29512286f, -2.56171679f, 2.49406147f, 1.68984890f, -3.61196756f, -1.40276003f, -1.38667703f,
-2.05177927f, -1.23729944f, -2.25812149f, 2.70134830f, 2.44814849f, 2.18869901f, 1.41840470f, 0.74720055f,
-0.84226906f, 1.01116371f, -2.06643319f, -0.75959998f, 0.23562123f, -1.52277124f, 1.53407717f, 0.83855170f,
-0.74153024f, 1.78542042f, 0.04648840f, -0.14555511f, 1.52768528f, 0.00453609f, 2.14107275f, -1.96492398f,
-0.63150787f, -2.29512286f, -2.56171679f, 2.49406147f, 1.68984890f, -3.61196756f, -1.40276003f, -1.38667703f,
-2.05177927f, -1.23729944f, -2.25812149f, 2.70134830f, 2.44814849f, 2.18869901f, 1.41840470f, 0.74720055f};
}
{
data.fp16_output_data = data.fp32_output_data;
}
}
#endif
void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) {
@ -1913,9 +2299,8 @@ void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) {
data.kv_sequence_length = 3;
data.mask_type = AttentionMaskType::MASK_NONE;
data.skip_kernel_types = {
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
AttentionKernelType::AttentionKernel_TrtFusedAttention
};
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
AttentionKernelType::AttentionKernel_TrtFusedAttention};
{
data.query_data = {

View file

@ -28,11 +28,15 @@ struct AttentionTestData{
std::vector<AttentionKernelType> skip_kernel_types; // skip some kernels if they do not supported this test case.
};
// Disable some tests in Windows since prefast build might crash with large test data.
#ifndef _MSC_VER
// Return packed weights and bias for input projection.
void GetAttentionWeight(std::vector<float>& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step=1);
void GetAttentionBias(std::vector<float>& bias_data, int elements = 3 * 64, int offset = 0, int step=1);
void GetCrossAttentionData_HeadSize40(AttentionTestData& data);
void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d);
void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data);
#endif
void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data);

View file

@ -74,7 +74,7 @@ static void RunMultiHeadAttentionTest(
constexpr float rel_error = 0.0f;
constexpr float abs_error = 0.05f;
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data), /*sort*/false, rel_error, abs_error);
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data), /*sort*/ false, rel_error, abs_error);
} else {
tester.AddInput<float>("query", query_dims, query_data);
tester.AddInput<float>("key", key_dims, key_data);
@ -94,7 +94,7 @@ static void RunMultiHeadAttentionTest(
constexpr float rel_error = 0.0f;
constexpr float abs_error = 0.02f;
tester.AddOutput<float>("output", output_dims, output_data, /*sort*/false, rel_error, abs_error);
tester.AddOutput<float>("output", output_dims, output_data, /*sort*/ false, rel_error, abs_error);
}
if (enable_cuda) {
@ -124,7 +124,7 @@ static void RunMultiHeadAttentionKernel(
const std::vector<float>& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size]
const std::vector<int32_t>& key_padding_mask_data, // key_padding_mask: see below
AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length]
const std::vector<float>& output_data, // output: [batch_size, sequence_length, v_hidden_size]
const std::vector<float>& output_data, // output: [batch_size, sequence_length, v_hidden_size]
int num_heads,
int batch_size,
int sequence_length,
@ -136,13 +136,13 @@ static void RunMultiHeadAttentionKernel(
bool disable_cpu = true, // not supported in cpu right now.
bool disable_cuda = false,
bool disable_rocm = true) {
if (kernel_type == AttentionKernelType::AttentionKernel_Default) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}}};
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
RunMultiHeadAttentionTest(
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
@ -150,13 +150,13 @@ static void RunMultiHeadAttentionKernel(
return;
}
if (kernel_type == AttentionKernelType::AttentionKernel_Unfused)
{
if (kernel_type == AttentionKernelType::AttentionKernel_Unfused) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}}};
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
@ -164,13 +164,13 @@ static void RunMultiHeadAttentionKernel(
return;
}
if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedCrossAttention)
{
if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedCrossAttention) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}}};
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
@ -178,13 +178,29 @@ static void RunMultiHeadAttentionKernel(
return;
}
if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention)
{
#if USE_FLASH_ATTENTION
if (kernel_type == AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
RunMultiHeadAttentionTest(
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
use_float16, disable_cpu, disable_cuda, disable_rocm);
return;
}
#endif
if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) {
ScopedEnvironmentVariables scoped_env_vars{
EnvVarMap{
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}}};
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
RunMultiHeadAttentionTest(
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
@ -204,11 +220,21 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
kernel_type = AttentionKernelType::AttentionKernel_Default;
RunMultiHeadAttentionKernel(
#if USE_FLASH_ATTENTION
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
#endif
kernel_type = AttentionKernelType::AttentionKernel_Default;
RunMultiHeadAttentionKernel(
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
if (data.fp16_output_data.size() > 0) {
@ -229,6 +255,16 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
#if USE_FLASH_ATTENTION
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
if (!SkipAttentionKernel(data, kernel_type)) {
RunMultiHeadAttentionKernel(
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
}
#endif
kernel_type = AttentionKernelType::AttentionKernel_Default;
RunMultiHeadAttentionKernel(
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
@ -245,6 +281,24 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) {
GetCrossAttentionData_HeadSize40(data);
RunMultiHeadAttentionTests(data);
}
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) {
AttentionTestData data;
GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true);
RunMultiHeadAttentionTests(data);
}
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) {
AttentionTestData data;
GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false);
RunMultiHeadAttentionTests(data);
}
TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) {
AttentionTestData data;
GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data);
RunMultiHeadAttentionTests(data);
}
#endif
// This tests qk_head_size != k_head_size
@ -260,6 +314,5 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) {
RunMultiHeadAttentionTests(data);
}
} // namespace test
} // namespace onnxruntime

View file

@ -36,6 +36,8 @@ class Attention(nn.Module):
self.value = nn.Linear(hidden_dim, self.v_hidden_size)
self.is_decoder = is_decoder
# Do not reshape output for pretty print.
self.reshape_output = False
self.verbose = False
def transpose_for_scores(self, x: torch.Tensor, head_size) -> torch.Tensor:
@ -43,7 +45,7 @@ class Attention(nn.Module):
x = x.view(new_x_shape)
return x.permute(0, 2, 1, 3)
def get_extended_attention_mask(self, attention_mask: Tensor) -> Tensor:
def get_extended_attention_mask(self, attention_mask: Tensor, dtype: torch.dtype) -> Tensor:
assert attention_mask.dim() == 2 or attention_mask.dim() == 3
extended_attention_mask = (
attention_mask[:, None, :, :] if attention_mask.dim() == 3 else attention_mask[:, None, None, :]
@ -120,7 +122,7 @@ class Attention(nn.Module):
if attention_mask is not None:
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask)
attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask, hidden_states.dtype)
# Normalize the attention scores to probabilities.
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
@ -131,8 +133,9 @@ class Attention(nn.Module):
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
# new_context_layer_shape = context_layer.size()[:-2] + (self.v_hidden_size,)
# context_layer = context_layer.view(new_context_layer_shape)
if self.reshape_output:
new_context_layer_shape = context_layer.size()[:-2] + (self.v_hidden_size,)
context_layer = context_layer.view(new_context_layer_shape)
print("output", context_layer)
@ -144,7 +147,7 @@ class Attention(nn.Module):
return outputs
def generate_test_data(
def run_cross_attention(
hidden_dim,
q_head_size,
v_head_size,
@ -161,7 +164,8 @@ def generate_test_data(
device = torch.device("cuda:0")
mha = Attention(num_heads, hidden_dim, q_head_size, v_head_size, is_decoder=False).to(device).eval()
if key_padding_mask is not None:
key_padding_mask = key_padding_mask.to(device)
torch.nn.init.uniform_(mha.query.weight, -0.5, 0.5)
torch.nn.init.uniform_(mha.key.weight, -0.5, 0.5)
torch.nn.init.uniform_(mha.value.weight, -0.5, 0.5)
@ -205,10 +209,9 @@ def generate_test_data(
past_key_value=None,
output_attentions=False,
)
print("output", output)
def CrossAttention_Batch2_HeadSize40():
def run_cross_batch2_headsize_40():
hidden_dim = 80
q_head_size = 40
v_head_size = 40
@ -216,10 +219,12 @@ def CrossAttention_Batch2_HeadSize40():
batch_size = 2
sequence_length = 3
kv_sequence_length = 5
generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length)
run_cross_attention(
hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length
)
def CrossAttention_Batch1_HeadSize16():
def run_cross_batch1_headsize_16():
hidden_dim = 32
q_head_size = 16
v_head_size = 16
@ -227,10 +232,12 @@ def CrossAttention_Batch1_HeadSize16():
batch_size = 1
sequence_length = 2
kv_sequence_length = 3
generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length)
run_cross_attention(
hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length
)
def CrossAttention_Batch2_HeadSize16_8():
def run_cross_batch2_headsize_16_8():
hidden_dim = 32
q_head_size = 16
v_head_size = 8
@ -238,15 +245,73 @@ def CrossAttention_Batch2_HeadSize16_8():
batch_size = 2
sequence_length = 1
kv_sequence_length = 3
generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length)
run_cross_attention(
hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length
)
def run_cross_batch2_headsize_32_right_side_padding():
hidden_dim = 64
q_head_size = 32
v_head_size = 32
num_heads = 2
batch_size = 2
sequence_length = 2
kv_sequence_length = 3
key_padding_mask = torch.tensor([[1, 0, 0], [1, 1, 0]], dtype=torch.int32).cuda()
run_cross_attention(
hidden_dim,
q_head_size,
v_head_size,
num_heads,
batch_size,
sequence_length,
kv_sequence_length,
key_padding_mask,
)
def run_cross_batch1_headsize_32_left_side_padding():
hidden_dim = 32
q_head_size = 32
v_head_size = 32
num_heads = 1
batch_size = 2
sequence_length = 2
kv_sequence_length = 3
key_padding_mask = torch.tensor([[0, 1, 1], [0, 0, 1]], dtype=torch.int32).cuda()
run_cross_attention(
hidden_dim,
q_head_size,
v_head_size,
num_heads,
batch_size,
sequence_length,
kv_sequence_length,
key_padding_mask,
)
def create_cross_attention_test_data():
"""
Create test data used in attention_op_test_helper.cc and multihead_attention_op_test.cc
"""
print("CrossAttention_Batch2_HeadSize40")
run_cross_batch2_headsize_40()
print("CrossAttention_Batch1_HeadSize16")
run_cross_batch1_headsize_16()
print("CrossAttention_Batch2_HeadSize16_8")
run_cross_batch2_headsize_16_8()
print("CrossAttention_Batch2_HeadSize32_RightSidePadding")
run_cross_batch2_headsize_32_right_side_padding()
print("CrossAttention_Batch1_HeadSize32_LeftSidePadding")
run_cross_batch1_headsize_32_left_side_padding()
with torch.no_grad():
print("CrossAttention_Batch2_HeadSize40")
CrossAttention_Batch2_HeadSize40()
rint("CrossAttention_Batch1_HeadSize16")
CrossAttention_Batch1_HeadSize16()
print("CrossAttention_Batch2_HeadSize16_8")
CrossAttention_Batch2_HeadSize16_8()
create_cross_attention_test_data()