From 156368b67f46fc6ab12b59ea39fcc93e51730ae7 Mon Sep 17 00:00:00 2001 From: Yufeng Li Date: Mon, 4 May 2020 14:20:38 -0700 Subject: [PATCH] Quantize attention with Cuda (#3693) * Add definition of QAttention * implemention of QAttention on GPU --- onnxruntime/contrib_ops/cpu/bert/attention.cc | 29 +- onnxruntime/contrib_ops/cpu/bert/attention.h | 5 +- .../contrib_ops/cuda/bert/attention.cc | 3 +- .../contrib_ops/cuda/cuda_contrib_kernels.cc | 6 +- .../quantization/attention_quantization.cc | 189 +++++++++++++ .../quantization/attention_quantization.h | 44 +++ .../attention_quantization_impl.cu | 52 ++++ .../attention_quantization_impl.cuh | 15 ++ .../quantize_dequantize_linear.cc} | 0 .../core/graph/contrib_ops/contrib_defs.cc | 73 +++++ onnxruntime/core/providers/cuda/cuda_common.h | 3 +- .../core/providers/cuda/integer_gemm.cc | 58 ++++ .../providers/cuda/math/matmul_integer.cc | 75 +----- .../providers/cuda/math/matmul_integer.cu | 36 +-- .../providers/cuda/math/matmul_integer.cuh | 2 - .../core/providers/cuda/math/matmul_integer.h | 10 - .../providers/cuda/shared_inc/integer_gemm.h | 23 ++ .../contrib_ops/quantize_attention_op_test.cc | 252 ++++++++++++++++++ 18 files changed, 753 insertions(+), 122 deletions(-) create mode 100644 onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc create mode 100644 onnxruntime/contrib_ops/cuda/quantization/attention_quantization.h create mode 100644 onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cu create mode 100644 onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cuh rename onnxruntime/contrib_ops/cuda/{quantize_ops.cc => quantization/quantize_dequantize_linear.cc} (100%) create mode 100644 onnxruntime/core/providers/cuda/integer_gemm.cc create mode 100644 onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h create mode 100644 onnxruntime/test/contrib_ops/quantize_attention_op_test.cc diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.cc b/onnxruntime/contrib_ops/cpu/bert/attention.cc index ce1715d202..67dd73aee4 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.cc +++ b/onnxruntime/contrib_ops/cpu/bert/attention.cc @@ -28,18 +28,19 @@ AttentionBase::AttentionBase(const OpKernelInfo& info) { int64_t num_heads = 0; ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0); num_heads_ = static_cast(num_heads); - is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 1) == 1; + is_unidirectional_ = info.GetAttrOrDefault("unidirectional", 0) == 1; } -Status AttentionBase::CheckInputs(const OpKernelContext* context) const { +Status AttentionBase::CheckInputs(const Tensor* input, + const Tensor* weights, + const Tensor* bias, + const Tensor* mask_index) const { // Input and output shapes: - // Input 0 - input : (batch_size, sequence_length, hidden_size) - // Input 1 - weights : (hidden_size, 3 * hidden_size) - // Input 2 - bias : (3 * hidden_size) - // Input 3 - mask_index : (batch_size) if presented - // Output : (batch_size, sequence_length, hidden_size) + // input : (batch_size, sequence_length, hidden_size) + // weights : (hidden_size, 3 * hidden_size) + // bias : (3 * hidden_size) + // mask_index : (batch_size) if presented - const Tensor* input = context->Input(0); const auto dims = input->Shape().GetDims(); if (dims.size() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 0 is expected to have 3 dimensions, got ", @@ -52,7 +53,6 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const { "Input 0 dimension 2 should be divisiable by value of the num_heads attribute."); } - const Tensor* weights = context->Input(1); const auto weights_dims = weights->Shape().GetDims(); if (weights_dims.size() != 2) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 1 is expected to have 2 dimensions, got ", @@ -66,7 +66,6 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 1 dimension 1 should be 3 times of dimension 0"); } - const Tensor* bias = context->Input(2); const auto bias_dims = bias->Shape().GetDims(); if (bias_dims.size() != 1) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 2 is expected to have 1 dimension, got ", @@ -77,8 +76,7 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const { "Input 2 dimension 0 should have same length as dimension 1 of input 1"); } - const Tensor* mask_index = context->Input(3); - if (mask_index != nullptr) { + if (mask_index != nullptr) { // mask_index is optional // unidirectional (like GPT2) does not need mask input. Here we do not allowed the input for unidirectional. if (is_unidirectional_) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 (mask_index) is not allowed for unidirectional"); @@ -103,13 +101,11 @@ Attention::Attention(const OpKernelInfo& info) : OpKernel(info), AttentionBas template Status Attention::Compute(OpKernelContext* context) const { - auto* tp = context->GetOperatorThreadPool(); - ORT_RETURN_IF_ERROR(CheckInputs(context)); - const Tensor* input = context->Input(0); const Tensor* weights = context->Input(1); const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); + ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, mask_index)); const auto dims = input->Shape().GetDims(); const int batch_size = static_cast(dims[0]); @@ -125,6 +121,7 @@ Status Attention::Compute(OpKernelContext* context) const { AllocatorPtr allocator; ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator)); + auto* tp = context->GetOperatorThreadPool(); // STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH) auto gemm_data = allocator->Alloc(SafeInt(batch_size) * sequence_length * 3 * hidden_size * element_size); BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator)); @@ -180,7 +177,7 @@ Status Attention::Compute(OpKernelContext* context) const { qkv_dest + qkv_offset, // C head_size, // ldc nullptr // use single-thread - ); + ); } }); } diff --git a/onnxruntime/contrib_ops/cpu/bert/attention.h b/onnxruntime/contrib_ops/cpu/bert/attention.h index ba2684c910..e82d758cdb 100644 --- a/onnxruntime/contrib_ops/cpu/bert/attention.h +++ b/onnxruntime/contrib_ops/cpu/bert/attention.h @@ -12,7 +12,10 @@ namespace contrib { class AttentionBase { protected: AttentionBase(const OpKernelInfo& info); - Status CheckInputs(const OpKernelContext* context) const; + Status CheckInputs(const Tensor* input, + const Tensor* weights, + const Tensor* bias, + const Tensor* mask_index) const; int num_heads_; // number of attention heads bool is_unidirectional_; // whether every token can only attend to previous tokens. diff --git a/onnxruntime/contrib_ops/cuda/bert/attention.cc b/onnxruntime/contrib_ops/cuda/bert/attention.cc index 9be5096a05..5d679fd335 100644 --- a/onnxruntime/contrib_ops/cuda/bert/attention.cc +++ b/onnxruntime/contrib_ops/cuda/bert/attention.cc @@ -34,8 +34,6 @@ Attention::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB template Status Attention::ComputeInternal(OpKernelContext* context) const { - ORT_RETURN_IF_ERROR(CheckInputs(context)); - // Input and output shapes: // Input 0 - input : (batch_size, sequence_length, hidden_size) // Input 1 - weights : (hidden_size, 3 * hidden_size) @@ -46,6 +44,7 @@ Status Attention::ComputeInternal(OpKernelContext* context) const { const Tensor* weights = context->Input(1); const Tensor* bias = context->Input(2); const Tensor* mask_index = context->Input(3); + ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, mask_index)); const auto dims = input->Shape().GetDims(); int batch_size = static_cast(dims[0]); diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index a081f18847..685ae59fc6 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -68,6 +68,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention); +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention); Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { static const BuildKernelCreateInfoFn function_table[] = { @@ -130,7 +132,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, - BuildKernelCreateInfo}; + BuildKernelCreateInfo, + BuildKernelCreateInfo, + BuildKernelCreateInfo}; for (auto& function_table_entry : function_table) { ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry())); diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc new file mode 100644 index 0000000000..4765726ddd --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc @@ -0,0 +1,189 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "attention_quantization.h" +#include "attention_quantization_impl.cuh" +#include "contrib_ops/cuda/bert/attention_impl.h" +#include "core/framework/tensorprotoutils.h" +#include "core/providers/common.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/shared_inc/integer_gemm.h" +#include "core/providers/cuda/tensor/quantize_linear.h" + +using namespace onnxruntime::cuda; +using namespace onnxruntime::common; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL_TYPED(T, TQuant) \ + ONNX_OPERATOR_TYPED_KERNEL_EX( \ + QAttention, \ + kMSDomain, \ + 1, \ + T##_##TQuant, \ + kCudaExecutionProvider, \ + KernelDefBuilder() \ + .InputMemoryType(3) \ + .InputMemoryType(4) \ + .InputMemoryType(6) \ + .InputMemoryType(7) \ + .TypeConstraint("T1", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T2", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T3", DataTypeImpl::GetTensorType()) \ + .TypeConstraint("T4", DataTypeImpl::GetTensorType()), \ + QAttention); + +REGISTER_KERNEL_TYPED(float, int8_t) +REGISTER_KERNEL_TYPED(MLFloat16, int8_t) + +template +Status QAttention::CheckInputs(const Tensor* input, + const Tensor* weights, + const Tensor* bias, + const Tensor* input_scale_tensor, + const Tensor* weight_scale_tensor, + const Tensor* mask_index, + const Tensor* i_zp_tensor, + const Tensor* w_zp_tensor) const { + // Input and output shapes: + // Input 0 - input : (batch_size, sequence_length, hidden_size) + // Input 1 - weights : (hidden_size, 3 * hidden_size) + // Input 2 - bias : (3 * hidden_size) + // Input 3 - input_scale : scalar + // Input 4 - weight_scale : scalar + // Input 5 - mask_index : (batch_size) + // Input 6 - input_zero_point : scalar + // Input 7 - weight_zero_point : scalar + // Output : (batch_size, sequence_length, hidden_size) + + ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input, weights, bias, mask_index)); + + ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor), + "input scale must be a scalar or 1D tensor of size 1"); + + ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(weight_scale_tensor), + "weight must be a scalar or 1D tensor of size 1"); + + if (i_zp_tensor != nullptr) { + ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(i_zp_tensor), + "input zero point must be a scalar or 1D tensor of size 1."); + if (0 != *(i_zp_tensor->template Data())) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention"); + } + + if (w_zp_tensor != nullptr) { + // CUDA only support symmetric quantization for Attention + ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(w_zp_tensor), + "weight zero point must be a scalar or 1D tensor of size 1."); + if (0 != *(w_zp_tensor->template Data())) + return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention"); + } + + return Status::OK(); +} + +template +Status QAttention::ComputeInternal(OpKernelContext* context) const { + // Input and output shapes: + // Input 0 - input : (batch_size, sequence_length, hidden_size) + // Input 1 - weights : (hidden_size, 3 * hidden_size) + // Input 2 - bias : (3 * hidden_size) + // Input 3 - input_scale : scalar + // Input 4 - weight_scale : scalar + // Input 5 - mask_index : (batch_size) + // Input 6 - input_zero_point : scalar + // Input 7 - weight_zero_point : scalar + // Output : (batch_size, sequence_length, hidden_size) + // ORT_RETURN_IF_ERROR(CheckInputs(context)); + const Tensor* input = context->Input(0); + const Tensor* weights = context->Input(1); + const Tensor* bias = context->Input(2); + const Tensor* input_scale_tensor = context->Input(3); + const Tensor* weight_scale_tensor = context->Input(4); + const Tensor* mask_index = context->Input(5); + const Tensor* i_zp_tensor = context->Input(6); + const Tensor* w_zp_tensor = context->Input(7); + + ORT_RETURN_IF_ERROR(CheckInputs(input, + weights, + bias, + input_scale_tensor, + weight_scale_tensor, + mask_index, + i_zp_tensor, + w_zp_tensor)); + + const auto dims = input->Shape().GetDims(); + /*int input_size = static_cast(input->Shape().Size());*/ + int batch_size = static_cast(dims[0]); + int sequence_length = static_cast(dims[1]); + int hidden_size = static_cast(dims[2]); + int head_size = hidden_size / num_heads_; + + TensorShape output_shape(dims); + Tensor* output = context->Output(0, output_shape); + + cublasHandle_t cublas = CublasHandle(); + const size_t element_size = sizeof(T); + + // Use GEMM for fully connection. + int m = batch_size * sequence_length; + int n = 3 * hidden_size; + int k = hidden_size; + auto gemm_buffer = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size * element_size); + auto gemm_buffer_quantized = GetScratchBuffer(batch_size * sequence_length * 3 * hidden_size); + + typedef typename ToCudaType::MappedType CudaT; + + GemmInt8(m, n, k, + 1 /*alpha_matmul*/, 0 /* beta_matmul*/, + input->template Data(), k, + weights->template Data(), n, + gemm_buffer_quantized.get(), n, + this); + + CudaT dequant_scale; + CudaT input_scale = *(reinterpret_cast(input_scale_tensor->template Data())); + CudaT weight_scale = *(reinterpret_cast(weight_scale_tensor->template Data())); + if (sizeof(T) == 2) { + dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale)); + } else { + dequant_scale = input_scale * weight_scale; + } + // scale back and bias + CudaDequantizeWithBias( + gemm_buffer_quantized.get(), + reinterpret_cast(bias->template Data()), + reinterpret_cast(gemm_buffer.get()), + dequant_scale, + m, + n); + + size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length); + auto temp_buffer = GetScratchBuffer(workSpaceSize); + if (!LaunchAttentionKernel( + reinterpret_cast(gemm_buffer.get()), + nullptr == mask_index ? nullptr : mask_index->template Data(), + output->template MutableData(), + batch_size, + sequence_length, + num_heads_, + head_size, + temp_buffer.get(), + cublas, + element_size, + false)) { + // Get last error to reset it to cudaSuccess. + CUDA_CALL(cudaGetLastError()); + return Status(common::ONNXRUNTIME, common::FAIL); + } + + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.h b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.h new file mode 100644 index 0000000000..ba636d5635 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.h @@ -0,0 +1,44 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/common/common.h" +#include "core/framework/op_kernel.h" +#include "core/providers/cuda/cuda_common.h" +#include "contrib_ops/cpu/bert/attention.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +using namespace onnxruntime::cuda; + +template +class QAttention; + +template +class QAttention final : public CudaKernel, public AttentionBase { + using Base = CudaKernel; + + public: + QAttention(const OpKernelInfo& info) : CudaKernel(info), + AttentionBase(info) { + } + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + Status CheckInputs(const Tensor* input, + const Tensor* weights, + const Tensor* bias, + const Tensor* input_scale_tensor, + const Tensor* weight_scale_tensor, + const Tensor* mask_index, + const Tensor* i_zp_tensor, + const Tensor* w_zp_tensor) const; +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cu b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cu new file mode 100644 index 0000000000..42791ae795 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cu @@ -0,0 +1,52 @@ +// Modifications: scaling is moved from masked softmax to the gemm before that. +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include +#include +#include +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/cuda_common.h" +#include "attention_quantization_impl.cuh" + +using namespace onnxruntime::cuda; +using namespace cub; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +template +__global__ void DequantizeLinearKernel(const int32_t* quantize, const T* bias, T* output, T scale, int bias_len, CUDA_LONG N) { + CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x; + +#pragma unroll + for (int i = 0; i < NumElementsPerThread; i++) { + if (id < N) { + output[id] = (static_cast(quantize[id]) * scale) + bias[id % bias_len]; + id += NumThreadsPerBlock; + } + } +} + +template +Status CudaDequantizeWithBias(const int32_t* quantize, const T* bias, T* output, T scale, int m, int n) { + int blocksPerGrid = static_cast(CeilDiv(m * n, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread)); + CUDA_LONG N = static_cast(m * n); + DequantizeLinearKernel<<>>( + quantize, + bias, + output, + scale, + n, + N); + return Status::OK(); +} + +template Status CudaDequantizeWithBias(const int32_t* quantize, const float* bias, float* output, float scale, int m, int n); +template Status CudaDequantizeWithBias(const int32_t* quantize, const half* bias, half* output, half scale, int m, int n); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cuh b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cuh new file mode 100644 index 0000000000..dc0ba262fa --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/quantization/attention_quantization_impl.cuh @@ -0,0 +1,15 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { +template +Status CudaDequantizeWithBias(const int32_t* quantize, const Tin* bias, Tin* output, Tin scale, int m, int n); + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/quantize_ops.cc b/onnxruntime/contrib_ops/cuda/quantization/quantize_dequantize_linear.cc similarity index 100% rename from onnxruntime/contrib_ops/cuda/quantize_ops.cc rename to onnxruntime/contrib_ops/cuda/quantization/quantize_dequantize_linear.cc diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index aaa2258fe6..6be24b9ae9 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -314,6 +314,79 @@ mask_index shall not be provided.)DOC"; .TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types") .TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput); + ONNX_CONTRIB_OPERATOR_SCHEMA(QAttention) + .SetDomain(kMSDomain) + .SinceVersion(1) + .SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL) + .SetDoc("Quantization of Multi-Head Self Attention.") + .Attr("num_heads", "Number of attention heads", AttributeProto::INT) + .Input( + 0, + "input", + "3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size", + "T1") + .Input( + 1, + "weight", + "2D input tensor with shape (hidden_size, 3 * hidden_size)", + "T2") + .Input( + 2, + "bias", + "1D input tensor with shape (3 * hidden_size)", + "T3") + .Input( + 3, + "input_scale", + "scale of quantized input tensor. It's a scalar, which means a per-tensor/layer quantization.", + "T3") + .Input( + 4, + "weight_scale", + "scale of weight scale. It's a scalar, which means a per-tensor/layer quantization.", + "T3") + .Input( + 5, + "mask_index", + "Attention mask index with shape (batch_size)", + "T4", + OpSchema::Optional) + .Input( + 6, + "input_zero_point", + "zero point of quantized input tensor.It's a scalar, which means a per-tensor/layer quantization.", + "T1", + OpSchema::Optional) + .Input( + 7, + "weight_zero_point", + "zero point of quantized weight tensor. It's a scalar, which means a per-tensor/layer quantization.", + "T2", + OpSchema::Optional) + .Output( + 0, + "output", + "3D output tensor with shape (batch_size, sequence_length, hidden_size)", + "T3") + .TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.") + .TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.") + .TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.") + .TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + // Type inference + ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0); + + // Shape inference + // if the input shape doesn't exist, further shape inference is not possible + if (!hasNInputShapes(ctx, 1)) { + return; + } + + ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0); + + return; + }); + static const char* EmbedLayerNormalization_ver1_doc = R"DOC( EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing. The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding, diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index b513110ebe..b5087e01e5 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -161,7 +161,6 @@ class CudaKernel : public OpKernel { const CudaKernel* op_kernel_; }; - protected: inline cublasHandle_t CublasHandle() const { return provider_->PerThreadCublasHandle(); } @@ -169,6 +168,8 @@ class CudaKernel : public OpKernel { inline cudnnHandle_t CudnnHandle() const { return provider_->PerThreadCudnnHandle(); } + + protected: inline curandGenerator_t CurandGenerator() const { return provider_->PerThreadCurandGenerator(); } diff --git a/onnxruntime/core/providers/cuda/integer_gemm.cc b/onnxruntime/core/providers/cuda/integer_gemm.cc new file mode 100644 index 0000000000..3bc3b3962e --- /dev/null +++ b/onnxruntime/core/providers/cuda/integer_gemm.cc @@ -0,0 +1,58 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "core/providers/cuda/shared_inc/integer_gemm.h" + +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/cuda_call.h" + +namespace onnxruntime { +namespace cuda { + +inline int roundoff(int v, int d) { + return (v + d - 1) / d * d; +} + +Status GemmInt8(int m, int n, int k, + int32_t alpha, int32_t beta, + const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc, + const CudaKernel* cuda_kernel) { + ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null"); + ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null"); + + // pad A and B to make their leading dimension be multiples of 32 + // because cublasGemmEx requires: + // 1. leading dimension is multiples of 4 + // 2. A, B is 32-bit aligned + + const int mask = 0x1F; + int lda_aligned = lda; + IAllocatorUniquePtr a_padded; + if ((mask & lda_aligned) != 0) { + lda_aligned = roundoff(lda, 32); + a_padded = cuda_kernel->GetScratchBuffer(m * lda_aligned); + cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, 0); + } + + int ldb_aligned = ldb; + IAllocatorUniquePtr b_padded; + if ((mask & ldb_aligned) != 0) { + ldb_aligned = roundoff(ldb, 32); + b_padded = cuda_kernel->GetScratchBuffer(k * ldb_aligned); + cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, 0); + } + + CUBLAS_RETURN_IF_ERROR(cublasGemmEx( + cuda_kernel->CublasHandle(), + CUBLAS_OP_N, CUBLAS_OP_N, + n, m, k, + &alpha, + ldb_aligned == ldb ? b : b_padded.get(), CUDA_R_8I, ldb_aligned, + lda_aligned == lda ? a : a_padded.get(), CUDA_R_8I, lda_aligned, + &beta, + c, CUDA_R_32I, ldc, CUDA_R_32I, + CUBLAS_GEMM_DFALT)); + return Status::OK(); +} +} // namespace cuda +} // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/matmul_integer.cc b/onnxruntime/core/providers/cuda/math/matmul_integer.cc index e4547da039..be0cb953af 100644 --- a/onnxruntime/core/providers/cuda/math/matmul_integer.cc +++ b/onnxruntime/core/providers/cuda/math/matmul_integer.cc @@ -5,6 +5,7 @@ #include "matmul_integer.cuh" #include "core/providers/cpu/math/matmul_helper.h" #include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cuda/shared_inc/integer_gemm.h" #include "core/providers/cuda/cuda_allocator.h" #include "core/providers/common.h" @@ -25,26 +26,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX( .TypeConstraint("T3", DataTypeImpl::GetTensorType()), MatMulInteger); -template <> -Status MatMulInteger::PadMatrix( - int64_t row, - int64_t col, - int64_t align_size, - const int8_t*& src, - int64_t& pad_size, - IAllocatorUniquePtr& temp_mem_holder) const { - pad_size = align_size - col % align_size; - if (pad_size != align_size) { - temp_mem_holder = GetScratchBuffer(row * (col + pad_size)); - ORT_RETURN_IF_ERROR(PadMatrixInLeadingDimension(src, temp_mem_holder.get(), row, col, pad_size)); - src = temp_mem_holder.get(); - } else { - pad_size = 0; - } - - return Status::OK(); -} - template <> Status MatMulInteger::ComputeInternal(OpKernelContext* ctx) const { auto a = ctx->Input(0); @@ -106,49 +87,19 @@ Status MatMulInteger::ComputeInternal(OpKernelContext* ctx) cons beta = 1; } - // pad A and B to make their leading dimension be multiples of 32 - // because cublasGemmEx requires: - // 1. leading dimension is multiples of 4 - // 2. A, B is 32-bit aligned - const int64_t align_size = 32; - int64_t a_pad_size = 0; - int64_t b_pad_size = 0; - IAllocatorUniquePtr a_padded; - IAllocatorUniquePtr b_padded; - ORT_RETURN_IF_ERROR(PadMatrix(a->Shape().Size() / helper.K(), - helper.K(), - align_size, - a_ptr, - a_pad_size, - a_padded)); - ORT_RETURN_IF_ERROR(PadMatrix(b->Shape().Size() / helper.N(), - helper.N(), - align_size, - b_ptr, - b_pad_size, - b_padded)); - for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) { - CUBLAS_RETURN_IF_ERROR(cublasGemmEx( - Base::CublasHandle(), - CUBLAS_OP_N, - CUBLAS_OP_N, - static_cast(helper.N()), - static_cast(helper.M()), - static_cast(helper.K()), - &alpha, - b_ptr + helper.RightOffsets()[batch] + helper.RightOffsets()[batch] / helper.N() * b_pad_size, - CUDA_R_8I, - static_cast(helper.N() + b_pad_size), - a_ptr + helper.LeftOffsets()[batch] + helper.LeftOffsets()[batch] / helper.K() * a_pad_size, - CUDA_R_8I, - static_cast(helper.K() + a_pad_size), - &beta, - output_ptr + helper.OutputOffsets()[batch], - CUDA_R_32I, - static_cast(helper.N()), - CUDA_R_32I, - CUBLAS_GEMM_DFALT)); + GemmInt8(static_cast(helper.M()), + static_cast(helper.N()), + static_cast(helper.K()), + alpha, + beta, + a_ptr + helper.LeftOffsets()[batch], + static_cast(helper.K()), + b_ptr + helper.RightOffsets()[batch], + static_cast(helper.N()), + output_ptr + helper.OutputOffsets()[batch], + static_cast(helper.N()), + this); } return Status::OK(); diff --git a/onnxruntime/core/providers/cuda/math/matmul_integer.cu b/onnxruntime/core/providers/cuda/math/matmul_integer.cu index 7f7bf1c41e..267cf198c9 100644 --- a/onnxruntime/core/providers/cuda/math/matmul_integer.cu +++ b/onnxruntime/core/providers/cuda/math/matmul_integer.cu @@ -28,11 +28,10 @@ __global__ void ReduceRowSumOnMatrixAKernel(const int8_t* matrix, int32_t* row_s Status ReduceRowSumOnMatrixA(const int8_t* matrix, int32_t* row_sum, const int8_t offset, const MatMulComputeHelper& helper) { for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) { - ReduceRowSumOnMatrixAKernel(GridDim::maxThreadsPerBlock)> - <<(helper.M()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.LeftOffsets()[batch], - row_sum + batch * helper.M(), - offset, - static_cast(helper.K())); + ReduceRowSumOnMatrixAKernel(GridDim::maxThreadsPerBlock)><<(helper.M()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.LeftOffsets()[batch], + row_sum + batch * helper.M(), + offset, + static_cast(helper.K())); } return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL); @@ -57,12 +56,11 @@ __global__ void ReduceColSumOnMatrixBKernel(const int8_t* matrix, int32_t* col_s Status ReduceColSumOnMatrixB(const int8_t* matrix, int32_t* col_sum, const int8_t offset, const MatMulComputeHelper& helper) { for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) { - ReduceColSumOnMatrixBKernel(GridDim::maxThreadsPerBlock)> - <<(helper.N()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.RightOffsets()[batch], - col_sum + batch * helper.N(), - offset, - static_cast(helper.K()), - static_cast(helper.N())); + ReduceColSumOnMatrixBKernel(GridDim::maxThreadsPerBlock)><<(helper.N()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.RightOffsets()[batch], + col_sum + batch * helper.N(), + offset, + static_cast(helper.K()), + static_cast(helper.N())); } return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL); @@ -128,21 +126,5 @@ Status OffsetOutput(const int32_t* row_sum, return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL); } -__global__ void PadMatrixInLeadingDimensionKernel(const int8_t* src, int8_t* dst, int col_src, int col_dst) { - for (int32_t i = threadIdx.x; i < col_src; i += blockDim.x) { - *(dst + blockIdx.x * col_dst + i) = *(src + blockIdx.x * col_src + i); - } -} - -Status PadMatrixInLeadingDimension(const int8_t* src, int8_t* dst, int64_t row, int64_t col, int64_t pad_size) { - PadMatrixInLeadingDimensionKernel<<(row), GridDim::maxThreadsPerBlock, 0>>>( - src, - dst, - static_cast(col), - static_cast(col + pad_size)); - - return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL); - ; -} } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/matmul_integer.cuh b/onnxruntime/core/providers/cuda/math/matmul_integer.cuh index 7be00db358..e22bbf4d24 100644 --- a/onnxruntime/core/providers/cuda/math/matmul_integer.cuh +++ b/onnxruntime/core/providers/cuda/math/matmul_integer.cuh @@ -20,7 +20,5 @@ Status OffsetOutput(const int32_t* row_sum, const int8_t b_offset, const MatMulComputeHelper& helper); -Status PadMatrixInLeadingDimension(const int8_t* src, int8_t* dst, int64_t row, int64_t col, int64_t pad_size); - } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/math/matmul_integer.h b/onnxruntime/core/providers/cuda/math/matmul_integer.h index a8c7e8b708..89eaebb0de 100644 --- a/onnxruntime/core/providers/cuda/math/matmul_integer.h +++ b/onnxruntime/core/providers/cuda/math/matmul_integer.h @@ -25,16 +25,6 @@ class MatMulInteger final : public CudaKernel { Status ComputeInternal(OpKernelContext* context) const override; - private: - // pad matrix and B to make their leading dimension be multiples of *align_size* - Status PadMatrix( - int64_t row, - int64_t col, - int64_t align_size, - const int8_t*& src, - int64_t& pad_size, - IAllocatorUniquePtr& temp_mem_holder) const; - private: bool has_a_zero_point_; bool has_b_zero_point_; diff --git a/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h b/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h new file mode 100644 index 0000000000..4953d9e50d --- /dev/null +++ b/onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "core/providers/cuda/cuda_common.h" + +namespace onnxruntime { +namespace cuda { +Status GemmInt8(int m, + int n, + int k, + int32_t alpha_matmul, + int32_t beta_matmul, + const int8_t* a, + int lda, + const int8_t* b, + int ldb, + int32_t* c, + int ldc, + const CudaKernel* cuda_kernel); +} +} \ No newline at end of file diff --git a/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc new file mode 100644 index 0000000000..be37ca81cd --- /dev/null +++ b/onnxruntime/test/contrib_ops/quantize_attention_op_test.cc @@ -0,0 +1,252 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include "gtest/gtest.h" +#include "test/common/tensor_op_test_utils.h" +#include "test/common/cuda_op_test_utils.h" +#include "test/providers/provider_test_utils.h" + +namespace onnxruntime { +namespace test { + +template ::value, Integer>::type> +inline std::vector ToInteger(const std::vector& data, float scale) { + std::vector result; + result.reserve(data.size()); + for (size_t i = 0; i < data.size(); i++) { + result.push_back(static_cast(std::round(data[i] / scale))); + } + return result; +} + +static void RunAttentionTest( + const std::vector& input_data, // input: [batch_size, sequence_length, hidden_size] + const std::vector& weights_data, // weights: [hidden_size, 3 * hidden_size] + const std::vector& bias_data, // bias: [3 * hidden_size] + const std::vector& mask_index_data, // mask_index: [batch_size] + const std::vector& output_data, // output: [batch_size, sequence_length, hidden_size] + int batch_size, + int sequence_length, + int hidden_size, + int number_of_heads, + bool use_float16 = false) { + int min_cuda_architecture = 530; + + bool enable_cuda = HasCudaEnvironment(min_cuda_architecture); + + if (enable_cuda) { + OpTester tester("QAttention", 1, onnxruntime::kMSDomain); + tester.AddAttribute("num_heads", static_cast(number_of_heads)); + + std::vector input_dims = {batch_size, sequence_length, hidden_size}; + std::vector weights_dims = {hidden_size, 3 * hidden_size}; + std::vector bias_dims = {3 * hidden_size}; + std::vector mask_index_dims = {batch_size}; + std::vector output_dims = input_dims; + + float input_scale = 0.1f; + float weight_scale = 0.1f; + tester.AddInput("input", input_dims, ToInteger(input_data, input_scale)); + tester.AddInput("weight", weights_dims, ToInteger(weights_data, weight_scale)); + if (use_float16) { + tester.AddInput("bias", bias_dims, ToFloat16(bias_data)); + tester.AddInput("input_scale", {1}, ToFloat16({input_scale})); + tester.AddInput("weight_scale", {1}, ToFloat16({weight_scale})); + tester.AddOutput("output", output_dims, ToFloat16(output_data)); + } else { + tester.AddInput("bias", bias_dims, bias_data); + tester.AddInput("input_scale", {1}, {input_scale}); + tester.AddInput("weight_scale", {1}, {weight_scale}); + tester.AddOutput("output", output_dims, output_data); + } + + if (mask_index_data.size() > 0) { // mask index is optional. + tester.AddInput("mask_index", mask_index_dims, mask_index_data); + } + + tester.Run(); + } +} + +TEST(QAttentionTest, AttentionBatch1) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + std::vector mask_index_data = {2L}; + + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads); +} + +TEST(QAttentionTest, AttentionBatch1_Float16) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + std::vector mask_index_data = {2L}; + + std::vector output_data = { + 3.15039f, 0.1082763671875f, 4.24609375f, 5.6484375f, + 3.96679f, 0.072998046875f, 4.24609f, 5.6484375f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads, true); +} + +TEST(QAttentionTest, AttentionBatch2) { + int batch_size = 2; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f, + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + std::vector mask_index_data = {2L, 2L}; + + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f, + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads); +} + +TEST(QAttentionTest, AttentionMaskPartialSequence) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test mask_index < sequence_length + std::vector mask_index_data = {1L}; + + std::vector output_data = { + 8.6899995803833008f, -0.13000002503395081f, 4.25f, 5.6499996185302734f, + 8.6899995803833008f, -0.13000002503395081f, 4.2499995231628418f, 5.6499991416931152f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads); +} + +TEST(QAttentionTest, AttentionMaskExceedSequence) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // Test mask_index > sequence_length + std::vector mask_index_data = {3L}; + + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads); +} + +TEST(QAttentionTest, AttentionNoMaskIndex) { + int batch_size = 1; + int sequence_length = 2; + int hidden_size = 4; + int number_of_heads = 2; + + std::vector input_data = { + 0.8f, -0.5f, 0.0f, 1.f, + 0.5f, 0.2f, 0.3f, -0.6f}; + + std::vector weight_data = { + 0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f, + 0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f, + 0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f, + 0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f}; + + std::vector bias_data = { + -0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f}; + + // No mask_index + std::vector mask_index_data = {}; + + std::vector output_data = { + 3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f, + 3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f}; + + RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data, + batch_size, sequence_length, hidden_size, number_of_heads); +} +} // namespace test +} // namespace onnxruntime