mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-13 01:09:22 +00:00
Add memory efficient attention from CUTLASS (#14343)
### Description Add memory efficient attention from CUTLASS. TODO (in next pull request): (1) Need performance tests on different GPUs, then add a sequence length threshold (only activate it for long sequence length). (2) Merge changes from https://github.com/NVIDIA/cutlass/pull/773 when it is in cutlass master.
This commit is contained in:
parent
e64f357ad4
commit
414b012f42
29 changed files with 2143 additions and 78 deletions
|
|
@ -2753,6 +2753,37 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
_____
|
||||
nvidia/cutlass
|
||||
|
||||
Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
1. Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
3. Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
_____
|
||||
Boost
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,8 @@ option(onnxruntime_USE_PREINSTALLED_EIGEN "Use pre-installed EIGEN. Need to prov
|
|||
option(onnxruntime_BUILD_BENCHMARKS "Build ONNXRuntime micro-benchmarks" OFF)
|
||||
option(onnxruntime_USE_LLVM "Build TVM with LLVM" OFF)
|
||||
|
||||
option(onnxruntime_USE_FLASH_ATTENTION "Build memory efficient attention kernel for scaled dot product attention" ON)
|
||||
|
||||
option(onnxruntime_BUILD_FOR_NATIVE_MACHINE "Enable this option for turning on optimization specific to this machine" OFF)
|
||||
option(onnxruntime_USE_AVX "Use AVX instructions" OFF)
|
||||
option(onnxruntime_USE_AVX2 "Use AVX2 instructions" OFF)
|
||||
|
|
@ -595,10 +597,31 @@ set(ONNXRUNTIME_PROVIDER_NAMES cpu)
|
|||
set(ORT_PROVIDER_FLAGS)
|
||||
set(ORT_PROVIDER_CMAKE_FLAGS)
|
||||
|
||||
if (onnxruntime_USE_CUDA)
|
||||
enable_language(CUDA)
|
||||
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
|
||||
|
||||
if (onnxruntime_DISABLE_CONTRIB_OPS)
|
||||
set(onnxruntime_USE_FLASH_ATTENTION OFF)
|
||||
endif()
|
||||
if (CMAKE_CUDA_COMPILER_VERSION VERSION_LESS 11.6)
|
||||
message( STATUS "Turn off flash attention since CUDA compiler version < 11.6")
|
||||
set(onnxruntime_USE_FLASH_ATTENTION OFF)
|
||||
endif()
|
||||
else()
|
||||
set(onnxruntime_USE_FLASH_ATTENTION OFF)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_CUDA)
|
||||
list(APPEND ORT_PROVIDER_FLAGS -DUSE_CUDA=1)
|
||||
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_CUDA=1)
|
||||
list(APPEND ONNXRUNTIME_PROVIDER_NAMES cuda)
|
||||
|
||||
if (onnxruntime_USE_FLASH_ATTENTION)
|
||||
message( STATUS "Enable flash attention for CUDA EP")
|
||||
list(APPEND ORT_PROVIDER_FLAGS -DUSE_FLASH_ATTENTION=1)
|
||||
list(APPEND ORT_PROVIDER_CMAKE_FLAGS -Donnxruntime_USE_FLASH_ATTENTION=1)
|
||||
endif()
|
||||
endif()
|
||||
if (onnxruntime_USE_VITISAI)
|
||||
list(APPEND ORT_PROVIDER_FLAGS -DUSE_VITISAI=1)
|
||||
|
|
@ -1234,8 +1257,6 @@ endif()
|
|||
|
||||
if (onnxruntime_USE_CUDA)
|
||||
set(CMAKE_CUDA_RUNTIME_LIBRARY Shared)
|
||||
enable_language(CUDA)
|
||||
message( STATUS "CMAKE_CUDA_COMPILER_VERSION: ${CMAKE_CUDA_COMPILER_VERSION}")
|
||||
set(CMAKE_CUDA_STANDARD 17)
|
||||
if(onnxruntime_CUDNN_HOME)
|
||||
file(TO_CMAKE_PATH ${onnxruntime_CUDNN_HOME} onnxruntime_CUDNN_HOME)
|
||||
|
|
|
|||
12
cmake/external/cutlass.cmake
vendored
Normal file
12
cmake/external/cutlass.cmake
vendored
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
if (onnxruntime_USE_FLASH_ATTENTION)
|
||||
include(FetchContent)
|
||||
FetchContent_Declare(cutlass
|
||||
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
|
||||
GIT_TAG 8b42e751c63ba219755c8ed91af5f6ec1ecc1ee6
|
||||
)
|
||||
|
||||
FetchContent_GetProperties(cutlass)
|
||||
if(NOT cutlass_POPULATED)
|
||||
FetchContent_Populate(cutlass)
|
||||
endif()
|
||||
endif()
|
||||
|
|
@ -484,6 +484,12 @@ if (onnxruntime_USE_CUDA)
|
|||
if(onnxruntime_CUDNN_HOME)
|
||||
target_include_directories(onnxruntime_providers_cuda PRIVATE ${onnxruntime_CUDNN_HOME}/include)
|
||||
endif()
|
||||
|
||||
if (onnxruntime_USE_FLASH_ATTENTION)
|
||||
include(cutlass)
|
||||
target_include_directories(onnxruntime_providers_cuda PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples)
|
||||
endif()
|
||||
|
||||
target_include_directories(onnxruntime_providers_cuda PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES} PUBLIC ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
|
||||
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
|
||||
set_target_properties(onnxruntime_providers_cuda PROPERTIES LINKER_LANGUAGE CUDA)
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ set(contrib_ops_excluded_files
|
|||
"bert/skip_layer_norm.h"
|
||||
"bert/skip_layer_norm_impl.cu"
|
||||
"bert/skip_layer_norm_impl.h"
|
||||
"bert/cutlass_fmha/*"
|
||||
"bert/tensorrt_fused_multihead_attention/*"
|
||||
"bert/transformer_common.h"
|
||||
"bert/transformer_common.cc"
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ enum AttentionMaskType {
|
|||
|
||||
enum AttentionQkvFormat {
|
||||
Q_K_V_BNSH, // for unfused attention
|
||||
Q_K_V_BSNH, // input format of query, key and value for MultiHeadAttention
|
||||
Q_K_V_BSNH, // for memory efficient attention, or format of query, key and value for MultiHeadAttention
|
||||
QKV_BSN3H, // for TRT fused attention, qkv are packed
|
||||
Q_K_V_BNSH_QKV_BS3NH, // for TRT fused causal attention, data has two formats (qkv is 3BNSH, gemm_buffer is BS3NH)
|
||||
Q_KV_BSNH_BSN2H, // for TRT fused cross attention, kv are packed
|
||||
|
|
@ -30,6 +30,7 @@ enum AttentionKernelType{
|
|||
AttentionKernel_TrtFusedAttention,
|
||||
AttentionKernel_TrtFlashAttention,
|
||||
AttentionKernel_TrtFusedCrossAttention,
|
||||
AttentionKernel_CutlassMemoryEfficientAttention,
|
||||
AttentionKernel_Default
|
||||
};
|
||||
|
||||
|
|
@ -61,8 +62,11 @@ constexpr const char* kDisableFusedAttention = "ORT_DISABLE_FUSED_ATTENTION";
|
|||
// Environment variable to enable or disable fused cross attention kernel. Default is 0 (enabled).
|
||||
constexpr const char* kDisableFusedCrossAttention = "ORT_DISABLE_FUSED_CROSS_ATTENTION";
|
||||
|
||||
// Environment variable to enable or disable flash attention. Default is 0 (enabled).
|
||||
constexpr const char* kDisableFlashAttention = "ORT_DISABLE_FLASH_ATTENTION";
|
||||
// Environment variable to enable or disable TRT flash attention. Default is 0 (enabled).
|
||||
constexpr const char* kDisableTrtFlashAttention = "ORT_DISABLE_TRT_FLASH_ATTENTION";
|
||||
|
||||
// Environment variable to enable or disable cutlass memory efficient attention. Default is 0 (enabled).
|
||||
constexpr const char* kDisableMemoryEfficientAttention = "ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION";
|
||||
|
||||
} // namespace attention
|
||||
|
||||
|
|
|
|||
|
|
@ -320,6 +320,110 @@ __global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, co
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddBiasTransposeCutlass(const T* input, const T* biases, T* output, int v_head_size) {
|
||||
// Format 3 for cutlass memory efficient attention
|
||||
// Input: BxSx(NxH + NxH + NxH_v) (Packed QKV where K and V has different hidden sizes)
|
||||
// Output: BxNxSxH + BxNxSxH + BxNxSxH_v
|
||||
// B is batch_size, S is sequence_length, N is num_heads, H is qk_head_size, H_v is v_head_size
|
||||
int n = threadIdx.y; // head_num_id
|
||||
int s = blockIdx.x; // sequence_id
|
||||
int b = blockIdx.y; // batch_id
|
||||
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
|
||||
const int h = threadIdx.x; // head_element_id
|
||||
|
||||
const int qk_head_size = blockDim.x;
|
||||
const int num_heads = blockDim.y;
|
||||
|
||||
const int sequence_length = gridDim.x;
|
||||
const int batch_size = gridDim.y;
|
||||
|
||||
const int head_size = (m == 2 ? v_head_size : qk_head_size);
|
||||
|
||||
const int total_head_size = num_heads * (qk_head_size + qk_head_size + v_head_size);
|
||||
|
||||
int in_offset;
|
||||
int out_offset;
|
||||
int bias_offset;
|
||||
in_offset = b * (total_head_size * sequence_length) + // B
|
||||
s * (total_head_size) + // S
|
||||
m * (qk_head_size * num_heads) + // M
|
||||
n * head_size + // N
|
||||
h; // H
|
||||
|
||||
out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
|
||||
b * (num_heads * head_size * sequence_length) + // B
|
||||
s * (num_heads * head_size) + // S
|
||||
n * (head_size) + // N
|
||||
h; // H
|
||||
|
||||
bias_offset = m * (num_heads * qk_head_size) + // M
|
||||
n * (head_size) + // N
|
||||
h; // H
|
||||
|
||||
if (h < head_size) {
|
||||
output[out_offset] = input[in_offset] + biases[bias_offset];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddBiasTransposeCutlass(int M, const T* input, const T* biases, T* output) {
|
||||
// Format 3 for cutlass memory efficient attention
|
||||
// Input: BxSxMxNxH
|
||||
// Output: MxBxSxNxH
|
||||
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
|
||||
int n = threadIdx.y;
|
||||
int s = blockIdx.x;
|
||||
int b = blockIdx.y;
|
||||
int m = blockIdx.z; // matrix id
|
||||
|
||||
const int head_size = blockDim.x;
|
||||
const int num_heads = blockDim.y;
|
||||
|
||||
const int sequence_length = gridDim.x;
|
||||
const int batch_size = gridDim.y;
|
||||
const int H = head_size;
|
||||
const int NH = num_heads * head_size;
|
||||
const int NHS = NH * sequence_length;
|
||||
|
||||
int in_offset = n * head_size + (m + s * M) * NH + b * NHS * M;
|
||||
const int out_offset = n * head_size + s * NH + b * NHS + m * NHS * batch_size;
|
||||
|
||||
const int h = threadIdx.x;
|
||||
if (h < head_size) {
|
||||
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddBiasTransposeCutlassLarge(const int head_size, const T* input, const T* biases, T* output,
|
||||
const int M) {
|
||||
// Format 3 for cutlass memory efficient attention
|
||||
// Input: BxSxMxNxH (Packed QKV)
|
||||
// Output: MxBxSxNxH
|
||||
int n = threadIdx.y;
|
||||
int s = blockIdx.x;
|
||||
int b = blockIdx.y;
|
||||
int m = blockIdx.z; // matrix id
|
||||
|
||||
const int stride = blockDim.x;
|
||||
const int num_heads = blockDim.y;
|
||||
|
||||
const int sequence_length = gridDim.x;
|
||||
const int batch_size = gridDim.y;
|
||||
const int H = head_size;
|
||||
const int NH = num_heads * H;
|
||||
const int NHS = NH * sequence_length;
|
||||
int in_offset = n * H + (m + s * M) * NH + b * NHS * M;
|
||||
const int out_offset = n * H + s * NH + b * NHS + m * NHS * batch_size;
|
||||
|
||||
int h = threadIdx.x;
|
||||
while (h < H) {
|
||||
output[out_offset + h] = input[in_offset + h] + biases[m * NH + n * H + h];
|
||||
h += stride;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddBiasTranspose(const T* input, const T* biases, T* output) {
|
||||
// Format 0 for Separated Q, K, V (N*H <= 1024)
|
||||
|
|
@ -395,6 +499,13 @@ void InvokeAddBiasTranspose(
|
|||
ORT_ENFORCE(total_matrix_count == 3);
|
||||
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
|
||||
}
|
||||
} else if (format == 3) {
|
||||
if (v_head_size == -1 || qk_head_size == v_head_size) {
|
||||
AddBiasTransposeCutlass<T><<<grid, block, 0, stream>>>(total_matrix_count, input, biases, output);
|
||||
} else {
|
||||
ORT_ENFORCE(total_matrix_count == 3);
|
||||
AddBiasTransposeCutlass<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
|
||||
}
|
||||
} else { // format == 0
|
||||
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
|
||||
}
|
||||
|
|
@ -410,6 +521,13 @@ void InvokeAddBiasTranspose(
|
|||
// It is rare for hidden size > 4096 (for half precision) and qk_head_size != v_head_size.
|
||||
ORT_THROW("AddBiasTranspose (format 1) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size");
|
||||
}
|
||||
} else if (format == 3) {
|
||||
if (v_head_size == -1 || qk_head_size == v_head_size) {
|
||||
AddBiasTransposeCutlassLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output,
|
||||
total_matrix_count);
|
||||
} else {
|
||||
ORT_THROW("AddBiasTranspose (format 3) not implemented for hidden_size > max_threads_per_block when qk_head_size != v_head_size");
|
||||
}
|
||||
} else { // format 0
|
||||
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@ namespace cuda {
|
|||
// format 2:
|
||||
// input : (batch_size, sequence_length, num_matrices, num_heads, head_size)
|
||||
// output: (batch_size, sequence_length, num_heads, num_matrices, head_size)
|
||||
// format 3: (requires sequence_length = kv_sequence_length and qk_head_size = v_head_size when num_matrices == 3)
|
||||
// input: (batch_size, sequence_length, num_matrices, num_heads, head_size)
|
||||
// output: (num_matrices, batch_size, sequence_length, num_heads, head_size)
|
||||
template <typename T>
|
||||
void LaunchAddBiasTranspose(
|
||||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@
|
|||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/attention.h"
|
||||
#include "contrib_ops/cuda/bert/bert_padding.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace ::onnxruntime::common;
|
||||
|
|
@ -41,8 +42,14 @@ Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB
|
|||
disable_fused_runner_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
|
||||
|
||||
enable_flash_attention_ = sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
#endif
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
@ -102,12 +109,12 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
parameters.sequence_length == parameters.kv_sequence_length &&
|
||||
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
|
||||
enable_flash_attention_, true);
|
||||
enable_trt_flash_attention_, true);
|
||||
if (use_causal_fused_runner) {
|
||||
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
|
||||
enable_flash_attention_, parameters.scale));
|
||||
enable_trt_flash_attention_, parameters.scale));
|
||||
}
|
||||
|
||||
// Here we assume all causal kernels can be loaded into shared memory. TODO: add a function to check.
|
||||
|
|
@ -122,13 +129,13 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
parameters.sequence_length == parameters.kv_sequence_length &&
|
||||
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
|
||||
enable_flash_attention_, false);
|
||||
enable_trt_flash_attention_, false);
|
||||
|
||||
if (use_fused_runner) {
|
||||
// Here we assume that num_heads, head_size and is_unidirectional does not change for an Attention node.
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, parameters.head_size, sm, is_unidirectional_,
|
||||
enable_flash_attention_, parameters.scale));
|
||||
enable_trt_flash_attention_, parameters.scale));
|
||||
}
|
||||
|
||||
// In case some kernel not loaded due to shared memory limit, we need to double check here.
|
||||
|
|
@ -139,6 +146,18 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
}
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
bool use_memory_efficient_attention = fused_runner == nullptr &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
nullptr == mask_index && // TODO: support 1D mask
|
||||
nullptr == past &&
|
||||
nullptr == present &&
|
||||
nullptr == extra_add_qk &&
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2);
|
||||
#else
|
||||
constexpr bool use_memory_efficient_attention = false;
|
||||
#endif
|
||||
|
||||
cublasHandle_t cublas = GetCublasHandle(context);
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
|
@ -169,7 +188,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.sequence_length,
|
||||
parameters.kv_sequence_length,
|
||||
parameters.total_sequence_length,
|
||||
fused_runner);
|
||||
fused_runner,
|
||||
use_memory_efficient_attention);
|
||||
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
|
@ -188,6 +208,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
|
||||
data.fused_runner = reinterpret_cast<void*>(fused_runner);
|
||||
data.fused_cross_attention_kernel = nullptr;
|
||||
data.use_memory_efficient_attention = use_memory_efficient_attention;
|
||||
|
||||
return QkvToContext<CudaT>(device_prop, cublas, Stream(context), parameters, data);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@ class Attention final : public CudaKernel, public AttentionBase {
|
|||
|
||||
protected:
|
||||
bool disable_fused_runner_;
|
||||
bool enable_flash_attention_;
|
||||
bool enable_trt_flash_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ limitations under the License.
|
|||
#include "contrib_ops/cpu/bert/attention_base.h"
|
||||
#include "contrib_ops/cuda/bert/bert_padding.h"
|
||||
#include "contrib_ops/cuda/transformers/dump_cuda_tensor.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace cub;
|
||||
|
|
@ -101,11 +102,25 @@ size_t GetAttentionWorkspaceSize(
|
|||
size_t sequence_length,
|
||||
size_t kv_sequence_length,
|
||||
size_t total_sequence_length,
|
||||
void* fused_runner) {
|
||||
void* fused_runner,
|
||||
bool use_memory_efficient_attention) {
|
||||
// Note that q, k and v might need alignment for fused attention kernels.
|
||||
const size_t qkv_bytes = element_size * batch_size * num_heads *
|
||||
((sequence_length + kv_sequence_length) * qk_head_size + kv_sequence_length * v_head_size);
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
if (use_memory_efficient_attention) {
|
||||
size_t fmha_buffer_bytes = 0;
|
||||
if (MemoryEfficientAttentionParams::need_workspace(v_head_size, element_size == sizeof(float))) {
|
||||
fmha_buffer_bytes = batch_size * sequence_length * num_heads * v_head_size * sizeof(float);
|
||||
}
|
||||
|
||||
return qkv_bytes + fmha_buffer_bytes;
|
||||
}
|
||||
#else
|
||||
ORT_UNUSED_PARAMETER(use_memory_efficient_attention);
|
||||
#endif
|
||||
|
||||
if (fused_runner != nullptr) {
|
||||
size_t sequence_offset_bytes = GetSequenceOffsetSize(static_cast<int>(batch_size), true);
|
||||
return qkv_bytes + sequence_offset_bytes;
|
||||
|
|
@ -259,12 +274,15 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
|
|||
const int v_head_size = parameters.v_head_size;
|
||||
const bool past_present_share_buffer = parameters.past_present_share_buffer;
|
||||
void* fused_runner = data.fused_runner;
|
||||
bool use_memory_efficient_attention = data.use_memory_efficient_attention;
|
||||
|
||||
T* qkv = data.workspace;
|
||||
|
||||
bool use_fused_kernel = (nullptr != fused_runner && data.bias != nullptr && !parameters.is_unidirectional);
|
||||
bool use_fused_causal = (nullptr != fused_runner && parameters.is_unidirectional);
|
||||
|
||||
// Default format for memory efficient attention.
|
||||
// When there is past state, the format shal be BxNxSxH, so we disable memory efficient attention when there is past.
|
||||
DUMP_ATTENTION_INIT();
|
||||
if (nullptr != data.gemm_buffer) {
|
||||
if (data.bias == nullptr) {
|
||||
|
|
@ -277,13 +295,16 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
|
|||
qkv_format = AttentionQkvFormat::Q_K_V_BNSH;
|
||||
} else {
|
||||
// For fused TRT attention, transpose qkv to BxSxNx3xH (format 2)
|
||||
// For memory efficient attention, transpose to 3xBxSxNxH (format 3)
|
||||
// For unfused kernel, transpose to 3xBxNxSxH (format 1)
|
||||
// For fused causal kernel, use format 1 since we need have K and V in BNSH format to update present state,
|
||||
// we also update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
|
||||
const int format = (use_fused_kernel ? 2 : 1);
|
||||
// For fused causal kernel, use format 1 since we need have K and V to update present state,
|
||||
// at the same time, we update gemm_buffer BxSx3xNxH with bias which is used as input for fused causal kernel.
|
||||
const int format = (use_fused_kernel ? 2 : (use_memory_efficient_attention ? 3 : 1));
|
||||
qkv_format = use_fused_kernel
|
||||
? AttentionQkvFormat::QKV_BSN3H
|
||||
: (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH);
|
||||
: (use_memory_efficient_attention
|
||||
? AttentionQkvFormat::Q_K_V_BSNH
|
||||
: (use_fused_causal ? AttentionQkvFormat::Q_K_V_BNSH_QKV_BS3NH : AttentionQkvFormat::Q_K_V_BNSH));
|
||||
|
||||
// For fused causal, we will update gemm_buffer with bias directly.
|
||||
T* qkv_add_bias = use_fused_causal ? data.gemm_buffer : nullptr;
|
||||
|
|
@ -318,7 +339,21 @@ Status PrepareQkv(contrib::AttentionParameters& parameters,
|
|||
data.bias, data.query, data.key, data.value, qkv, true, kv_sequence_length);
|
||||
|
||||
qkv_format = AttentionQkvFormat::Q_KV_BSNH_BSN2H;
|
||||
} else if (use_fused_kernel) {
|
||||
}
|
||||
#if USE_FLASH_ATTENTION
|
||||
else if (use_memory_efficient_attention) {
|
||||
LaunchAddBias(stream, max_threads_per_block,
|
||||
batch_size, sequence_length, kv_sequence_length,
|
||||
num_heads, qk_head_size, v_head_size,
|
||||
data.bias, data.query, data.key, data.value, q, k, v);
|
||||
|
||||
DUMP_ATTENTION_D("q(BSNH)", q, batch_size * sequence_length, num_heads, qk_head_size);
|
||||
DUMP_ATTENTION_D("k(BSNH)", k, batch_size * kv_sequence_length, num_heads, qk_head_size);
|
||||
DUMP_ATTENTION_D("v(BSNH)", v, batch_size * kv_sequence_length, num_heads, v_head_size);
|
||||
qkv_format = AttentionQkvFormat::Q_K_V_BSNH;
|
||||
}
|
||||
#endif
|
||||
else if (use_fused_kernel) {
|
||||
assert(qk_head_size == v_head_size);
|
||||
|
||||
// Q (BxSxNxH), K (BxSxNxH), V (BxSxNxH) => BxSxNx(H + H + H)
|
||||
|
|
@ -382,9 +417,10 @@ Status QkvToContext(
|
|||
const bool past_present_share_buffer = parameters.past_present_share_buffer;
|
||||
const float mask_filter_value = parameters.mask_filter_value;
|
||||
void* fused_runner = data.fused_runner;
|
||||
bool use_memory_efficient_attention = data.use_memory_efficient_attention;
|
||||
|
||||
// At most one fused kernel is enabled.
|
||||
assert(int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1);
|
||||
assert(int(use_memory_efficient_attention) + int(fused_runner != nullptr) + int(data.fused_cross_attention_kernel != nullptr) <= 1);
|
||||
|
||||
const int batches = batch_size * num_heads;
|
||||
const int size_per_batch_q = sequence_length * qk_head_size;
|
||||
|
|
@ -433,6 +469,7 @@ Status QkvToContext(
|
|||
assert(data.fused_cross_attention_kernel == nullptr);
|
||||
assert(!use_fused_kernel);
|
||||
assert(data.gemm_buffer != nullptr);
|
||||
assert(!use_memory_efficient_attention);
|
||||
|
||||
if (data.present != data.past) {
|
||||
// For easy testing. Production should better avoid this path.
|
||||
|
|
@ -526,6 +563,38 @@ Status QkvToContext(
|
|||
return Status::OK();
|
||||
}
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
if (use_memory_efficient_attention) {
|
||||
// We only enable fused cross attention when there is no key padding mask.
|
||||
// Otherwise, key have effective batch size 2 * batch_size, which is different from batch_size of query.
|
||||
assert(data.mask_index == nullptr);
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BSNH);
|
||||
|
||||
MemoryEfficientAttentionParams p;
|
||||
p.sm = device_prop.major * 10 + device_prop.minor;
|
||||
p.is_half = sizeof(T) == 2;
|
||||
p.batch_size = data.mask_index == nullptr ? parameters.batch_size : 2 * parameters.batch_size;
|
||||
p.num_heads = parameters.num_heads;
|
||||
p.sequence_length = parameters.sequence_length;
|
||||
p.kv_sequence_length = parameters.total_sequence_length;
|
||||
p.qk_head_size = parameters.head_size;
|
||||
p.v_head_size = parameters.v_head_size;
|
||||
p.causal = parameters.is_unidirectional;
|
||||
p.cu_seqlens_q = nullptr;
|
||||
p.cu_seqlens_k = nullptr;
|
||||
p.query = q;
|
||||
p.key = k;
|
||||
p.value = v;
|
||||
p.output = data.output;
|
||||
p.workspace = MemoryEfficientAttentionParams::need_workspace(v_head_size, sizeof(T) == sizeof(float)) ? scratch1 : nullptr;
|
||||
p.stream = stream;
|
||||
run_memory_efficient_attention(p);
|
||||
|
||||
DUMP_ATTENTION("cutlass output", data.output, batch_size * sequence_length, num_heads, v_head_size);
|
||||
return Status::OK();
|
||||
}
|
||||
#endif
|
||||
|
||||
// The following are unfused attention.
|
||||
assert(qkv_format == AttentionQkvFormat::Q_K_V_BNSH);
|
||||
const int* mask_index = data.mask_index;
|
||||
|
|
|
|||
|
|
@ -27,7 +27,8 @@ size_t GetAttentionWorkspaceSize(
|
|||
size_t sequence_length,
|
||||
size_t kv_sequence_length,
|
||||
size_t total_sequence_length,
|
||||
void* fused_runner);
|
||||
void* fused_runner,
|
||||
bool use_memory_efficient_attention = false);
|
||||
|
||||
template <typename T>
|
||||
struct AttentionData {
|
||||
|
|
@ -48,6 +49,8 @@ struct AttentionData {
|
|||
|
||||
void* fused_runner;
|
||||
const void* fused_cross_attention_kernel;
|
||||
|
||||
bool use_memory_efficient_attention;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
|
|
@ -0,0 +1,116 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic push
|
||||
#pragma GCC diagnostic ignored "-Wunused-parameter"
|
||||
#pragma GCC diagnostic ignored "-Wstrict-aliasing"
|
||||
#endif
|
||||
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <typename T, typename ArchTag, bool is_aligned, int queries_per_block, int keys_per_block, bool single_value_iteration>
|
||||
void LaunchCutlassFmha(const MemoryEfficientAttentionParams& params) {
|
||||
using Attention = AttentionKernel<T, ArchTag, is_aligned, queries_per_block, keys_per_block, single_value_iteration>;
|
||||
typename Attention::Params p;
|
||||
{ // set parameters
|
||||
p.query_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.query));
|
||||
p.key_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.key));
|
||||
p.value_ptr = const_cast<T*>(reinterpret_cast<const T*>(params.value));
|
||||
p.cu_seqlens_q_ptr = params.cu_seqlens_q;
|
||||
p.cu_seqlens_k_ptr = params.cu_seqlens_k;
|
||||
|
||||
p.logsumexp_ptr = nullptr; // [num_heads, num_queries] for backward or nullptr for forward
|
||||
p.output_ptr = reinterpret_cast<T*>(params.output);
|
||||
if (Attention::kNeedsOutputAccumulatorBuffer) {
|
||||
using Acc = typename Attention::accum_t;
|
||||
// workspace size: batch_size * sequence_length * num_heads * v_head_size * sizeof(float)
|
||||
ORT_ENFORCE(params.workspace != nullptr, "Need output accumulator buffer but no workspace provided");
|
||||
p.output_accum_ptr = reinterpret_cast<Acc*>(params.workspace);
|
||||
} else {
|
||||
p.output_accum_ptr = nullptr;
|
||||
}
|
||||
p.num_heads = params.num_heads;
|
||||
p.num_batches = params.batch_size;
|
||||
p.head_dim = params.qk_head_size;
|
||||
p.head_dim_value = params.v_head_size;
|
||||
|
||||
// When params.cu_seqlens_q is provided, num_queries is max_seq_q and num_keys will be set inside the kernel
|
||||
p.num_queries = params.sequence_length;
|
||||
p.num_keys = params.kv_sequence_length;
|
||||
|
||||
p.causal = params.causal;
|
||||
|
||||
// Input format is BxSxNxH, output is BxSxNxH
|
||||
p.q_strideH = params.qk_head_size;
|
||||
p.k_strideH = params.qk_head_size;
|
||||
p.v_strideH = params.v_head_size;
|
||||
p.o_strideH = params.v_head_size;
|
||||
|
||||
p.q_strideM = params.num_heads * params.qk_head_size;
|
||||
p.k_strideM = params.num_heads * params.qk_head_size;
|
||||
p.v_strideM = params.num_heads * params.v_head_size;
|
||||
|
||||
p.q_strideB = static_cast<int64_t>(p.q_strideM) * params.sequence_length;
|
||||
p.k_strideB = static_cast<int64_t>(p.k_strideM) * params.kv_sequence_length;
|
||||
p.v_strideB = static_cast<int64_t>(p.v_strideM) * params.kv_sequence_length;
|
||||
p.o_strideB = static_cast<int64_t>(params.num_heads) * params.v_head_size * params.sequence_length;
|
||||
|
||||
p.causal = params.causal;
|
||||
}
|
||||
|
||||
constexpr auto kernel_fn = attention_kernel_batched_impl<Attention>;
|
||||
int smem_bytes = sizeof(typename Attention::SharedStorage);
|
||||
if (smem_bytes > 0xc000) {
|
||||
ORT_ENFORCE(params.sm >= 70, "This kernel requires too much shared memory on this machine!");
|
||||
static bool once = [&]() {
|
||||
cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes);
|
||||
return true;
|
||||
}();
|
||||
}
|
||||
|
||||
ORT_ENFORCE(Attention::check_supported(p));
|
||||
kernel_fn<<<p.getBlocksGrid(), p.getThreadsGrid(), smem_bytes, params.stream>>>(p);
|
||||
}
|
||||
|
||||
template <typename T, typename ArchTag, int queries_per_block, int keys_per_block, bool single_value_iteration>
|
||||
void DispatchIsAligned(const MemoryEfficientAttentionParams& params) {
|
||||
using AlignedAK = AttentionKernel<T, ArchTag, true, queries_per_block, keys_per_block, single_value_iteration>;
|
||||
|
||||
// Run a more efficient kernel with `isAligned=True` when memory is correctly aligned.
|
||||
bool is_aligned = params.qk_head_size % AlignedAK::kAlignmentQ == 0 &&
|
||||
params.qk_head_size % AlignedAK::kAlignmentK == 0 &&
|
||||
params.v_head_size % AlignedAK::kAlignmentV == 0;
|
||||
|
||||
DISPATCH_BOOL(is_aligned, kIsAligned, ([&]() {
|
||||
LaunchCutlassFmha<T, ArchTag, kIsAligned, queries_per_block, keys_per_block, single_value_iteration>(params);
|
||||
}));
|
||||
}
|
||||
|
||||
template <typename T, typename ArchTag>
|
||||
void DispatchBlockSize(const MemoryEfficientAttentionParams& params) {
|
||||
if (params.v_head_size <= 64) {
|
||||
DispatchIsAligned<T, ArchTag, 64, 64, true>(params);
|
||||
} else if (params.v_head_size <= 128) {
|
||||
DispatchIsAligned<T, ArchTag, 32, 128, true>(params);
|
||||
} else {
|
||||
DispatchIsAligned<T, ArchTag, 32, 128, false>(params);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#if defined(__GNUC__)
|
||||
#pragma GCC diagnostic pop
|
||||
#endif
|
||||
|
||||
#endif // USE_FLASH_ATTENTION
|
||||
24
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu
Normal file
24
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm50.cu
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params) {
|
||||
if (params.is_half) {
|
||||
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm50>(params);
|
||||
} else {
|
||||
DispatchBlockSize<float, cutlass::arch::Sm50>(params);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // USE_FLASH_ATTENTION
|
||||
24
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu
Normal file
24
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm70.cu
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params) {
|
||||
if (params.is_half) {
|
||||
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm70>(params);
|
||||
} else {
|
||||
DispatchBlockSize<float, cutlass::arch::Sm70>(params);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // USE_FLASH_ATTENTION
|
||||
24
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu
Normal file
24
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm75.cu
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params) {
|
||||
if (params.is_half) {
|
||||
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm75>(params);
|
||||
} else {
|
||||
DispatchBlockSize<float, cutlass::arch::Sm75>(params);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // USE_FLASH_ATTENTION
|
||||
24
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu
Normal file
24
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/fmha_sm80.cu
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/fmha_launch_template.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params) {
|
||||
if (params.is_half) {
|
||||
DispatchBlockSize<cutlass::half_t, cutlass::arch::Sm80>(params);
|
||||
} else {
|
||||
DispatchBlockSize<float, cutlass::arch::Sm80>(params);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // USE_FLASH_ATTENTION
|
||||
947
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h
Normal file
947
onnxruntime/contrib_ops/cuda/bert/cutlass_fmha/kernel_forward.h
Normal file
|
|
@ -0,0 +1,947 @@
|
|||
/***************************************************************************************************
|
||||
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
* SPDX-License-Identifier: BSD-3-Clause
|
||||
*
|
||||
* Redistribution and use in source and binary forms, with or without
|
||||
* modification, are permitted provided that the following conditions are met:
|
||||
*
|
||||
* 1. Redistributions of source code must retain the above copyright notice, this
|
||||
* list of conditions and the following disclaimer.
|
||||
*
|
||||
* 2. Redistributions in binary form must reproduce the above copyright notice,
|
||||
* this list of conditions and the following disclaimer in the documentation
|
||||
* and/or other materials provided with the distribution.
|
||||
*
|
||||
* 3. Neither the name of the copyright holdvr nor the names of its
|
||||
* contributors may be used to endorse or promote products derived from
|
||||
* this software without specific prior written permission.
|
||||
*
|
||||
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
*
|
||||
**************************************************************************************************/
|
||||
|
||||
#pragma once
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "cutlass/bfloat16.h"
|
||||
#include "cutlass/gemm/gemm.h"
|
||||
#include "cutlass/layout/matrix.h"
|
||||
#include "cutlass/layout/vector.h"
|
||||
#include "cutlass/numeric_types.h"
|
||||
|
||||
#include "41_fused_multi_head_attention/attention_scaling_coefs_updater.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_simt.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h"
|
||||
#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h"
|
||||
#include "cutlass/gemm/device/default_gemm_configuration.h"
|
||||
#include "cutlass/gemm/kernel/default_gemm.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_simt.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm70.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm75.h"
|
||||
#include "cutlass/gemm/threadblock/default_mma_core_sm80.h"
|
||||
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
|
||||
#include "cutlass/matrix_shape.h"
|
||||
#include "cutlass/platform/platform.h"
|
||||
#include "cutlass/transform/threadblock/predicated_tile_iterator.h"
|
||||
#include "41_fused_multi_head_attention/debug_utils.h"
|
||||
#include "41_fused_multi_head_attention/epilogue_pipelined.h"
|
||||
#include "41_fused_multi_head_attention/epilogue_rescale_output.h"
|
||||
#include "41_fused_multi_head_attention/find_default_mma.h"
|
||||
#include "41_fused_multi_head_attention/gemm_kernel_utils.h"
|
||||
#include "41_fused_multi_head_attention/mma_from_smem.h"
|
||||
|
||||
#include <inttypes.h>
|
||||
|
||||
using namespace gemm_kernel_utils;
|
||||
|
||||
namespace {
|
||||
template <typename scalar_t, typename Arch>
|
||||
constexpr int getWarpsPerSm() {
|
||||
return (
|
||||
Arch::kMinComputeCapability >= 80 &&
|
||||
!cutlass::platform::is_same<scalar_t, float>::value
|
||||
? 16
|
||||
: 12);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
template <
|
||||
// The datatype of Q/K/V
|
||||
typename scalar_t_,
|
||||
// Architecture we are targeting (eg `cutlass::arch::Sm80`)
|
||||
typename ArchTag,
|
||||
// If Q/K/V are correctly aligned in memory and we can run a fast kernel
|
||||
bool isAligned_,
|
||||
int kQueriesPerBlock,
|
||||
int kKeysPerBlock,
|
||||
bool kSingleValueIteration // = `value.shape[-1] <= kKeysPerBlock`
|
||||
>
|
||||
struct AttentionKernel {
|
||||
using scalar_t = scalar_t_;
|
||||
using accum_t = float;
|
||||
using lse_scalar_t = float;
|
||||
using output_t = scalar_t;
|
||||
// Accumulator between 2 iterations
|
||||
// Using `accum_t` improves perf on f16 at the cost of
|
||||
// numerical errors
|
||||
using output_accum_t = accum_t;
|
||||
static constexpr bool kIsAligned = isAligned_;
|
||||
static constexpr int32_t kAlignLSE = 32; // block size of backward
|
||||
static constexpr bool kPreloadV = ArchTag::kMinComputeCapability >= 80 &&
|
||||
cutlass::sizeof_bits<scalar_t>::value == 16;
|
||||
static constexpr bool kKeepOutputInRF = kSingleValueIteration;
|
||||
static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF &&
|
||||
!cutlass::platform::is_same<output_accum_t, output_t>::value;
|
||||
|
||||
static_assert(kQueriesPerBlock % 32 == 0, "");
|
||||
static_assert(kKeysPerBlock % 32 == 0, "");
|
||||
static constexpr int kNumWarpsPerBlock =
|
||||
kQueriesPerBlock * kKeysPerBlock / (32 * 32);
|
||||
static constexpr int kWarpSize = 32;
|
||||
|
||||
// Launch bounds
|
||||
static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock;
|
||||
static constexpr int kMinBlocksPerSm =
|
||||
getWarpsPerSm<scalar_t, ArchTag>() / kNumWarpsPerBlock;
|
||||
|
||||
struct Params {
|
||||
// Input tensors
|
||||
scalar_t* query_ptr; // [num_queries, num_heads, head_dim]
|
||||
scalar_t* key_ptr; // [num_keys, num_heads, head_dim]
|
||||
scalar_t* value_ptr; // [num_keys, num_heads, head_dim_value]
|
||||
int32_t* cu_seqlens_q_ptr = nullptr;
|
||||
int32_t* cu_seqlens_k_ptr = nullptr;
|
||||
|
||||
// Output tensors
|
||||
output_t* output_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
output_accum_t*
|
||||
output_accum_ptr; // [num_queries, num_heads, head_dim_value]
|
||||
lse_scalar_t* logsumexp_ptr; // [num_heads, num_queries] - can be null
|
||||
|
||||
// Dimensions/strides
|
||||
int32_t head_dim;
|
||||
int32_t head_dim_value;
|
||||
int32_t num_queries;
|
||||
int32_t num_keys;
|
||||
|
||||
bool causal;
|
||||
|
||||
int32_t q_strideM;
|
||||
int32_t k_strideM;
|
||||
int32_t v_strideM;
|
||||
|
||||
// Everything below is only used in `advance_to_block`
|
||||
// and shouldn't use registers
|
||||
int32_t q_strideH;
|
||||
int32_t k_strideH;
|
||||
int32_t v_strideH;
|
||||
int32_t o_strideH;
|
||||
int64_t q_strideB;
|
||||
int64_t k_strideB;
|
||||
int64_t v_strideB;
|
||||
int64_t o_strideB;
|
||||
int32_t num_batches;
|
||||
int32_t num_heads;
|
||||
|
||||
// https://github.com/NVIDIA/cutlass/issues/771
|
||||
CUTLASS_HOST_DEVICE int32_t o_strideM() const {
|
||||
return head_dim_value * num_heads;
|
||||
}
|
||||
|
||||
// Moves pointers to what we should process
|
||||
// Returns "false" if there is no work to do
|
||||
CUTLASS_DEVICE bool advance_to_block() {
|
||||
auto batch_id = blockIdx.z;
|
||||
auto head_id = blockIdx.y;
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
|
||||
auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE;
|
||||
|
||||
int64_t q_start, k_start;
|
||||
// Advance to current batch - in case of different sequence lengths
|
||||
if (cu_seqlens_q_ptr != nullptr) {
|
||||
assert(cu_seqlens_k_ptr != nullptr);
|
||||
cu_seqlens_q_ptr += batch_id;
|
||||
cu_seqlens_k_ptr += batch_id;
|
||||
q_start = cu_seqlens_q_ptr[0];
|
||||
k_start = cu_seqlens_k_ptr[0];
|
||||
int64_t q_next_start = cu_seqlens_q_ptr[1];
|
||||
int64_t k_next_start = cu_seqlens_k_ptr[1];
|
||||
num_queries = q_next_start - q_start;
|
||||
num_keys = k_next_start - k_start;
|
||||
|
||||
if (query_start >= num_queries) {
|
||||
return false;
|
||||
}
|
||||
} else {
|
||||
query_ptr += batch_id * q_strideB;
|
||||
key_ptr += batch_id * k_strideB;
|
||||
value_ptr += batch_id * v_strideB;
|
||||
output_ptr += batch_id * o_strideB;
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += batch_id * o_strideB;
|
||||
}
|
||||
q_start = 0;
|
||||
k_start = 0;
|
||||
}
|
||||
|
||||
// Advance to the current batch / head / query_start
|
||||
query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH;
|
||||
key_ptr += k_start * k_strideM + head_id * k_strideH;
|
||||
value_ptr += k_start * v_strideM + head_id * v_strideH;
|
||||
output_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * o_strideH;
|
||||
|
||||
if (output_accum_ptr != nullptr) {
|
||||
output_accum_ptr += int64_t(q_start + query_start) * o_strideM() +
|
||||
head_id * o_strideH;
|
||||
} else {
|
||||
// Accumulate directly in the destination buffer (eg for f32)
|
||||
output_accum_ptr = (accum_t*)output_ptr;
|
||||
}
|
||||
if (logsumexp_ptr != nullptr) {
|
||||
// lse[batch_id, head_id, query_start]
|
||||
logsumexp_ptr +=
|
||||
batch_id * lse_dim * num_heads + head_id * lse_dim + query_start;
|
||||
}
|
||||
|
||||
num_queries -= query_start;
|
||||
if (causal) {
|
||||
num_keys = cutlass::fast_min(
|
||||
int32_t(query_start + kQueriesPerBlock), num_keys);
|
||||
}
|
||||
num_batches = 0; // no longer used after
|
||||
|
||||
// Make sure the compiler knows these variables are the same on all
|
||||
// the threads of the warp.
|
||||
query_ptr = warp_uniform(query_ptr);
|
||||
key_ptr = warp_uniform(key_ptr);
|
||||
value_ptr = warp_uniform(value_ptr);
|
||||
output_ptr = warp_uniform(output_ptr);
|
||||
output_accum_ptr = warp_uniform(output_accum_ptr);
|
||||
logsumexp_ptr = warp_uniform(logsumexp_ptr);
|
||||
num_queries = warp_uniform(num_queries);
|
||||
num_keys = warp_uniform(num_keys);
|
||||
head_dim = warp_uniform(head_dim);
|
||||
head_dim_value = warp_uniform(head_dim_value);
|
||||
return true;
|
||||
}
|
||||
|
||||
__host__ dim3 getBlocksGrid() const {
|
||||
return dim3(
|
||||
ceil_div(num_queries, (int32_t)kQueriesPerBlock),
|
||||
num_heads,
|
||||
num_batches);
|
||||
}
|
||||
__host__ dim3 getThreadsGrid() const {
|
||||
return dim3(kWarpSize, kNumWarpsPerBlock, 1);
|
||||
}
|
||||
};
|
||||
|
||||
struct MM0 {
|
||||
/*
|
||||
In this first matmul, we compute a block of `Q @ K.T`.
|
||||
While the calculation result is still hot in registers, we update
|
||||
`mi`, `m_prime`, `s_prime` in shared-memory, and then store this value
|
||||
into a shared-memory ("AccumulatorSharedStorage") that is used later as
|
||||
operand A for the second matmul (see MM1)
|
||||
*/
|
||||
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
||||
|
||||
using OpClass = typename GemmType::OpClass;
|
||||
using DefaultConfig =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
ArchTag,
|
||||
scalar_t,
|
||||
scalar_t,
|
||||
scalar_t, // ElementC
|
||||
accum_t // ElementAccumulator
|
||||
>;
|
||||
static constexpr int kAlignmentA =
|
||||
kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment;
|
||||
static constexpr int kAlignmentB =
|
||||
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
||||
using ThreadblockShape = cutlass::gemm::
|
||||
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
||||
using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma<
|
||||
scalar_t, // ElementA,
|
||||
cutlass::layout::RowMajor, // LayoutA,
|
||||
kAlignmentA,
|
||||
scalar_t, // ElementB,
|
||||
cutlass::layout::ColumnMajor, // LayoutB,
|
||||
kAlignmentB,
|
||||
accum_t,
|
||||
cutlass::layout::RowMajor, // LayoutC,
|
||||
OpClass,
|
||||
ArchTag, // ArchTag
|
||||
ThreadblockShape, // ThreadblockShape
|
||||
WarpShape, // WarpShape
|
||||
typename GemmType::InstructionShape, // InstructionShape
|
||||
DefaultConfig::kStages, // Should use `DefaultConfig::kStages`, but that
|
||||
// uses too much smem
|
||||
typename GemmType::Operator // Operator
|
||||
>::DefaultMma;
|
||||
using MmaCore = typename DefaultMma::MmaCore;
|
||||
using IteratorA = typename DefaultMma::IteratorA;
|
||||
using IteratorB = typename DefaultMma::IteratorB;
|
||||
using Mma = typename DefaultMma::ThreadblockMma;
|
||||
using ScalingCoefsUpdater = typename DefaultAttentionScalingCoefsUpdater<
|
||||
typename Mma::Operator::IteratorC,
|
||||
accum_t,
|
||||
kWarpSize>::Updater;
|
||||
static_assert(
|
||||
MmaCore::WarpCount::kM * MmaCore::WarpCount::kN *
|
||||
MmaCore::WarpCount::kK ==
|
||||
kNumWarpsPerBlock,
|
||||
"");
|
||||
|
||||
// Epilogue to store to shared-memory in a format that we can use later for
|
||||
// the second matmul
|
||||
using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm<
|
||||
typename Mma::Operator::IteratorC,
|
||||
typename Mma::Operator,
|
||||
scalar_t,
|
||||
WarpShape,
|
||||
ThreadblockShape>;
|
||||
using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage;
|
||||
};
|
||||
|
||||
struct MM1 {
|
||||
/**
|
||||
Second matmul: perform `attn @ V` where `attn` is the attention (not
|
||||
normalized) and stored in shared memory
|
||||
*/
|
||||
using GemmType = DefaultGemmType<ArchTag, scalar_t>;
|
||||
|
||||
using OpClass = typename GemmType::OpClass;
|
||||
using DefaultConfig =
|
||||
typename cutlass::gemm::device::DefaultGemmConfiguration<
|
||||
OpClass,
|
||||
ArchTag,
|
||||
scalar_t,
|
||||
scalar_t,
|
||||
output_accum_t, // ElementC
|
||||
accum_t // ElementAccumulator
|
||||
>;
|
||||
static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem
|
||||
static constexpr int kAlignmentB =
|
||||
kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment;
|
||||
using ThreadblockShape = cutlass::gemm::
|
||||
GemmShape<kQueriesPerBlock, kKeysPerBlock, GemmType::ThreadK>;
|
||||
using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
|
||||
using InstructionShape = typename GemmType::InstructionShape;
|
||||
|
||||
using LayoutB = cutlass::layout::RowMajor;
|
||||
using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
|
||||
scalar_t, // ElementA,
|
||||
cutlass::layout::RowMajor, // LayoutA,
|
||||
kAlignmentA,
|
||||
scalar_t, // ElementB,
|
||||
LayoutB, // LayoutB,
|
||||
kAlignmentB,
|
||||
output_accum_t,
|
||||
cutlass::layout::RowMajor, // LayoutC,
|
||||
accum_t,
|
||||
OpClass,
|
||||
ArchTag,
|
||||
ThreadblockShape,
|
||||
WarpShape,
|
||||
typename GemmType::InstructionShape,
|
||||
typename DefaultConfig::EpilogueOutputOp,
|
||||
void, // ThreadblockSwizzle - not used
|
||||
DefaultConfig::kStages,
|
||||
false, // SplitKSerial
|
||||
typename GemmType::Operator>;
|
||||
|
||||
using DefaultMmaFromSmem =
|
||||
typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory<
|
||||
typename DefaultGemm::Mma,
|
||||
typename MM0::AccumulatorSharedStorage>;
|
||||
using Mma = typename DefaultMmaFromSmem::Mma;
|
||||
using IteratorB = typename Mma::IteratorB;
|
||||
using WarpCount = typename Mma::WarpCount;
|
||||
static_assert(
|
||||
WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock,
|
||||
"");
|
||||
|
||||
using DefaultEpilogue = typename DefaultGemm::Epilogue;
|
||||
using OutputTileIterator =
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_t>;
|
||||
using OutputTileIteratorAccum =
|
||||
typename cutlass::epilogue::threadblock::PredicatedTileIterator<
|
||||
typename DefaultEpilogue::OutputTileIterator::ThreadMap,
|
||||
output_accum_t>;
|
||||
|
||||
struct SharedStorageMM1 {
|
||||
typename Mma::SharedStorage mm;
|
||||
};
|
||||
};
|
||||
|
||||
static constexpr int64_t kAlignmentQ = MM0::kAlignmentA;
|
||||
static constexpr int64_t kAlignmentK = MM0::kAlignmentB;
|
||||
static constexpr int64_t kAlignmentV = 1;
|
||||
|
||||
// Shared storage - depends on kernel params
|
||||
struct ScalingCoefs {
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> m_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> s_prime;
|
||||
cutlass::Array<accum_t, kQueriesPerBlock> mi;
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueAtEnd : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
};
|
||||
|
||||
union {
|
||||
typename MM0::Mma::SharedStorage mm0;
|
||||
SharedStorageAfterMM0 after_mm0;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
||||
epilogue_shared_storage() {
|
||||
return epilogue;
|
||||
}
|
||||
};
|
||||
|
||||
struct SharedStorageEpilogueInLoop : ScalingCoefs {
|
||||
struct SharedStorageAfterMM0 {
|
||||
// Everything here might be overwritten during MM0
|
||||
typename MM0::AccumulatorSharedStorage si;
|
||||
typename MM1::SharedStorageMM1 mm1;
|
||||
typename MM1::DefaultEpilogue::SharedStorage epilogue;
|
||||
};
|
||||
|
||||
union {
|
||||
typename MM0::Mma::SharedStorage mm0;
|
||||
SharedStorageAfterMM0 after_mm0;
|
||||
};
|
||||
|
||||
CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage&
|
||||
epilogue_shared_storage() {
|
||||
return after_mm0.epilogue;
|
||||
}
|
||||
};
|
||||
|
||||
using SharedStorage = typename cutlass::platform::conditional<
|
||||
kSingleValueIteration || kKeepOutputInRF,
|
||||
SharedStorageEpilogueAtEnd,
|
||||
SharedStorageEpilogueInLoop>::type;
|
||||
|
||||
static bool __host__ check_supported(Params const& p) {
|
||||
CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ);
|
||||
CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK);
|
||||
CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV);
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideM % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideM % kAlignmentK == 0, "key is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideM % kAlignmentV == 0, "value is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.q_strideH % kAlignmentQ == 0, "query is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.k_strideH % kAlignmentK == 0, "key is not correctly aligned");
|
||||
XFORMERS_CHECK(
|
||||
p.v_strideH % kAlignmentV == 0, "value is not correctly aligned");
|
||||
return true;
|
||||
}
|
||||
|
||||
static void CUTLASS_DEVICE attention_kernel(Params& p) {
|
||||
// In this block, we will only ever:
|
||||
// - read query[query_start:query_end, :]
|
||||
// - write to output[query_start:query_end, :]
|
||||
|
||||
extern __shared__ char smem_buffer[];
|
||||
SharedStorage& shared_storage = *((SharedStorage*)smem_buffer);
|
||||
auto& m_prime = shared_storage.m_prime;
|
||||
auto& s_prime = shared_storage.s_prime;
|
||||
auto& mi = shared_storage.mi;
|
||||
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (thread_id() < kQueriesPerBlock) {
|
||||
s_prime[thread_id()] = accum_t(0);
|
||||
m_prime[thread_id()] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
mi[thread_id()] = -cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
typename MM1::Mma::FragmentC accum_o;
|
||||
accum_o.clear();
|
||||
|
||||
auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator {
|
||||
using OutputTileIterator = typename MM1::OutputTileIterator;
|
||||
return OutputTileIterator(
|
||||
typename OutputTileIterator::Params{(int32_t)p.o_strideM()},
|
||||
p.output_ptr,
|
||||
typename OutputTileIterator::TensorCoord{
|
||||
p.num_queries, p.head_dim_value},
|
||||
thread_id(),
|
||||
{0, col});
|
||||
};
|
||||
|
||||
auto createOutputAccumIter = [&](int col) ->
|
||||
typename MM1::OutputTileIteratorAccum {
|
||||
using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum;
|
||||
return OutputTileIteratorAccum(
|
||||
typename OutputTileIteratorAccum::Params{(int32_t)p.o_strideM()},
|
||||
p.output_accum_ptr,
|
||||
typename OutputTileIteratorAccum::TensorCoord{
|
||||
p.num_queries, p.head_dim_value},
|
||||
thread_id(),
|
||||
{0, col});
|
||||
};
|
||||
|
||||
// Iterate through keys
|
||||
for (int32_t iter_key_start = 0; iter_key_start < p.num_keys;
|
||||
iter_key_start += kKeysPerBlock) {
|
||||
int32_t problem_size_0_m =
|
||||
cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries);
|
||||
int32_t problem_size_0_n = cutlass::fast_min(
|
||||
int32_t(kKeysPerBlock), p.num_keys - iter_key_start);
|
||||
int32_t const& problem_size_0_k = p.head_dim;
|
||||
int32_t const& problem_size_1_n = p.head_dim_value;
|
||||
int32_t const& problem_size_1_k = problem_size_0_n;
|
||||
|
||||
auto prologueV = [&](int blockN) {
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
|
||||
p.value_ptr + iter_key_start * p.v_strideM,
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
MM1::Mma::prologue(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
iterator_V,
|
||||
thread_id(),
|
||||
problem_size_1_k);
|
||||
};
|
||||
|
||||
__syncthreads(); // Need to have shared memory initialized, and `m_prime`
|
||||
// updated from end of prev iter
|
||||
//
|
||||
// MATMUL: Q.K_t
|
||||
//
|
||||
// Computes the block-matrix product of:
|
||||
// (a) query[query_start:query_end, :]
|
||||
// with
|
||||
// (b) key[iter_key_start:iter_key_start + kKeysPerBlock]
|
||||
// and stores that into `shared_storage.si`
|
||||
//
|
||||
|
||||
// Compute threadblock location
|
||||
cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_A{
|
||||
tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()};
|
||||
|
||||
cutlass::MatrixCoord tb_offset_B{
|
||||
tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN};
|
||||
|
||||
// Construct iterators to A and B operands
|
||||
typename MM0::IteratorA iterator_A(
|
||||
typename MM0::IteratorA::Params(
|
||||
typename MM0::MmaCore::LayoutA(p.q_strideM)),
|
||||
p.query_ptr,
|
||||
{problem_size_0_m, problem_size_0_k},
|
||||
thread_id(),
|
||||
tb_offset_A);
|
||||
|
||||
typename MM0::IteratorB iterator_B(
|
||||
typename MM0::IteratorB::Params(
|
||||
typename MM0::MmaCore::LayoutB(p.k_strideM)),
|
||||
p.key_ptr + iter_key_start * p.k_strideM,
|
||||
{problem_size_0_k, problem_size_0_n},
|
||||
thread_id(),
|
||||
tb_offset_B);
|
||||
|
||||
auto my_warp_id = warp_id();
|
||||
auto my_lane_id = lane_id();
|
||||
|
||||
// Construct thread-scoped matrix multiply
|
||||
typename MM0::Mma mma(
|
||||
shared_storage.mm0, thread_id(), my_warp_id, my_lane_id);
|
||||
|
||||
typename MM0::Mma::FragmentC accum;
|
||||
|
||||
accum.clear();
|
||||
|
||||
auto gemm_k_iterations =
|
||||
(problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add
|
||||
mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum);
|
||||
__syncthreads();
|
||||
|
||||
if (kPreloadV) {
|
||||
prologueV(0);
|
||||
}
|
||||
|
||||
typename MM0::Mma::Operator::IteratorC::TensorCoord
|
||||
iteratorC_tile_offset = {
|
||||
(tb_tile_offset.m() * MM0::Mma::WarpCount::kM) +
|
||||
(my_warp_id % MM0::Mma::WarpCount::kM),
|
||||
(tb_tile_offset.n() * MM0::Mma::WarpCount::kN) +
|
||||
(my_warp_id / MM0::Mma::WarpCount::kM)};
|
||||
|
||||
// Mask out last if causal
|
||||
if (p.causal && p.num_keys - iter_key_start <= kKeysPerBlock) {
|
||||
auto query_start = blockIdx.x * kQueriesPerBlock;
|
||||
auto lane_offset = MM0::ScalingCoefsUpdater::get_lane_offset(
|
||||
lane_id(), warp_id(), iteratorC_tile_offset);
|
||||
int32_t last_col;
|
||||
MM0::ScalingCoefsUpdater::iterateRows(
|
||||
lane_offset,
|
||||
[&](int accum_m) {
|
||||
last_col = query_start + accum_m - iter_key_start;
|
||||
},
|
||||
[&](int accum_m, int accum_n, int idx) {
|
||||
if (accum_n > last_col) {
|
||||
accum[idx] =
|
||||
-cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
},
|
||||
[&](int accum_m) {});
|
||||
}
|
||||
DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
p.num_keys - iter_key_start >= kKeysPerBlock,
|
||||
kFullColumns,
|
||||
([&] {
|
||||
// Update `mi` from accum stored in registers
|
||||
// Also updates `accum` with accum[i] <-
|
||||
// exp(accum[i] * scale
|
||||
// - mi)
|
||||
MM0::ScalingCoefsUpdater::update<
|
||||
kQueriesPerBlock,
|
||||
kFullColumns,
|
||||
kIsFirst,
|
||||
kKeepOutputInRF>(
|
||||
accum_o,
|
||||
accum,
|
||||
mi,
|
||||
m_prime,
|
||||
s_prime,
|
||||
lane_id(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
p.num_keys - iter_key_start,
|
||||
iteratorC_tile_offset,
|
||||
1.0f / cutlass::fast_sqrt(float(p.head_dim)));
|
||||
}));
|
||||
}));
|
||||
|
||||
// Output results to shared-memory
|
||||
int warp_idx_mn_0 = my_warp_id %
|
||||
(MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN);
|
||||
auto output_tile_coords = cutlass::MatrixCoord{
|
||||
warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM,
|
||||
warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM};
|
||||
|
||||
MM0::B2bGemm::accumToSmem(
|
||||
shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords);
|
||||
|
||||
__syncthreads();
|
||||
|
||||
//
|
||||
// MATMUL: Attn . V
|
||||
// Run the matmul `attn @ V` for a block of attn and V.
|
||||
// `attn` is read from shared memory (in `shared_storage_si`)
|
||||
// `V` is read from global memory (with iterator_B)
|
||||
//
|
||||
|
||||
const int64_t nBlockN = kSingleValueIteration
|
||||
? 1
|
||||
: ceil_div(
|
||||
(int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN));
|
||||
for (int blockN = 0; blockN < nBlockN; ++blockN) {
|
||||
int gemm_k_iterations =
|
||||
(problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK;
|
||||
|
||||
// Compute threadblock-scoped matrix multiply-add and store it in accum
|
||||
// (in registers)
|
||||
if (!kPreloadV) {
|
||||
__syncthreads(); // we share shmem between mma and epilogue
|
||||
}
|
||||
|
||||
typename MM1::Mma::IteratorB iterator_V(
|
||||
typename MM1::IteratorB::Params{MM1::LayoutB(p.v_strideM)},
|
||||
p.value_ptr + iter_key_start * p.v_strideM,
|
||||
{problem_size_1_k, problem_size_1_n},
|
||||
thread_id(),
|
||||
cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN});
|
||||
typename MM1::Mma mma_pv(
|
||||
shared_storage.after_mm0.mm1.mm,
|
||||
shared_storage.after_mm0.si,
|
||||
(int)thread_id(),
|
||||
(int)warp_id(),
|
||||
(int)lane_id(),
|
||||
(int)problem_size_1_k);
|
||||
mma_pv.set_prologue_done(kPreloadV);
|
||||
if (!kKeepOutputInRF) {
|
||||
accum_o.clear();
|
||||
}
|
||||
mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o);
|
||||
__syncthreads();
|
||||
|
||||
if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) {
|
||||
prologueV(blockN + 1);
|
||||
}
|
||||
|
||||
if (!kKeepOutputInRF) {
|
||||
DISPATCH_BOOL(
|
||||
iter_key_start == 0, kIsFirst, ([&] {
|
||||
DISPATCH_BOOL(
|
||||
(iter_key_start + kKeysPerBlock) >= p.num_keys,
|
||||
kIsLast,
|
||||
([&] {
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp =
|
||||
typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp = typename cutlass::epilogue::
|
||||
thread::MemoryEfficientAttentionNormalize<
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
output_t,
|
||||
output_accum_t>::type,
|
||||
output_accum_t,
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator,
|
||||
ElementCompute,
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue = typename cutlass::epilogue::threadblock::
|
||||
EpiloguePipelined<
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename cutlass::platform::conditional<
|
||||
kIsLast,
|
||||
typename MM1::OutputTileIterator,
|
||||
typename MM1::OutputTileIteratorAccum>::type,
|
||||
typename DefaultEpilogue::
|
||||
AccumulatorFragmentIterator,
|
||||
typename DefaultEpilogue::WarpTileIterator,
|
||||
typename DefaultEpilogue::SharedLoadIterator,
|
||||
EpilogueOutputOp,
|
||||
typename DefaultEpilogue::Padding,
|
||||
DefaultEpilogue::kFragmentsPerIteration,
|
||||
true, // IterationsUnroll
|
||||
typename MM1::OutputTileIteratorAccum // Read
|
||||
// iterator
|
||||
>;
|
||||
|
||||
int col = blockN * MM1::Mma::Shape::kN;
|
||||
auto source_iter = createOutputAccumIter(col);
|
||||
auto dest_iter = call_conditional<
|
||||
kIsLast,
|
||||
decltype(createOutputIter),
|
||||
decltype(createOutputAccumIter)>::
|
||||
apply(createOutputIter, createOutputAccumIter, col);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
epilogue(rescale, dest_iter, accum_o, source_iter);
|
||||
}));
|
||||
}));
|
||||
if (!kSingleValueIteration) {
|
||||
__syncthreads();
|
||||
}
|
||||
}
|
||||
}
|
||||
__syncthreads(); // we modify `m_prime` after
|
||||
}
|
||||
|
||||
if (kKeepOutputInRF) {
|
||||
constexpr bool kIsFirst = true;
|
||||
constexpr bool kIsLast = true;
|
||||
using DefaultEpilogue = typename MM1::DefaultEpilogue;
|
||||
using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp;
|
||||
using ElementCompute = typename DefaultOp::ElementCompute;
|
||||
using EpilogueOutputOp =
|
||||
typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize<
|
||||
output_t, // output
|
||||
output_accum_t, // source
|
||||
DefaultOp::kCount,
|
||||
typename DefaultOp::ElementAccumulator, // accum
|
||||
output_accum_t, // compute
|
||||
kIsFirst,
|
||||
kIsLast,
|
||||
cutlass::Array<ElementCompute, kQueriesPerBlock>>;
|
||||
using Epilogue =
|
||||
typename cutlass::epilogue::threadblock::EpiloguePipelined<
|
||||
typename DefaultEpilogue::Shape,
|
||||
typename MM1::Mma::Operator,
|
||||
DefaultEpilogue::kPartitionsK,
|
||||
typename MM1::OutputTileIterator, // destination
|
||||
typename DefaultEpilogue::AccumulatorFragmentIterator,
|
||||
typename DefaultEpilogue::WarpTileIterator,
|
||||
typename DefaultEpilogue::SharedLoadIterator,
|
||||
EpilogueOutputOp,
|
||||
typename DefaultEpilogue::Padding,
|
||||
DefaultEpilogue::kFragmentsPerIteration,
|
||||
true, // IterationsUnroll
|
||||
typename MM1::OutputTileIteratorAccum // source tile
|
||||
>;
|
||||
auto dest_iter = createOutputIter(0);
|
||||
EpilogueOutputOp rescale(s_prime, m_prime);
|
||||
Epilogue epilogue(
|
||||
shared_storage.epilogue_shared_storage(),
|
||||
thread_id(),
|
||||
warp_id(),
|
||||
lane_id());
|
||||
epilogue(rescale, dest_iter, accum_o);
|
||||
}
|
||||
|
||||
// 7. Calculate logsumexp
|
||||
// To make the backward easier, we pad logsumexp with `inf`
|
||||
// this avoids a few bound checks, and is not more expensive during fwd
|
||||
static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, "");
|
||||
if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) {
|
||||
auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE;
|
||||
if (thread_id() < p.num_queries) {
|
||||
p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()]) +
|
||||
cutlass::fast_log(accum_t(s_prime[thread_id()]));
|
||||
} else if (thread_id() < lse_dim) {
|
||||
p.logsumexp_ptr[thread_id()] =
|
||||
cutlass::platform::numeric_limits<accum_t>::infinity();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static CUTLASS_DEVICE int8_t lane_id() {
|
||||
return threadIdx.x;
|
||||
}
|
||||
static CUTLASS_DEVICE int8_t warp_id() {
|
||||
return threadIdx.y;
|
||||
}
|
||||
static CUTLASS_DEVICE int16_t thread_id() {
|
||||
return threadIdx.x + threadIdx.y * blockDim.x;
|
||||
}
|
||||
};
|
||||
|
||||
template <typename AK>
|
||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
attention_kernel_batched_impl(typename AK::Params p) {
|
||||
if (!p.advance_to_block()) {
|
||||
return;
|
||||
}
|
||||
AK::attention_kernel(p);
|
||||
}
|
||||
|
||||
template <typename AK>
|
||||
__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm)
|
||||
attention_kernel_batched(typename AK::Params params);
|
||||
|
||||
#define _ATTENTION_KERNEL_FORWARD_BEGIN(...) \
|
||||
template <> \
|
||||
__global__ void __launch_bounds__( \
|
||||
__VA_ARGS__::kNumThreads, __VA_ARGS__::kMinBlocksPerSm) \
|
||||
attention_kernel_batched<__VA_ARGS__>(typename __VA_ARGS__::Params p) { \
|
||||
using Kernel = __VA_ARGS__;
|
||||
#define _ATTENTION_KERNEL_FORWARD_END() }
|
||||
|
||||
#ifdef __CUDA_ARCH__
|
||||
#define __CUDA_ARCH_OR_ZERO__ __CUDA_ARCH__
|
||||
#else
|
||||
#define __CUDA_ARCH_OR_ZERO__ 0
|
||||
#endif
|
||||
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD( \
|
||||
ARCH, \
|
||||
SCALAR_T, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER) \
|
||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
||||
SCALAR_T, \
|
||||
cutlass::arch::Sm##ARCH, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER>) \
|
||||
if (!p.advance_to_block()) { \
|
||||
return; \
|
||||
} \
|
||||
Kernel::attention_kernel(p); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED( \
|
||||
ARCH, \
|
||||
SCALAR_T, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER) \
|
||||
_ATTENTION_KERNEL_FORWARD_BEGIN(AttentionKernel< \
|
||||
SCALAR_T, \
|
||||
cutlass::arch::Sm##ARCH, \
|
||||
IS_ALIGNED, \
|
||||
QUERIES_PER_BLOCK, \
|
||||
KEYS_PER_BLOCK, \
|
||||
SINGLE_VALUE_ITER>) \
|
||||
printf( \
|
||||
"FATAL: this function is for sm%d, but was built for sm%d\n", \
|
||||
int(ARCH), \
|
||||
int(__CUDA_ARCH_OR_ZERO__)); \
|
||||
_ATTENTION_KERNEL_FORWARD_END();
|
||||
|
||||
// All kernels are disabled by default
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(50, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(70, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(75, __VA_ARGS__)
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD_DISABLED(80, __VA_ARGS__)
|
||||
|
||||
// Enable the right one based on __CUDA_ARCH__
|
||||
#ifndef __CUDA_ARCH__
|
||||
#elif __CUDA_ARCH__ < 500
|
||||
//#error "Need cuda arch at least 5.0"
|
||||
#elif __CUDA_ARCH__ < 700
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM50(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(50, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ < 750
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM70(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(70, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ < 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM75(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(75, __VA_ARGS__)
|
||||
#elif __CUDA_ARCH__ >= 800
|
||||
#undef INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80
|
||||
#define INSTANTIATE_ATTENTION_KERNEL_FORWARD_SM80(...) \
|
||||
INSTANTIATE_ATTENTION_KERNEL_FORWARD(80, __VA_ARGS__)
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params) {
|
||||
const int32_t& sm = params.sm;
|
||||
if (sm >= 80) {
|
||||
run_memory_efficient_attention_sm80(params);
|
||||
} else if (sm >= 75) {
|
||||
run_memory_efficient_attention_sm75(params);
|
||||
} else if (sm >= 70) {
|
||||
run_memory_efficient_attention_sm70(params);
|
||||
} else if (sm >= 50) {
|
||||
run_memory_efficient_attention_sm50(params);
|
||||
} else {
|
||||
assert(false); // shall not reach here.
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // USE_FLASH_ATTENTION
|
||||
|
|
@ -0,0 +1,55 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
#pragma once
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cpu/bert/attention_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
struct MemoryEfficientAttentionParams {
|
||||
int32_t sm;
|
||||
bool is_half;
|
||||
int32_t batch_size;
|
||||
int32_t num_heads;
|
||||
int32_t sequence_length;
|
||||
int32_t kv_sequence_length;
|
||||
int32_t qk_head_size;
|
||||
int32_t v_head_size;
|
||||
bool causal;
|
||||
|
||||
int32_t* cu_seqlens_q;
|
||||
int32_t* cu_seqlens_k;
|
||||
|
||||
const void* query; // [B, S, N, H]
|
||||
const void* key; // [B, L, N, H], where L is kv_sequence_length
|
||||
const void* value; // [B, L, N, H_v]
|
||||
void* output; // [B, S, N, H_v]
|
||||
void* workspace; // [B, S, N, H_v] when kNeedsOutputAccumulatorBuffer, nullptr otherwise
|
||||
cudaStream_t stream;
|
||||
|
||||
static bool need_workspace(size_t v_head_size, bool is_float) {
|
||||
return (v_head_size > 128 && !is_float);
|
||||
}
|
||||
};
|
||||
|
||||
void run_memory_efficient_attention(const MemoryEfficientAttentionParams& params);
|
||||
|
||||
inline bool has_memory_efficient_attention(int32_t sm, bool is_half) {
|
||||
return sm >= (is_half ? 53 : 50);
|
||||
}
|
||||
|
||||
void run_memory_efficient_attention_sm80(const MemoryEfficientAttentionParams& params);
|
||||
void run_memory_efficient_attention_sm75(const MemoryEfficientAttentionParams& params);
|
||||
void run_memory_efficient_attention_sm70(const MemoryEfficientAttentionParams& params);
|
||||
void run_memory_efficient_attention_sm50(const MemoryEfficientAttentionParams& params);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
||||
#endif // USE_FLASH_ATTENTION
|
||||
|
|
@ -6,6 +6,7 @@
|
|||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "contrib_ops/cuda/bert/multihead_attention.h"
|
||||
#include "contrib_ops/cpu/bert/multihead_attention_helper.h"
|
||||
#include "contrib_ops/cuda/bert/cutlass_fmha/memory_efficient_attention.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace ::onnxruntime::common;
|
||||
|
|
@ -42,8 +43,14 @@ MultiHeadAttention<T>::MultiHeadAttention(const OpKernelInfo& info)
|
|||
disable_fused_runner_ = sizeof(T) != 2 ||
|
||||
ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedAttention, false);
|
||||
|
||||
enable_flash_attention_ = sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFlashAttention, false);
|
||||
enable_trt_flash_attention_ = sizeof(T) == 2 &&
|
||||
!ParseEnvironmentVariableWithDefault<bool>(attention::kDisableTrtFlashAttention, false);
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
disable_memory_efficient_attention_ = ParseEnvironmentVariableWithDefault<bool>(attention::kDisableMemoryEfficientAttention, false);
|
||||
#else
|
||||
disable_memory_efficient_attention_ = true;
|
||||
#endif
|
||||
|
||||
disable_fused_cross_attention_ = sizeof(T) != 2 || ParseEnvironmentVariableWithDefault<bool>(attention::kDisableFusedCrossAttention, false);
|
||||
}
|
||||
|
|
@ -85,7 +92,6 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
bool is_mask_1d_seq_len = parameters.mask_type == AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
|
||||
|
||||
bool use_fused_cross_attention = !disable_fused_cross_attention_ &&
|
||||
nullptr == key_padding_mask &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
|
|
@ -104,17 +110,18 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
|
||||
bool use_fused_runner = !disable_fused_runner_ &&
|
||||
fused_cross_attention_kernel == nullptr &&
|
||||
(nullptr == key_padding_mask || is_mask_1d_seq_len) &&
|
||||
parameters.hidden_size == parameters.v_hidden_size &&
|
||||
parameters.sequence_length == parameters.kv_sequence_length &&
|
||||
FusedMHARunnerFP16v2::is_supported(sm, parameters.head_size, sequence_length,
|
||||
enable_flash_attention_, false);
|
||||
enable_trt_flash_attention_, false);
|
||||
if (use_fused_runner) {
|
||||
// Here we assume that num_heads and head_size does not change for a MultiHeadAttention node.
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
constexpr bool is_unidirectional = false;
|
||||
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(
|
||||
num_heads_, parameters.head_size, sm, is_unidirectional, enable_flash_attention_, parameters.scale));
|
||||
num_heads_, parameters.head_size, sm, is_unidirectional, enable_trt_flash_attention_, parameters.scale));
|
||||
}
|
||||
|
||||
// In case some kernel not loaded due to shared memory limit, we need to double check here.
|
||||
|
|
@ -124,6 +131,16 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
}
|
||||
}
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
bool use_memory_efficient_attention = fused_runner == nullptr &&
|
||||
fused_cross_attention_kernel == nullptr &&
|
||||
!disable_memory_efficient_attention_ &&
|
||||
nullptr == key_padding_mask && // TODO: support 1D mask
|
||||
has_memory_efficient_attention(sm, sizeof(T) == 2);
|
||||
#else
|
||||
constexpr bool use_memory_efficient_attention = false;
|
||||
#endif
|
||||
|
||||
constexpr size_t element_size = sizeof(T);
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
|
||||
parameters.batch_size,
|
||||
|
|
@ -133,7 +150,8 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
parameters.sequence_length,
|
||||
parameters.kv_sequence_length,
|
||||
parameters.total_sequence_length,
|
||||
fused_runner);
|
||||
fused_runner,
|
||||
use_memory_efficient_attention);
|
||||
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
|
@ -152,6 +170,7 @@ Status MultiHeadAttention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.present = nullptr;
|
||||
data.fused_runner = reinterpret_cast<void*>(fused_runner);
|
||||
data.fused_cross_attention_kernel = fused_cross_attention_kernel;
|
||||
data.use_memory_efficient_attention = use_memory_efficient_attention;
|
||||
|
||||
cublasHandle_t cublas = GetCublasHandle(context);
|
||||
return QkvToContext<CudaT>(
|
||||
|
|
|
|||
|
|
@ -24,8 +24,9 @@ class MultiHeadAttention final : public CudaKernel {
|
|||
int num_heads_; // number of attention heads
|
||||
float mask_filter_value_;
|
||||
bool disable_fused_runner_;
|
||||
bool enable_flash_attention_;
|
||||
bool enable_trt_flash_attention_;
|
||||
bool disable_fused_cross_attention_;
|
||||
bool disable_memory_efficient_attention_;
|
||||
mutable std::unique_ptr<MHARunner> fused_fp16_runner_;
|
||||
mutable const FusedMultiHeadCrossAttentionKernel* fused_fp16_cross_attention_kernel_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -174,6 +174,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
Tensor* present = context->Output(1, present_shape);
|
||||
|
||||
void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up
|
||||
bool use_memory_efficient_attention = false;
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
|
||||
batch_size,
|
||||
parameters.num_heads,
|
||||
|
|
@ -182,7 +183,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
sequence_length,
|
||||
parameters.kv_sequence_length,
|
||||
parameters.total_sequence_length,
|
||||
fused_runner);
|
||||
fused_runner,
|
||||
use_memory_efficient_attention);
|
||||
|
||||
auto work_space = GetScratchBuffer<void>(workSpaceSize, context->GetComputeStream());
|
||||
|
||||
|
|
@ -202,6 +204,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
data.present = (nullptr == present) ? nullptr : reinterpret_cast<CudaT*>(present->MutableData<T>());
|
||||
data.fused_runner = fused_runner;
|
||||
data.fused_cross_attention_kernel = nullptr;
|
||||
data.use_memory_efficient_attention = use_memory_efficient_attention;
|
||||
|
||||
return QkvToContext<CudaT>(GetDeviceProp(), cublas, Stream(context), parameters, data);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -544,7 +544,8 @@ def get_ort_environment_variables():
|
|||
env_names = [
|
||||
"ORT_DISABLE_FUSED_ATTENTION",
|
||||
"ORT_DISABLE_FUSED_CROSS_ATTENTION",
|
||||
"ORT_DISABLE_FLASH_ATTENTION",
|
||||
"ORT_DISABLE_TRT_FLASH_ATTENTION",
|
||||
"ORT_DISABLE_MEMORY_EFFICIENT_ATTENTION",
|
||||
"ORT_TRANSFORMER_OPTIONS",
|
||||
"ORT_CUDA_GEMM_OPTIONS",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -929,7 +929,7 @@ TEST(AttentionTest, Causal_EmptyPastState) {
|
|||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"}}};
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
|
||||
|
|
@ -940,7 +940,7 @@ TEST(AttentionTest, Causal_EmptyPastState) {
|
|||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
|
||||
|
|
@ -951,7 +951,7 @@ TEST(AttentionTest, Causal_EmptyPastState) {
|
|||
{
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"}}};
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, use_float16, is_unidirectional,
|
||||
|
|
|
|||
|
|
@ -1902,6 +1902,392 @@ void GetCrossAttentionData_HeadSize40(AttentionTestData& data) {
|
|||
1.2402344f, 2.2792969f, 0.33398438f, 2.2519531f, 0.67041016f, -0.55957031f, 0.20666504f, 1.3583984f, -1.9716797f, 2.6074219f, 2.2832031f, -2.0546875f, -2.4335938f, 0.53515625f, -0.15100098f, 1.9599609f, -0.51513672f, 0.31030273f, -0.49169922f, 1.4677734f, 2.234375f, 0.87451172f, 0.54736328f, -1.8681641f, -4.2265625f, -0.97509766f, -7.296875f, -1.3486328f, 1.3769531f, -1.8427734f, 3.1601562f, -2.4238281f, -0.82421875f, -2.7324219f, -0.52734375f, 2.2089844f, 0.66796875f, -0.42236328f, -3.03125f, -0.047302246f};
|
||||
}
|
||||
}
|
||||
|
||||
void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d) {
|
||||
data.hidden_size = 64;
|
||||
data.v_hidden_size = 64;
|
||||
data.num_heads = 2;
|
||||
data.batch_size = 2;
|
||||
data.sequence_length = 2;
|
||||
data.kv_sequence_length = 3;
|
||||
|
||||
if (is_mask_1d) {
|
||||
data.mask_type = AttentionMaskType::MASK_1D_KEY_SEQ_LEN;
|
||||
data.key_padding_mask_data = {1, 2};
|
||||
} else {
|
||||
data.mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
|
||||
data.key_padding_mask_data = {1, 0, 0,
|
||||
1, 1, 0};
|
||||
}
|
||||
|
||||
data.skip_kernel_types = {AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
|
||||
AttentionKernelType::AttentionKernel_TrtFusedAttention,
|
||||
AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention};
|
||||
|
||||
{
|
||||
data.query_data = {
|
||||
0.66417468f, -2.82039404f, 1.66603971f, 4.84341049f, -1.63285708f, 3.61133432f,
|
||||
-1.07151258f, -0.41698062f, -1.38491797f, -3.79137778f, 1.34514475f, -2.97253704f,
|
||||
2.12579250f, -0.02954102f, 2.30081463f, 0.21410012f, 1.84038579f, 0.46486610f,
|
||||
-4.49463224f, 0.69027799f, 1.01090157f, 0.04715919f, -1.60957003f, 0.10730582f,
|
||||
-5.77672052f, 0.37593889f, 2.04825425f, -1.00890708f, -3.88195300f, -2.69047785f,
|
||||
1.15699422f, -1.13536406f, -0.42816854f, 3.12039518f, 3.21898699f, -0.51998949f,
|
||||
-4.72336435f, -0.78055519f, -0.72722042f, 3.17147565f, -1.31066322f, -3.09425855f,
|
||||
-3.54743338f, -0.07284085f, 1.10525322f, 1.82087338f, -2.03681397f, -4.27978802f,
|
||||
0.26408362f, 0.58637118f, -2.07128787f, -3.48036027f, -0.03049034f, -1.99293542f,
|
||||
-0.67289937f, 1.17342246f, -4.84998703f, -2.43558168f, 1.16422236f, 0.26511097f,
|
||||
-1.98199308f, -1.86423326f, 1.61366916f, -0.35201707f,
|
||||
-1.43554640f, -1.37493825f, 2.32563400f, -1.31762123f, -1.46716797f, 0.18536982f,
|
||||
0.85819042f, -3.11506653f, -1.25773919f, 1.30177450f, 0.58314162f, -1.72039497f,
|
||||
-4.55264997f, 0.02031951f, -2.83490133f, 2.69835496f, -0.07102034f, -2.05412841f,
|
||||
-1.26518285f, 3.30601740f, -4.54173231f, 0.80148667f, -1.36685658f, -2.26921320f,
|
||||
-0.94192690f, -2.77439642f, 0.43918809f, 1.44727242f, 1.53386545f, 2.67014980f,
|
||||
3.30231142f, -1.60745978f, -1.26032567f, 1.27801156f, 0.31288767f, 3.04471421f,
|
||||
-1.09798527f, -2.76303077f, -1.68329728f, -4.78179169f, -0.86371553f, -1.57159030f,
|
||||
-1.06435764f, 3.61700702f, 0.71459293f, -0.25048330f, 1.31865597f, -1.83117080f,
|
||||
-1.10344386f, 2.94894052f, -1.33930528f, 1.94855583f, -1.94283628f, -0.64020038f,
|
||||
2.24100995f, 1.06447530f, -0.03809617f, 3.47241497f, -2.55227089f, 0.12048072f,
|
||||
2.88777542f, -1.73300576f, 3.10077643f, -0.37158102f,
|
||||
|
||||
-0.76705527f, -1.27237630f, 3.55744553f, 0.84103155f, -2.37726879f, 0.20218298f,
|
||||
-3.41723180f, 1.26160014f, 1.45791709f, -1.47226799f, -2.36974764f, 1.49916458f,
|
||||
1.68845606f, -1.33727181f, -2.18113089f, -0.64312577f, -1.06002951f, -0.98938328f,
|
||||
1.95285964f, 3.08321524f, 1.28492856f, 2.28907299f, 1.14324796f, -0.11273877f,
|
||||
-5.96574259f, -1.80337310f, 3.86340094f, -2.42390299f, -1.29642844f, 0.14276078f,
|
||||
-1.23373103f, -0.51519167f, -1.04046988f, 0.60624832f, -0.93274558f, 2.46919179f,
|
||||
-0.58201206f, -3.43382907f, 1.63227773f, 1.92112875f, -0.17216301f, 2.79771209f,
|
||||
2.67759442f, 1.73900354f, -0.00557053f, -0.63086307f, -0.37115061f, 0.82691956f,
|
||||
1.81370568f, -0.48766607f, -1.05545425f, -2.79009533f, -7.64374399f, -2.65407372f,
|
||||
-0.84429693f, 1.35677493f, -1.25277543f, 2.26928639f, -1.77852845f, 2.31752825f,
|
||||
-1.28869593f, -2.97340727f, -2.87103486f, 2.17401385f,
|
||||
0.20970306f, -1.19119942f, 1.11263359f, 0.21227169f, -5.30872822f, -2.15851903f,
|
||||
0.63067430f, -0.49583313f, 3.05784941f, 0.09588236f, 0.76925617f, 1.18900692f,
|
||||
0.35771871f, -0.97235727f, 1.14949071f, -1.25595427f, 2.37192512f, -0.32522821f,
|
||||
1.42988098f, -0.38017935f, 2.49831486f, -0.30629224f, 1.08675146f, -1.02598715f,
|
||||
-0.17971759f, -0.55683851f, 1.04535389f, 1.54741859f, -0.05179391f, 0.73957652f,
|
||||
0.54304504f, 1.95280874f, -1.19504929f, -1.19528544f, 1.33258319f, 0.13532166f,
|
||||
-1.87509251f, 0.99605685f, 2.69439840f, 1.03421521f, 1.79539657f, 0.15001571f,
|
||||
0.55184591f, -0.84038037f, -2.08177447f, -1.43082356f, -1.52199960f, 1.69448102f,
|
||||
2.12475252f, -2.64191580f, 0.10776700f, -4.01538181f, 1.15558016f, -0.09849232f,
|
||||
0.33533198f, 3.34633803f, -2.89805937f, -2.51580763f, 0.94939411f, 1.36254668f,
|
||||
0.47172806f, 4.40817642f, -0.11368597f, -2.70789719f};
|
||||
}
|
||||
{
|
||||
data.key_data = {
|
||||
1.18319833f, -0.20700163f, -0.64873743f, 3.88316822f, -2.82827115f, 4.12166834f, 0.84225285f, -1.11044288f,
|
||||
-1.75086212f, -1.66724730f, 2.22730064f, -3.22617316f, -0.14071584f, 0.58066225f, 3.04375815f, -1.43881261f,
|
||||
-2.39294887f, 1.03637624f, -0.98744214f, 1.13576865f, -0.23876363f, 0.27395499f, -0.51450062f, -2.23614597f,
|
||||
-2.12345290f, -0.68864471f, 2.56223369f, -1.14069867f, -2.14457107f, -1.32647824f, -1.20575166f, -0.98427975f,
|
||||
0.43083039f, -1.72496212f, 0.89925444f, -0.33879194f, -1.01836991f, 0.06260723f, -4.40405083f, 1.51136112f,
|
||||
-1.57057071f, -2.49242449f, -0.37187487f, -3.55319405f, 1.50083232f, 0.37271553f, 1.00157571f, -0.50416815f,
|
||||
1.28753221f, -0.82453167f, -1.13294256f, -1.49514699f, 0.11243388f, 1.89696264f, -1.46173263f, 3.32755566f,
|
||||
-0.54521537f, -2.61305809f, -0.43132567f, -0.33066380f, -0.47485363f, 3.62707257f, -0.61352783f, 2.21147466f,
|
||||
-2.39673638f, 0.89925957f, -2.58643913f, -0.81968069f, 3.34945726f, 0.73745269f, -1.62732553f, -4.55126476f,
|
||||
2.78017616f, 0.33757699f, 2.50468874f, -4.14928627f, 0.20017165f, 3.62233806f, -4.17984772f, 2.60447359f,
|
||||
2.16826940f, 1.70457518f, 1.03199887f, 2.66712570f, 0.50808340f, -3.47132921f, -2.60008478f, 1.03852415f,
|
||||
-0.53876096f, 3.36212158f, -5.49142551f, 1.69825470f, -2.98179603f, -3.39561105f, -2.33971524f, 1.23642313f,
|
||||
2.13283253f, -0.56307364f, -2.49120903f, 2.97641850f, -1.28758216f, 3.43342829f, 2.49575281f, 0.09292871f,
|
||||
-0.46469527f, -3.95696974f, 2.16474032f, -2.15254521f, -2.24547267f, 2.34235692f, -1.02470589f, 3.97816467f,
|
||||
3.60425544f, 1.87994969f, -2.46964216f, 1.47802746f, -1.81441534f, -1.56946301f, 0.56189334f, -1.69905055f,
|
||||
-1.83049631f, 4.64296293f, 3.36173010f, 1.17065477f, 0.62365234f, 1.23748016f, 0.63865232f, -2.90434527f,
|
||||
1.80253839f, 3.11227179f, -3.96782875f, -2.78780794f, 3.76587057f, -1.66908360f, 1.83301187f, -1.74414611f,
|
||||
-2.83874130f, -2.00238085f, -6.45539570f, 0.56152177f, 2.52830791f, -4.32480669f, 1.40038610f, 0.83278954f,
|
||||
0.16065764f, -0.13457650f, 2.17216778f, -4.28218699f, 0.75475001f, -0.67497885f, -0.95346600f, 3.29623652f,
|
||||
1.84325528f, 1.18348145f, -0.23741919f, 2.49520302f, 0.88820332f, 1.15528166f, 0.75733638f, 2.09371948f,
|
||||
-1.16427231f, 1.36415648f, -1.17721760f, 0.19180456f, -3.83617687f, -0.22694540f, 5.14728260f, -0.43242604f,
|
||||
-2.59039426f, -1.40904129f, 0.58194822f, -2.59625196f, -3.60205126f, 1.45633197f, 3.66319609f, -4.45727873f,
|
||||
3.95457315f, -0.17875004f, 2.43404126f, 2.83592010f, 0.87342203f, 1.24538708f, 3.10003138f, 2.63025975f,
|
||||
4.57258415f, -5.20645714f, -2.55821514f, 0.60136455f, -4.13579988f, -2.04082966f, 2.21142578f, -1.05740535f,
|
||||
|
||||
1.78609943f, -3.10438013f, -0.13040465f, -3.02957106f, 0.91924584f, 0.45405358f, -1.90627027f, -1.05065346f,
|
||||
-1.21743047f, -1.65989709f, -0.51138550f, 2.04327297f, 0.65217698f, 0.77914226f, 1.86315429f, 0.75791669f,
|
||||
-0.55304748f, -1.23857486f, 2.63207936f, -0.51371288f, 5.48993397f, -2.35509205f, -2.30255723f, 3.88706803f,
|
||||
-1.93575382f, 0.03364474f, -1.61156952f, -2.74172544f, 1.64667726f, 0.04652762f, 2.88130736f, -2.00066185f,
|
||||
0.74907655f, -3.35894132f, -1.85703170f, 1.78695405f, 0.16497552f, 0.94382036f, 3.04452896f, -4.42404556f,
|
||||
-1.67239439f, 0.93356639f, 0.08288032f, -0.11422639f, -3.94759631f, 0.35302341f, -1.20778334f, -1.92491865f,
|
||||
-1.86599326f, -1.29324412f, -1.12795746f, 0.24268979f, -0.50242394f, 2.26449108f, 0.91289425f, -2.48235416f,
|
||||
-1.12685704f, -0.32806787f, 3.28139257f, 3.19231367f, 0.99441254f, -1.86975384f, -3.57600951f, 0.07424650f,
|
||||
-0.45312887f, 5.02197504f, -3.93365264f, -3.30742884f, -1.48101401f, 1.03335130f, 2.79531693f, -3.71739435f,
|
||||
1.58574414f, -4.52857542f, 1.99908066f, 1.53755212f, 1.60631371f, -2.46801257f, -1.85840714f, 5.07508087f,
|
||||
1.69143867f, -1.04688716f, -3.17096090f, -4.08357859f, -0.02436948f, -1.26299214f, 1.55509603f, 3.11954260f,
|
||||
3.55844116f, 0.10080734f, -0.57031679f, 2.01342750f, -0.66671205f, -1.89724469f, 2.52388906f, 3.71421099f,
|
||||
0.77953398f, -1.63364959f, -1.90900147f, -3.60591793f, 1.17604601f, -1.69456589f, -1.62096381f, -1.44886708f,
|
||||
-1.09821022f, -1.27646899f, 2.73696446f, -2.21802664f, -0.22022307f, 1.76918471f, -1.55524099f, 0.27310926f,
|
||||
-0.56175643f, -0.59620953f, 2.34752941f, -0.74946308f, -2.33520174f, 1.37984359f, -1.82466078f, -0.04973821f,
|
||||
-4.77387571f, -0.85034770f, 3.39579129f, -2.82413197f, -2.37980723f, 0.10482252f, 0.10614476f, 0.38176090f,
|
||||
-0.03948998f, -3.33898020f, 0.33013302f, -0.24926627f, 1.82249093f, 0.57584983f, -0.68790460f, -0.62760007f,
|
||||
0.17052543f, -0.54540014f, 1.66043472f, -0.29917845f, 3.31803465f, 0.86704284f, -0.26854402f, 2.23795938f,
|
||||
-0.65058500f, -2.01540327f, -2.32472515f, -2.85143948f, -3.76564598f, -0.25596800f, -2.08064461f, -0.60812098f,
|
||||
3.64154029f, -2.58636141f, -0.25312662f, -2.22530699f, -1.24763203f, -3.08458424f, 0.69228125f, -1.84211481f,
|
||||
1.09744453f, -1.35679579f, 1.68044925f, 0.89537722f, 3.56465936f, -0.64790231f, -1.42140329f, -2.85126376f,
|
||||
0.88302374f, -0.77923191f, -0.61865216f, -3.08081675f, 0.87791818f, -0.27943787f, 0.46918952f, 1.50163293f,
|
||||
3.43236423f, 1.99953759f, -2.42805409f, 4.97383118f, -2.13942194f, 1.45409000f, -1.14207470f, 0.63804722f,
|
||||
-4.23801470f, 1.23076391f, 2.71176004f, 1.13607812f, 2.27742863f, 1.64165723f, 1.20048785f, -0.66269439f};
|
||||
}
|
||||
|
||||
{
|
||||
data.value_data = {
|
||||
2.52855659f, 1.00436294f, 0.83871710f, 0.97005701f, 1.33615291f, -2.07353282f, 0.14190522f, -1.42923164f,
|
||||
-0.05781263f, -3.81081843f, 1.15263164f, 0.62601233f, -0.93824124f, 1.21525323f, -0.17992918f, 2.08717370f,
|
||||
3.61659431f, -0.16836943f, 2.17779160f, -0.63968349f, 0.32170480f, 1.74428463f, -0.46570981f, -0.07432288f,
|
||||
-0.21569058f, 0.65559602f, 3.58669281f, 0.40837619f, 2.40912223f, 1.31780922f, -4.45945454f, 0.64903581f,
|
||||
-1.10752177f, -1.79390311f, 0.89312351f, -1.84512544f, -1.13948750f, 3.87221098f, -2.74163318f, 2.90849519f,
|
||||
-0.31782085f, 3.12108278f, 0.80056298f, 1.02164125f, -0.07995117f, -0.96148860f, 3.49803638f, -4.48321056f,
|
||||
-1.50024915f, -2.58987570f, 0.61711067f, 4.13532829f, -4.38111591f, -2.48988461f, -0.43977243f, -3.93134618f,
|
||||
-2.67314148f, 2.64455128f, 0.11041284f, 1.26786041f, -0.24446392f, -0.86178148f, 2.35680771f, -1.69236851f,
|
||||
-1.22143269f, 1.99185669f, 2.99625540f, -2.32311869f, -2.26162481f, 3.13980794f, 0.37014920f, 3.22335911f,
|
||||
2.55935216f, 2.19479871f, 4.89236355f, 1.76135564f, -2.74285603f, 1.39842391f, -0.25135490f, -4.76257038f,
|
||||
-0.80362052f, -1.75548995f, -4.70487833f, 1.72763062f, 3.14491320f, 3.97562551f, -0.64091396f, -0.49683607f,
|
||||
1.09094775f, -0.04886785f, -0.20181555f, 2.22182846f, 3.00734067f, -0.52149582f, -1.55592132f, 4.41542721f,
|
||||
4.68795204f, -1.03364658f, 1.12266266f, -1.50595415f, -4.82583904f, -0.65535200f, -1.44525290f, -0.24540535f,
|
||||
-0.44778955f, 2.32284093f, 1.60033488f, 0.12583408f, -4.42107201f, -1.32412672f, -1.84733653f, -1.53440499f,
|
||||
3.21279287f, -0.37051341f, 0.26685789f, 2.25037003f, 0.01608747f, 1.66141725f, -0.53394145f, 1.35017800f,
|
||||
1.35997009f, -2.73341703f, 5.47488451f, 5.49519920f, -1.90401053f, 3.37626982f, -1.97467375f, 1.91208827f,
|
||||
-0.39609963f, -3.46037388f, -1.47946858f, 3.59935665f, 2.36377144f, -2.32310963f, 1.95714176f, -3.10615826f,
|
||||
-1.72878003f, 0.37169266f, -5.95610952f, -1.32819366f, -1.24326205f, 0.17746472f, 2.59834385f, 1.83808351f,
|
||||
2.94952321f, 3.01939392f, 1.37281823f, 2.67180538f, -0.32547897f, 1.11373281f, -0.26456773f, 0.30103314f,
|
||||
-1.05465972f, -1.74858260f, 4.66243505f, -0.58474910f, 1.26216507f, 1.28856802f, 0.30135399f, -3.24127388f,
|
||||
1.57217860f, -3.84659171f, 1.52000761f, -0.57999939f, 7.80852032f, 2.83661318f, -1.72516418f, 0.70036685f,
|
||||
5.33224869f, 3.27205563f, 0.22613347f, 1.27628899f, 0.63828707f, 0.60137266f, 2.23047280f, -3.12771320f,
|
||||
-0.03023779f, 0.80765182f, -2.25078392f, -2.55701947f, -1.01789987f, -4.81986141f, 5.08153057f, -1.74439597f,
|
||||
-2.12658811f, -0.01458025f, -2.19556737f, 0.66254830f, -0.97602153f, -0.09858370f, -2.05090475f, -3.57909155f,
|
||||
|
||||
4.57896709f, -1.96923888f, -3.86827421f, 3.18770289f, -5.16361237f, 1.42594528f, -1.43490076f, 1.62748218f,
|
||||
0.91413617f, -0.27147734f, 0.89311242f, 0.39315015f, 1.18184900f, 4.30172014f, -2.32771754f, 1.61144018f,
|
||||
1.31702828f, 1.47999883f, -0.20565452f, 0.75846130f, -0.13237280f, -2.10059071f, 0.12025893f, -0.58277643f,
|
||||
1.93927395f, -3.11170292f, 0.84666562f, 0.08490577f, -0.36315954f, -3.13071823f, 0.12070303f, -0.10385191f,
|
||||
-2.37523723f, 2.28944397f, 0.12518460f, -1.10043252f, -1.94665289f, 3.44240570f, 1.14374518f, 3.27769613f,
|
||||
1.40222466f, 0.68902296f, 2.48193359f, 1.85469973f, 0.53099388f, -2.16307211f, 0.67865700f, -0.05084896f,
|
||||
0.09825261f, 1.40057099f, -0.74452353f, 0.81515837f, 1.51540780f, -1.30754757f, -1.50317430f, -2.04524612f,
|
||||
-0.49154273f, 0.75809133f, -0.25134420f, 0.36961895f, -0.01882899f, -1.72547066f, 1.12012851f, -6.72828960f,
|
||||
1.76177442f, 1.19128907f, -0.77717477f, -1.97159290f, -2.30860472f, 2.01583147f, 5.43375349f, 2.58655977f,
|
||||
0.71099019f, 0.71843386f, 3.10709906f, 1.48128355f, 0.22561067f, -4.27442265f, -2.49249840f, 4.71605539f,
|
||||
2.19818974f, -1.96133125f, 0.41619009f, 0.66834581f, -3.74457240f, -0.48215276f, -1.28305256f, -1.83142948f,
|
||||
-0.72452945f, -1.97440028f, -0.14068973f, 0.11765432f, 0.49793118f, 0.40227121f, -1.34390569f, 0.92099732f,
|
||||
-1.21718168f, -1.95382285f, 1.37468243f, -0.72062874f, 2.66714525f, 1.06695974f, -2.86761045f, 1.34743905f,
|
||||
3.30500460f, -0.91894615f, -0.09608981f, -4.09408808f, -2.57941151f, -0.36501098f, 1.93333972f, 1.54577386f,
|
||||
-2.96415496f, -2.09494066f, 1.63500857f, -1.51829720f, -0.98314112f, -1.89401948f, -0.54314089f, -3.68928242f,
|
||||
1.07439506f, 1.70869648f, 0.86973846f, 1.71959770f, 1.78241849f, -4.29455566f, -1.55857742f, -3.32966399f,
|
||||
0.20903873f, 1.40176547f, -6.08825064f, 2.12755013f, 3.84799123f, -0.83979988f, -1.64312506f, -0.69876713f,
|
||||
4.00779629f, -2.85212469f, 0.09145057f, 1.72984874f, -0.77233994f, 1.21815240f, -1.75377214f, 4.08561277f,
|
||||
-1.20909250f, -1.24881196f, 4.37579060f, 4.27434301f, -2.01065826f, 2.96602201f, 3.07406378f, 1.22374272f,
|
||||
0.06376281f, -1.60328245f, -1.32239270f, 1.00765312f, 1.27593243f, -2.14843464f, -3.47884607f, -0.32401958f,
|
||||
-2.52805567f, -1.01782882f, 0.74270618f, 1.47170806f, -2.56010485f, -1.49985540f, 0.92767721f, 3.42378139f,
|
||||
5.23711205f, 0.47062784f, -0.26747131f, -2.06014609f, -0.20237172f, -1.60944867f, -2.51956654f, 0.59529293f,
|
||||
2.63805699f, 0.43868792f, -5.84081888f, 3.25271368f, -4.44406748f, -3.80642724f, -1.59846020f, -2.59634686f,
|
||||
0.11074528f, 2.04441738f, -1.51878321f, -2.59639883f, 2.23697233f, 0.07920718f, 1.31056094f, -8.10540771f};
|
||||
}
|
||||
|
||||
{
|
||||
data.bias_data = {
|
||||
-0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f, -0.34948170f, -0.19608477f, 0.19725692f, 0.39987487f,
|
||||
0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f, -0.23178342f, -0.13692456f, -0.04889601f, 0.48739988f,
|
||||
-0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f, 0.48096961f, -0.47906545f, 0.43613154f, -0.23511401f,
|
||||
-0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f, -0.45826656f, 0.23286396f, -0.43407962f, 0.40754890f,
|
||||
0.23778325f, 0.34850210f, -0.01385659f, 0.32141626f, -0.27738628f, 0.27683002f, 0.31886810f, -0.24781504f,
|
||||
-0.25476855f, -0.46742713f, -0.12478521f, 0.39731556f, -0.12087554f, 0.40822440f, 0.13202906f, -0.23747686f,
|
||||
0.30502868f, 0.27182943f, -0.03640261f, -0.39626551f, -0.22411832f, 0.17324352f, -0.49959660f, -0.49318257f,
|
||||
0.31363028f, 0.05469471f, -0.00390345f, -0.46100286f, -0.27253938f, 0.17251462f, 0.46564627f, 0.21038425f,
|
||||
0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f, 0.27376485f, -0.38174152f, -0.43700469f, 0.38040614f,
|
||||
-0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f, 0.07120579f, -0.08055863f, 0.12095112f, -0.27988660f,
|
||||
0.06004709f, -0.05600315f, -0.25510073f, 0.41887105f, -0.19016314f, 0.47241372f, 0.12890404f, -0.24272856f,
|
||||
0.21106839f, -0.40523255f, 0.10336459f, -0.11084765f, 0.42408967f, -0.15285304f, -0.28945464f, -0.25714916f,
|
||||
0.40978593f, -0.09138483f, -0.02013114f, -0.39042589f, -0.19557095f, 0.07540411f, 0.33955890f, 0.41873980f,
|
||||
-0.27744853f, -0.33097768f, -0.44587523f, -0.01648277f, 0.34952271f, -0.48838940f, -0.17273578f, 0.37286615f,
|
||||
-0.10157353f, -0.08097187f, 0.23243034f, 0.25516337f, -0.45793599f, 0.08089012f, 0.17673731f, 0.03000754f,
|
||||
0.48834521f, 0.35069120f, -0.32989410f, 0.20729345f, 0.24406803f, 0.35393929f, -0.16146761f, 0.04258209f,
|
||||
-0.10567203f, 0.26791072f, -0.08976898f, 0.31341976f, 0.06027532f, 0.14307594f, 0.31587386f, 0.16180152f,
|
||||
0.34785229f, 0.00531715f, -0.35168743f, -0.11641458f, 0.39196932f, 0.44535065f, 0.43545735f, 0.15593112f,
|
||||
0.06171834f, -0.42181283f, -0.41170910f, 0.40969193f, -0.01510030f, 0.07973170f, -0.18156880f, 0.21522856f,
|
||||
0.03915739f, -0.20913908f, -0.47068381f, 0.35633272f, -0.35124153f, 0.36624825f, -0.05567622f, -0.35343069f,
|
||||
0.12821168f, 0.35526341f, -0.23420528f, -0.46328634f, -0.21994811f, -0.27556795f, 0.01653767f, 0.42626363f,
|
||||
0.23239774f, 0.39632857f, 0.32416028f, -0.48494491f, -0.05365932f, -0.10860911f, 0.06893444f, 0.46116674f,
|
||||
0.34345043f, -0.02719739f, -0.39574289f, -0.39339882f, 0.23044002f, -0.06155324f, 0.23292047f, 0.39775699f,
|
||||
0.12789404f, -0.44719657f, 0.12020230f, 0.26871282f, -0.10917315f, -0.29244915f, 0.09059817f, -0.19613290f};
|
||||
}
|
||||
|
||||
{
|
||||
data.fp32_output_data = {
|
||||
2.42288446f, 1.27227366f, 0.74894810f, 1.28347683f, 1.39642823f, -1.93045688f, 0.45777908f, -1.26743007f,
|
||||
0.29003966f, -3.80550122f, 0.80094421f, 0.50959778f, -0.54627192f, 1.66060388f, 0.25552815f, 2.24310493f,
|
||||
3.67831278f, -0.59018224f, 1.76608253f, -0.22999156f, 0.30660450f, 1.82401633f, -0.64727861f, 0.14090568f,
|
||||
-0.17653319f, 0.44645694f, 3.11600900f, 0.76470888f, 2.05788064f, 1.68405747f, -4.51513100f, 0.29560512f,
|
||||
-0.97931010f, -1.43863964f, 0.65891826f, -2.30841184f, -1.35943556f, 3.59664297f, -2.72509551f, 3.33475876f,
|
||||
-0.08542311f, 3.51741123f, 1.12472320f, 0.53669631f, -0.13361049f, -1.07009768f, 3.56697083f, -4.02204370f,
|
||||
-1.15679872f, -2.61707306f, 0.22136778f, 3.74192953f, -4.15067577f, -2.55143785f, -0.20685196f, -3.53358912f,
|
||||
-2.54524755f, 2.19735479f, 0.23061514f, 1.53657317f, -0.35363707f, -1.15423059f, 2.44740582f, -1.88850141f,
|
||||
|
||||
2.42288446f, 1.27227366f, 0.74894810f, 1.28347683f, 1.39642823f, -1.93045688f, 0.45777908f, -1.26743007f,
|
||||
0.29003966f, -3.80550122f, 0.80094421f, 0.50959778f, -0.54627192f, 1.66060388f, 0.25552815f, 2.24310493f,
|
||||
3.67831278f, -0.59018224f, 1.76608253f, -0.22999156f, 0.30660450f, 1.82401633f, -0.64727861f, 0.14090568f,
|
||||
-0.17653319f, 0.44645694f, 3.11600900f, 0.76470888f, 2.05788064f, 1.68405747f, -4.51513100f, 0.29560512f,
|
||||
-0.97931010f, -1.43863964f, 0.65891826f, -2.30841184f, -1.35943556f, 3.59664297f, -2.72509551f, 3.33475876f,
|
||||
-0.08542311f, 3.51741123f, 1.12472320f, 0.53669631f, -0.13361049f, -1.07009768f, 3.56697083f, -4.02204370f,
|
||||
-1.15679872f, -2.61707306f, 0.22136778f, 3.74192953f, -4.15067577f, -2.55143785f, -0.20685196f, -3.53358912f,
|
||||
-2.54524755f, 2.19735479f, 0.23061514f, 1.53657317f, -0.35363707f, -1.15423059f, 2.44740582f, -1.88850141f,
|
||||
|
||||
4.47329473f, -1.70132744f, -3.95804238f, 3.50112128f, -5.10333633f, 1.56902146f, -1.11902511f, 1.78928399f,
|
||||
1.26198828f, -0.26615992f, 0.54142559f, 0.27673587f, 1.57381809f, 4.74706888f, -1.89226031f, 1.76737213f,
|
||||
1.37874687f, 1.05818522f, -0.61736351f, 1.16815329f, -0.14747408f, -2.02085853f, -0.06131025f, -0.36754823f,
|
||||
1.97843063f, -3.32084179f, 0.37598154f, 0.44123849f, -0.71440083f, -2.76446915f, 0.06502641f, -0.45728233f,
|
||||
-1.93884647f, 1.51549935f, 0.22349268f, -1.46264625f, -0.93878794f, 2.53468966f, 0.09279048f, 3.19028425f,
|
||||
2.14098549f, 0.65744257f, 2.12003636f, -0.21332240f, -0.35039914f, -1.79318547f, 1.08148456f, 0.83520722f,
|
||||
-0.37325758f, 0.44315636f, -0.50703102f, -0.19921407f, 1.08093989f, -1.52517128f, -1.01477206f, -2.08499599f,
|
||||
0.05307493f, 0.56386751f, 0.16719794f, 0.99758488f, 0.35134155f, -2.70159864f, 0.49787593f, -6.01998806f,
|
||||
|
||||
1.88393891f, 1.20359635f, -1.11693203f, -1.24092197f, -2.47922421f, 2.11120105f, 5.19413376f, 2.67079711f,
|
||||
1.07527149f, 0.64369327f, 2.57635832f, 1.27686763f, 0.69491446f, -3.13548803f, -2.04371452f, 4.62090492f,
|
||||
2.18864536f, -2.10483122f, -0.04580984f, 1.08532572f, -3.46754074f, -0.53330994f, -1.35113037f, -1.51521778f,
|
||||
-0.46994060f, -2.27551699f, -0.53152251f, 0.47133854f, 0.07705012f, 0.48279381f, -1.28113365f, 0.48468336f,
|
||||
-1.54411674f, 0.06915778f, 0.64939111f, -1.33318806f, 0.63385141f, 1.72500539f, -1.27450287f, 2.53234506f,
|
||||
2.78955889f, 0.10935718f, 1.24130249f, -2.24100065f, -1.41059852f, -1.18030620f, 1.50915027f, 1.37942517f,
|
||||
-1.41709673f, -0.74830860f, 0.30404601f, -0.99458563f, 0.22929534f, -1.72507358f, -0.68753922f, -2.64537501f,
|
||||
0.58683372f, 0.88788664f, 0.54932535f, 1.45773280f, 0.96530700f, -3.57728553f, -0.41517627f, -4.86154747f};
|
||||
}
|
||||
|
||||
{
|
||||
data.fp16_output_data = data.fp32_output_data;
|
||||
}
|
||||
}
|
||||
|
||||
void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data) {
|
||||
data.hidden_size = 32;
|
||||
data.v_hidden_size = 32;
|
||||
data.num_heads = 1;
|
||||
data.batch_size = 2;
|
||||
data.sequence_length = 2;
|
||||
data.kv_sequence_length = 3;
|
||||
data.mask_type = AttentionMaskType::MASK_2D_KEY_PADDING;
|
||||
data.key_padding_mask_data = {0, 1, 1, // first key sequence has one padding on the left
|
||||
0, 0, 1}; // second key sequence has two paddings on the left
|
||||
|
||||
data.skip_kernel_types = {
|
||||
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
|
||||
AttentionKernelType::AttentionKernel_TrtFusedAttention,
|
||||
AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention
|
||||
};
|
||||
|
||||
{
|
||||
data.query_data = {
|
||||
2.88765883f, 1.27536213f, -0.57580215f, 2.73696542f, 2.19016314f, 0.42629790f, 1.55081677f, -2.01307678f,
|
||||
-0.80203497f, -1.23206115f, 1.78565156f, -2.09875321f, -2.22730732f, -0.98120236f, -0.25774139f, 0.75868356f,
|
||||
-2.87585187f, -0.41810805f, -2.11528730f, 0.50642025f, -0.29446256f, -3.69675803f, -2.73721838f, -1.51089072f,
|
||||
0.74300194f, 0.27352047f, -0.88251829f, 2.82622814f, 0.73837662f, -2.14588642f, 0.37608737f, -0.06190044f,
|
||||
-1.97659302f, -2.22348428f, 2.25573063f, -2.24459195f, -2.28073978f, -0.52412349f, -0.57297325f, 3.29259396f,
|
||||
1.35617173f, -0.83082151f, 0.03767079f, 1.82568312f, 0.88193995f, 1.15579486f, 1.87845564f, -0.15923920f,
|
||||
2.37435389f, 1.49093378f, 1.95134592f, -1.67609048f, -0.45959851f, 1.63960719f, 3.44909906f, -0.23531833f,
|
||||
-0.57074630f, 1.38279045f, 0.58870834f, 0.85297751f, -1.44973445f, 1.56243801f, -0.67229253f, -0.16198707f,
|
||||
|
||||
-0.23966503f, -0.15329531f, -3.22765136f, 0.60538405f, -0.33244422f, -1.34865439f, -0.24373266f, -1.78808010f,
|
||||
-1.53090763f, 1.75037694f, -0.71890754f, 0.12527336f, 1.26654553f, -0.86477917f, -1.49822962f, 1.67973542f,
|
||||
0.99763191f, -0.07183220f, 1.55289185f, 1.62626481f, -0.04283767f, -2.55072594f, -1.95238030f, 0.60994428f,
|
||||
-2.53714681f, 1.54605150f, 0.05900350f, 1.42194426f, 0.33801061f, 1.25557244f, 0.67291188f, -1.36867523f,
|
||||
1.86936152f, -1.19588101f, 0.75778806f, 1.85271311f, 0.02081686f, 2.65807819f, 0.78890860f, -1.07388866f,
|
||||
4.18109226f, 0.06373940f, 2.86840463f, 0.90427721f, -0.09531648f, -0.40835506f, 1.60812938f, -1.61683714f,
|
||||
-0.45421624f, -2.25537109f, -1.35910070f, -0.25111723f, -0.71782172f, 0.62597942f, -0.42838976f, 0.23198499f,
|
||||
1.29250073f, -2.01550317f, 0.14619158f, -0.03868395f, -0.74211842f, -3.17291188f, -1.90475547f, 2.02544284f};
|
||||
}
|
||||
{
|
||||
data.key_data = {
|
||||
1.14242256f, 1.08148384f, -0.00962424f, -1.62719429f, 0.86478198f, 0.16862091f, 1.01692820f, -1.15278327f,
|
||||
-1.13622630f, 1.78038371f, 0.58222097f, 0.39166588f, 1.75063372f, -1.20408881f, 0.75154918f, 0.58156419f,
|
||||
-0.98975772f, -0.82555556f, -0.72656512f, -2.42399549f, 2.19217968f, 2.18518472f, -1.72216129f, 1.35098433f,
|
||||
-0.34989786f, -0.69064844f, -0.98365444f, 3.10148478f, 0.64813483f, 1.78129303f, -0.47006512f, 2.53122735f,
|
||||
0.09757380f, 0.04077591f, -0.81791472f, -0.19737752f, 1.13775492f, -1.51351953f, 0.59109330f, 2.86624002f,
|
||||
-0.09282493f, -1.69204521f, 1.27087700f, 3.53944731f, 0.59776509f, -0.90838081f, -0.15813766f, -1.86199224f,
|
||||
0.18734205f, -0.76110429f, -0.02243887f, -0.94068182f, 1.32443166f, 0.03512055f, -0.13194422f, -1.50401211f,
|
||||
0.92001319f, 0.20918207f, -1.34839189f, 1.56431675f, -0.61030018f, 2.39562368f, -1.56722510f, -0.96874726f,
|
||||
-0.48726845f, -1.41476154f, -1.45116997f, 0.53907454f, -2.14415288f, 1.14340270f, -0.21846619f, -2.72349358f,
|
||||
2.99664998f, -2.38684058f, 0.95269018f, 0.04208702f, -1.75080788f, 1.24652982f, -1.76879966f, 3.10814905f,
|
||||
2.48754454f, -0.62601894f, 1.41356945f, 0.10340121f, 1.09059846f, -0.78241473f, -0.61477584f, -0.19339988f,
|
||||
-0.48253334f, -2.41782594f, 1.04690075f, 0.14725411f, -0.20820639f, -1.95920563f, 0.96303236f, -1.20068836f,
|
||||
|
||||
-1.71051037f, -1.90946770f, -2.07985783f, 2.35042953f, 0.35059446f, -0.44228595f, 4.08558750f, -0.60121447f,
|
||||
0.78836018f, 0.35280651f, 0.23129070f, -0.21523762f, 0.12277550f, 0.12348226f, -1.62759030f, -2.78246498f,
|
||||
4.04853964f, 0.29263157f, -0.38621908f, -1.07599223f, -1.99170423f, 1.41409016f, 2.19121861f, -3.53451037f,
|
||||
3.63692737f, 0.68270516f, 2.51469731f, 2.57543731f, -2.39040112f, -3.97164130f, 1.28371549f, 1.64144099f,
|
||||
-0.70385075f, 2.55361128f, 1.60707259f, 0.84735453f, -2.07756495f, -1.99240303f, -3.60991144f, 2.87136865f,
|
||||
2.31296396f, 2.30251813f, -1.05624914f, -2.43777156f, -0.27048296f, 2.39037871f, -2.04504776f, 1.65183067f,
|
||||
-0.38970214f, 0.16808379f, -1.30286717f, 1.90201700f, -2.71696734f, -0.66445369f, 1.27085483f, -0.60816145f,
|
||||
1.81054437f, -1.55584621f, -2.19360781f, -4.52794456f, -0.90534067f, 0.94724411f, 2.40401077f, -2.94815230f,
|
||||
-3.19650269f, 2.50638890f, 1.02038431f, 1.50519919f, 0.47196171f, -1.89026380f, -1.86559379f, 0.82210326f,
|
||||
0.10818237f, 1.45290673f, 1.62321615f, -0.61283481f, -1.42501950f, 2.10349464f, -1.65715265f, 0.30090189f,
|
||||
-3.81919909f, -2.44903922f, -1.20557833f, -0.69951278f, -1.31475580f, -3.73842764f, 1.49299407f, -0.70933276f,
|
||||
-1.49021530f, 0.71776378f, -1.23052382f, -2.13119912f, -1.20718014f, 2.30572701f, 1.78386402f, -1.57122159f};
|
||||
}
|
||||
|
||||
{
|
||||
data.value_data = {
|
||||
1.79297853f, 0.96909231f, 1.23087275f, -0.61933923f, -0.56477690f, 1.47813499f, 0.51474279f, -3.44743419f,
|
||||
0.95816678f, -0.20553169f, -0.76906109f, -4.60927439f, 0.40629998f, 0.91934747f, -1.09594405f, -1.45653892f,
|
||||
-0.59282207f, 0.05621797f, -2.26634383f, -1.30799258f, 1.22072279f, -3.60811162f, 1.70111597f, 0.47336632f,
|
||||
-1.43857694f, -0.13917151f, -1.34617388f, 1.07960105f, -1.77342618f, 0.31946269f, 1.19137061f, 2.59346104f,
|
||||
-1.82395399f, 0.73557752f, 2.32600021f, -0.22650969f, -0.48526058f, 1.40349376f, -0.33553454f, 0.45531431f,
|
||||
0.73859257f, 0.37798560f, 0.85344458f, -1.30447221f, 1.23349071f, -0.26439479f, 1.18636096f, -0.33328748f,
|
||||
-0.50939041f, 0.53500950f, 1.33486223f, -1.54447496f, -2.88690519f, -0.06809106f, -0.00597921f, -1.07510388f,
|
||||
0.62182164f, 0.50033569f, -0.88293070f, 2.56142712f, 0.37708595f, 1.59349704f, -1.17139614f, 0.89580274f,
|
||||
0.69456708f, 2.91441655f, -0.25431669f, -1.20305562f, 2.06701255f, -0.86700624f, -2.23615170f, 0.13303493f,
|
||||
-2.97540593f, 0.08654684f, 1.40381706f, 3.54294443f, -2.07661867f, -1.33181918f, 2.24228764f, 1.79975545f,
|
||||
2.14695477f, 1.40222490f, -0.29813689f, 1.94485068f, 1.99623775f, 1.53450203f, 0.28755581f, -0.67934704f,
|
||||
-0.92102510f, -1.52764773f, 1.11267352f, -3.90122724f, 0.22128634f, 0.14945325f, -4.38529491f, -1.58423281f,
|
||||
|
||||
-2.45574522f, -1.91599977f, 5.05240345f, 2.24617362f, 3.99182248f, 0.92924285f, -0.39660916f, -0.08696688f,
|
||||
0.24855530f, 0.71378094f, 0.92413902f, 1.73599064f, 1.03852975f, 2.44676781f, 0.35013664f, 0.98107171f,
|
||||
1.62946916f, 0.41239718f, -1.41385484f, 2.49293518f, 2.32976985f, 2.89612579f, 2.66875219f, 1.47379971f,
|
||||
1.31164551f, -1.82183075f, -5.15272474f, 0.28575048f, 0.16861364f, -0.47264135f, 0.22565089f, -0.37727535f,
|
||||
-1.13935280f, 0.38051969f, -2.38735437f, -2.80645251f, 0.18637873f, 2.13938355f, 2.92260599f, -0.38653925f,
|
||||
0.58366799f, -1.67636371f, -2.29396892f, -1.31527638f, 2.39795637f, 0.39815575f, -0.98530269f, -1.29227996f,
|
||||
0.14452982f, -0.38186538f, -1.71267688f, 0.18121701f, -2.26441002f, -0.94511753f, 0.27371156f, -2.44858527f,
|
||||
-0.21510160f, -2.65228534f, -2.16755104f, 0.86151361f, 0.77589297f, -1.06628847f, 0.73745233f, 1.15778029f,
|
||||
-0.73659700f, 0.74325305f, -1.97666430f, -1.07301974f, 0.17534591f, -1.66584718f, 1.21820331f, 0.67675018f,
|
||||
-1.08938253f, 1.78010321f, 0.39817584f, -0.02914053f, 1.13571596f, -0.44081455f, 1.70561552f, -2.12085509f,
|
||||
-0.69322622f, -1.87331009f, -2.15000772f, 2.08436966f, 1.70494926f, -3.69169927f, -1.22119129f, -1.60190558f,
|
||||
-2.09093666f, -1.02816033f, -1.78743768f, 2.34501553f, 2.79939008f, 1.82245076f, 1.47408092f, 1.10063124f};
|
||||
}
|
||||
|
||||
{
|
||||
data.bias_data = {
|
||||
-0.38124341f, 0.02696526f, -0.11914945f, -0.43795273f, -0.34948170f, -0.19608477f, 0.19725692f, 0.39987487f,
|
||||
0.04772711f, -0.03419551f, -0.30606642f, 0.42656231f, -0.23178342f, -0.13692456f, -0.04889601f, 0.48739988f,
|
||||
-0.25891554f, 0.13431972f, 0.22861153f, 0.06360734f, 0.48096961f, -0.47906545f, 0.43613154f, -0.23511401f,
|
||||
-0.10595283f, -0.42839217f, 0.28931111f, -0.13180739f, -0.45826656f, 0.23286396f, -0.43407962f, 0.40754890f,
|
||||
0.27079183f, 0.42074734f, -0.40314156f, -0.43726659f, 0.27376485f, -0.38174152f, -0.43700469f, 0.38040614f,
|
||||
-0.40546918f, 0.06927037f, 0.16979086f, 0.41458064f, 0.07120579f, -0.08055863f, 0.12095112f, -0.27988660f,
|
||||
0.06004709f, -0.05600315f, -0.25510073f, 0.41887105f, -0.19016314f, 0.47241372f, 0.12890404f, -0.24272856f,
|
||||
0.21106839f, -0.40523255f, 0.10336459f, -0.11084765f, 0.42408967f, -0.15285304f, -0.28945464f, -0.25714916f,
|
||||
-0.10567203f, 0.26791072f, -0.08976898f, 0.31341976f, 0.06027532f, 0.14307594f, 0.31587386f, 0.16180152f,
|
||||
0.34785229f, 0.00531715f, -0.35168743f, -0.11641458f, 0.39196932f, 0.44535065f, 0.43545735f, 0.15593112f,
|
||||
0.06171834f, -0.42181283f, -0.41170910f, 0.40969193f, -0.01510030f, 0.07973170f, -0.18156880f, 0.21522856f,
|
||||
0.03915739f, -0.20913908f, -0.47068381f, 0.35633272f, -0.35124153f, 0.36624825f, -0.05567622f, -0.35343069f};
|
||||
}
|
||||
|
||||
{
|
||||
data.fp32_output_data = {
|
||||
0.23503941f, 2.87619758f, 0.01845241f, -0.75242990f, 1.76869011f, -0.40492195f, -1.65323853f, 0.34011719f,
|
||||
-2.10573196f, 0.13281155f, 0.97480160f, 2.74546146f, -1.21957457f, -0.73649400f, 2.52938581f, 1.65599120f,
|
||||
1.83545303f, 0.85856718f, -0.48040742f, 1.86428785f, 1.29504943f, 1.38906729f, 0.06474495f, -0.51972288f,
|
||||
-0.66509569f, -1.45185244f, 0.36160457f, -2.63688278f, -0.10806514f, 0.71859169f, -3.98941422f, -1.58921516f,
|
||||
|
||||
-1.89806330f, 1.03079379f, 2.20389438f, 0.07467184f, -0.39299977f, 1.51811528f, -0.04347950f, 0.61307698f,
|
||||
1.03990030f, 0.37965038f, 0.50865448f, -1.36013806f, 1.58397710f, 0.16757873f, 1.63505113f, -0.15062472f,
|
||||
-0.41438234f, 0.12406474f, 0.90268815f, -1.09105420f, -2.84080887f, 0.03172458f, -0.18386938f, -0.85491556f,
|
||||
0.64164376f, 0.26578158f, -1.32860518f, 2.83676863f, 0.02389192f, 1.94164813f, -1.26734924f, 0.51129180f,
|
||||
|
||||
-0.84226906f, 1.01116371f, -2.06643319f, -0.75959998f, 0.23562123f, -1.52277124f, 1.53407717f, 0.83855170f,
|
||||
-0.74153024f, 1.78542042f, 0.04648840f, -0.14555511f, 1.52768528f, 0.00453609f, 2.14107275f, -1.96492398f,
|
||||
-0.63150787f, -2.29512286f, -2.56171679f, 2.49406147f, 1.68984890f, -3.61196756f, -1.40276003f, -1.38667703f,
|
||||
-2.05177927f, -1.23729944f, -2.25812149f, 2.70134830f, 2.44814849f, 2.18869901f, 1.41840470f, 0.74720055f,
|
||||
|
||||
-0.84226906f, 1.01116371f, -2.06643319f, -0.75959998f, 0.23562123f, -1.52277124f, 1.53407717f, 0.83855170f,
|
||||
-0.74153024f, 1.78542042f, 0.04648840f, -0.14555511f, 1.52768528f, 0.00453609f, 2.14107275f, -1.96492398f,
|
||||
-0.63150787f, -2.29512286f, -2.56171679f, 2.49406147f, 1.68984890f, -3.61196756f, -1.40276003f, -1.38667703f,
|
||||
-2.05177927f, -1.23729944f, -2.25812149f, 2.70134830f, 2.44814849f, 2.18869901f, 1.41840470f, 0.74720055f};
|
||||
}
|
||||
|
||||
{
|
||||
data.fp16_output_data = data.fp32_output_data;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) {
|
||||
|
|
@ -1913,9 +2299,8 @@ void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data) {
|
|||
data.kv_sequence_length = 3;
|
||||
data.mask_type = AttentionMaskType::MASK_NONE;
|
||||
data.skip_kernel_types = {
|
||||
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
|
||||
AttentionKernelType::AttentionKernel_TrtFusedAttention
|
||||
};
|
||||
AttentionKernelType::AttentionKernel_TrtFusedCrossAttention,
|
||||
AttentionKernelType::AttentionKernel_TrtFusedAttention};
|
||||
|
||||
{
|
||||
data.query_data = {
|
||||
|
|
|
|||
|
|
@ -28,11 +28,15 @@ struct AttentionTestData{
|
|||
std::vector<AttentionKernelType> skip_kernel_types; // skip some kernels if they do not supported this test case.
|
||||
};
|
||||
|
||||
// Disable some tests in Windows since prefast build might crash with large test data.
|
||||
#ifndef _MSC_VER
|
||||
// Return packed weights and bias for input projection.
|
||||
void GetAttentionWeight(std::vector<float>& weight_data, int elements = 64 * 3 * 64, int offset = 0, int step=1);
|
||||
void GetAttentionBias(std::vector<float>& bias_data, int elements = 3 * 64, int offset = 0, int step=1);
|
||||
|
||||
void GetCrossAttentionData_HeadSize40(AttentionTestData& data);
|
||||
void GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(AttentionTestData& data, bool is_mask_1d);
|
||||
void GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(AttentionTestData& data);
|
||||
#endif
|
||||
|
||||
void GetCrossAttentionData_HeadSize16_8(AttentionTestData& data);
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ static void RunMultiHeadAttentionTest(
|
|||
|
||||
constexpr float rel_error = 0.0f;
|
||||
constexpr float abs_error = 0.05f;
|
||||
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data), /*sort*/false, rel_error, abs_error);
|
||||
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data), /*sort*/ false, rel_error, abs_error);
|
||||
} else {
|
||||
tester.AddInput<float>("query", query_dims, query_data);
|
||||
tester.AddInput<float>("key", key_dims, key_data);
|
||||
|
|
@ -94,7 +94,7 @@ static void RunMultiHeadAttentionTest(
|
|||
|
||||
constexpr float rel_error = 0.0f;
|
||||
constexpr float abs_error = 0.02f;
|
||||
tester.AddOutput<float>("output", output_dims, output_data, /*sort*/false, rel_error, abs_error);
|
||||
tester.AddOutput<float>("output", output_dims, output_data, /*sort*/ false, rel_error, abs_error);
|
||||
}
|
||||
|
||||
if (enable_cuda) {
|
||||
|
|
@ -124,7 +124,7 @@ static void RunMultiHeadAttentionKernel(
|
|||
const std::vector<float>& bias_data, // bias: [hidden_size + hidden_size + v_hidden_size]
|
||||
const std::vector<int32_t>& key_padding_mask_data, // key_padding_mask: see below
|
||||
AttentionMaskType mask_type, // 1 for [batch_size], 2 for [batch_size, kv_sequence_length]
|
||||
const std::vector<float>& output_data, // output: [batch_size, sequence_length, v_hidden_size]
|
||||
const std::vector<float>& output_data, // output: [batch_size, sequence_length, v_hidden_size]
|
||||
int num_heads,
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
|
|
@ -136,13 +136,13 @@ static void RunMultiHeadAttentionKernel(
|
|||
bool disable_cpu = true, // not supported in cpu right now.
|
||||
bool disable_cuda = false,
|
||||
bool disable_rocm = true) {
|
||||
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_Default) {
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}}};
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
|
||||
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
|
||||
|
|
@ -150,13 +150,13 @@ static void RunMultiHeadAttentionKernel(
|
|||
return;
|
||||
}
|
||||
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_Unfused)
|
||||
{
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_Unfused) {
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}}};
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
|
||||
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
|
||||
|
|
@ -164,13 +164,13 @@ static void RunMultiHeadAttentionKernel(
|
|||
return;
|
||||
}
|
||||
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedCrossAttention)
|
||||
{
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedCrossAttention) {
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"}}};
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
|
||||
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
|
||||
|
|
@ -178,13 +178,29 @@ static void RunMultiHeadAttentionKernel(
|
|||
return;
|
||||
}
|
||||
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention)
|
||||
{
|
||||
#if USE_FLASH_ATTENTION
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention) {
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "0"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
|
||||
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
|
||||
use_float16, disable_cpu, disable_cuda, disable_rocm);
|
||||
return;
|
||||
}
|
||||
#endif
|
||||
|
||||
if (kernel_type == AttentionKernelType::AttentionKernel_TrtFusedAttention) {
|
||||
ScopedEnvironmentVariables scoped_env_vars{
|
||||
EnvVarMap{
|
||||
{onnxruntime::contrib::attention::kDisableTrtFlashAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedAttention, "0"},
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"}}};
|
||||
{onnxruntime::contrib::attention::kDisableFusedCrossAttention, "1"},
|
||||
{onnxruntime::contrib::attention::kDisableMemoryEfficientAttention, "1"}}};
|
||||
RunMultiHeadAttentionTest(
|
||||
query_data, key_data, value_data, bias_data, key_padding_mask_data, mask_type, output_data,
|
||||
num_heads, batch_size, sequence_length, kv_sequence_length, hidden_size, v_hidden_size,
|
||||
|
|
@ -204,11 +220,21 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
|
|||
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
|
||||
}
|
||||
|
||||
kernel_type = AttentionKernelType::AttentionKernel_Default;
|
||||
RunMultiHeadAttentionKernel(
|
||||
#if USE_FLASH_ATTENTION
|
||||
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
|
||||
if (!SkipAttentionKernel(data, kernel_type)) {
|
||||
RunMultiHeadAttentionKernel(
|
||||
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
|
||||
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
|
||||
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
|
||||
}
|
||||
#endif
|
||||
|
||||
kernel_type = AttentionKernelType::AttentionKernel_Default;
|
||||
RunMultiHeadAttentionKernel(
|
||||
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
|
||||
data.fp32_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
|
||||
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
|
||||
}
|
||||
|
||||
if (data.fp16_output_data.size() > 0) {
|
||||
|
|
@ -229,6 +255,16 @@ static void RunMultiHeadAttentionTests(AttentionTestData& data) {
|
|||
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
|
||||
}
|
||||
|
||||
#if USE_FLASH_ATTENTION
|
||||
kernel_type = AttentionKernelType::AttentionKernel_CutlassMemoryEfficientAttention;
|
||||
if (!SkipAttentionKernel(data, kernel_type)) {
|
||||
RunMultiHeadAttentionKernel(
|
||||
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
|
||||
data.fp16_output_data, data.num_heads, data.batch_size, data.sequence_length, data.kv_sequence_length,
|
||||
data.hidden_size, data.v_hidden_size, kernel_type, use_float16);
|
||||
}
|
||||
#endif
|
||||
|
||||
kernel_type = AttentionKernelType::AttentionKernel_Default;
|
||||
RunMultiHeadAttentionKernel(
|
||||
data.query_data, data.key_data, data.value_data, data.bias_data, data.key_padding_mask_data, data.mask_type,
|
||||
|
|
@ -245,6 +281,24 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize40) {
|
|||
GetCrossAttentionData_HeadSize40(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask1D) {
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, true);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch2_HeadSize32_RightSidePadding_Mask2D) {
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_Batch2_HeadSize32_RightSidePadding(data, false);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
|
||||
TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize32_LeftSidePadding_Mask2D) {
|
||||
AttentionTestData data;
|
||||
GetCrossAttentionData_Batch1_HeadSize32_LeftSidePadding(data);
|
||||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
#endif
|
||||
|
||||
// This tests qk_head_size != k_head_size
|
||||
|
|
@ -260,6 +314,5 @@ TEST(MultiHeadAttentionTest, CrossAttention_Batch1_HeadSize16) {
|
|||
RunMultiHeadAttentionTests(data);
|
||||
}
|
||||
|
||||
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -36,6 +36,8 @@ class Attention(nn.Module):
|
|||
self.value = nn.Linear(hidden_dim, self.v_hidden_size)
|
||||
self.is_decoder = is_decoder
|
||||
|
||||
# Do not reshape output for pretty print.
|
||||
self.reshape_output = False
|
||||
self.verbose = False
|
||||
|
||||
def transpose_for_scores(self, x: torch.Tensor, head_size) -> torch.Tensor:
|
||||
|
|
@ -43,7 +45,7 @@ class Attention(nn.Module):
|
|||
x = x.view(new_x_shape)
|
||||
return x.permute(0, 2, 1, 3)
|
||||
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor) -> Tensor:
|
||||
def get_extended_attention_mask(self, attention_mask: Tensor, dtype: torch.dtype) -> Tensor:
|
||||
assert attention_mask.dim() == 2 or attention_mask.dim() == 3
|
||||
extended_attention_mask = (
|
||||
attention_mask[:, None, :, :] if attention_mask.dim() == 3 else attention_mask[:, None, None, :]
|
||||
|
|
@ -120,7 +122,7 @@ class Attention(nn.Module):
|
|||
|
||||
if attention_mask is not None:
|
||||
# Apply the attention mask is (precomputed for all layers in BertModel forward() function)
|
||||
attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask)
|
||||
attention_scores = attention_scores + self.get_extended_attention_mask(attention_mask, hidden_states.dtype)
|
||||
|
||||
# Normalize the attention scores to probabilities.
|
||||
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
|
||||
|
|
@ -131,8 +133,9 @@ class Attention(nn.Module):
|
|||
|
||||
context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
|
||||
|
||||
# new_context_layer_shape = context_layer.size()[:-2] + (self.v_hidden_size,)
|
||||
# context_layer = context_layer.view(new_context_layer_shape)
|
||||
if self.reshape_output:
|
||||
new_context_layer_shape = context_layer.size()[:-2] + (self.v_hidden_size,)
|
||||
context_layer = context_layer.view(new_context_layer_shape)
|
||||
|
||||
print("output", context_layer)
|
||||
|
||||
|
|
@ -144,7 +147,7 @@ class Attention(nn.Module):
|
|||
return outputs
|
||||
|
||||
|
||||
def generate_test_data(
|
||||
def run_cross_attention(
|
||||
hidden_dim,
|
||||
q_head_size,
|
||||
v_head_size,
|
||||
|
|
@ -161,7 +164,8 @@ def generate_test_data(
|
|||
|
||||
device = torch.device("cuda:0")
|
||||
mha = Attention(num_heads, hidden_dim, q_head_size, v_head_size, is_decoder=False).to(device).eval()
|
||||
|
||||
if key_padding_mask is not None:
|
||||
key_padding_mask = key_padding_mask.to(device)
|
||||
torch.nn.init.uniform_(mha.query.weight, -0.5, 0.5)
|
||||
torch.nn.init.uniform_(mha.key.weight, -0.5, 0.5)
|
||||
torch.nn.init.uniform_(mha.value.weight, -0.5, 0.5)
|
||||
|
|
@ -205,10 +209,9 @@ def generate_test_data(
|
|||
past_key_value=None,
|
||||
output_attentions=False,
|
||||
)
|
||||
print("output", output)
|
||||
|
||||
|
||||
def CrossAttention_Batch2_HeadSize40():
|
||||
def run_cross_batch2_headsize_40():
|
||||
hidden_dim = 80
|
||||
q_head_size = 40
|
||||
v_head_size = 40
|
||||
|
|
@ -216,10 +219,12 @@ def CrossAttention_Batch2_HeadSize40():
|
|||
batch_size = 2
|
||||
sequence_length = 3
|
||||
kv_sequence_length = 5
|
||||
generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length)
|
||||
run_cross_attention(
|
||||
hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length
|
||||
)
|
||||
|
||||
|
||||
def CrossAttention_Batch1_HeadSize16():
|
||||
def run_cross_batch1_headsize_16():
|
||||
hidden_dim = 32
|
||||
q_head_size = 16
|
||||
v_head_size = 16
|
||||
|
|
@ -227,10 +232,12 @@ def CrossAttention_Batch1_HeadSize16():
|
|||
batch_size = 1
|
||||
sequence_length = 2
|
||||
kv_sequence_length = 3
|
||||
generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length)
|
||||
run_cross_attention(
|
||||
hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length
|
||||
)
|
||||
|
||||
|
||||
def CrossAttention_Batch2_HeadSize16_8():
|
||||
def run_cross_batch2_headsize_16_8():
|
||||
hidden_dim = 32
|
||||
q_head_size = 16
|
||||
v_head_size = 8
|
||||
|
|
@ -238,15 +245,73 @@ def CrossAttention_Batch2_HeadSize16_8():
|
|||
batch_size = 2
|
||||
sequence_length = 1
|
||||
kv_sequence_length = 3
|
||||
generate_test_data(hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length)
|
||||
run_cross_attention(
|
||||
hidden_dim, q_head_size, v_head_size, num_heads, batch_size, sequence_length, kv_sequence_length
|
||||
)
|
||||
|
||||
|
||||
def run_cross_batch2_headsize_32_right_side_padding():
|
||||
hidden_dim = 64
|
||||
q_head_size = 32
|
||||
v_head_size = 32
|
||||
num_heads = 2
|
||||
batch_size = 2
|
||||
sequence_length = 2
|
||||
kv_sequence_length = 3
|
||||
key_padding_mask = torch.tensor([[1, 0, 0], [1, 1, 0]], dtype=torch.int32).cuda()
|
||||
|
||||
run_cross_attention(
|
||||
hidden_dim,
|
||||
q_head_size,
|
||||
v_head_size,
|
||||
num_heads,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
kv_sequence_length,
|
||||
key_padding_mask,
|
||||
)
|
||||
|
||||
|
||||
def run_cross_batch1_headsize_32_left_side_padding():
|
||||
hidden_dim = 32
|
||||
q_head_size = 32
|
||||
v_head_size = 32
|
||||
num_heads = 1
|
||||
batch_size = 2
|
||||
sequence_length = 2
|
||||
kv_sequence_length = 3
|
||||
key_padding_mask = torch.tensor([[0, 1, 1], [0, 0, 1]], dtype=torch.int32).cuda()
|
||||
run_cross_attention(
|
||||
hidden_dim,
|
||||
q_head_size,
|
||||
v_head_size,
|
||||
num_heads,
|
||||
batch_size,
|
||||
sequence_length,
|
||||
kv_sequence_length,
|
||||
key_padding_mask,
|
||||
)
|
||||
|
||||
|
||||
def create_cross_attention_test_data():
|
||||
"""
|
||||
Create test data used in attention_op_test_helper.cc and multihead_attention_op_test.cc
|
||||
"""
|
||||
print("CrossAttention_Batch2_HeadSize40")
|
||||
run_cross_batch2_headsize_40()
|
||||
|
||||
print("CrossAttention_Batch1_HeadSize16")
|
||||
run_cross_batch1_headsize_16()
|
||||
|
||||
print("CrossAttention_Batch2_HeadSize16_8")
|
||||
run_cross_batch2_headsize_16_8()
|
||||
|
||||
print("CrossAttention_Batch2_HeadSize32_RightSidePadding")
|
||||
run_cross_batch2_headsize_32_right_side_padding()
|
||||
|
||||
print("CrossAttention_Batch1_HeadSize32_LeftSidePadding")
|
||||
run_cross_batch1_headsize_32_left_side_padding()
|
||||
|
||||
|
||||
with torch.no_grad():
|
||||
print("CrossAttention_Batch2_HeadSize40")
|
||||
CrossAttention_Batch2_HeadSize40()
|
||||
|
||||
rint("CrossAttention_Batch1_HeadSize16")
|
||||
CrossAttention_Batch1_HeadSize16()
|
||||
|
||||
print("CrossAttention_Batch2_HeadSize16_8")
|
||||
CrossAttention_Batch2_HeadSize16_8()
|
||||
create_cross_attention_test_data()
|
||||
|
|
|
|||
Loading…
Reference in a new issue