mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-24 02:47:54 +00:00
Quantize attention with Cuda (#3693)
* Add definition of QAttention * implemention of QAttention on GPU
This commit is contained in:
parent
49f0610447
commit
156368b67f
18 changed files with 753 additions and 122 deletions
|
|
@ -28,18 +28,19 @@ AttentionBase::AttentionBase(const OpKernelInfo& info) {
|
|||
int64_t num_heads = 0;
|
||||
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
|
||||
num_heads_ = static_cast<int>(num_heads);
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 1) == 1;
|
||||
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
|
||||
}
|
||||
|
||||
Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
|
||||
Status AttentionBase::CheckInputs(const Tensor* input,
|
||||
const Tensor* weights,
|
||||
const Tensor* bias,
|
||||
const Tensor* mask_index) const {
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
// Input 2 - bias : (3 * hidden_size)
|
||||
// Input 3 - mask_index : (batch_size) if presented
|
||||
// Output : (batch_size, sequence_length, hidden_size)
|
||||
// input : (batch_size, sequence_length, hidden_size)
|
||||
// weights : (hidden_size, 3 * hidden_size)
|
||||
// bias : (3 * hidden_size)
|
||||
// mask_index : (batch_size) if presented
|
||||
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const auto dims = input->Shape().GetDims();
|
||||
if (dims.size() != 3) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 0 is expected to have 3 dimensions, got ",
|
||||
|
|
@ -52,7 +53,6 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
|
|||
"Input 0 dimension 2 should be divisiable by value of the num_heads attribute.");
|
||||
}
|
||||
|
||||
const Tensor* weights = context->Input<Tensor>(1);
|
||||
const auto weights_dims = weights->Shape().GetDims();
|
||||
if (weights_dims.size() != 2) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 1 is expected to have 2 dimensions, got ",
|
||||
|
|
@ -66,7 +66,6 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
|
|||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 1 dimension 1 should be 3 times of dimension 0");
|
||||
}
|
||||
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
const auto bias_dims = bias->Shape().GetDims();
|
||||
if (bias_dims.size() != 1) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 2 is expected to have 1 dimension, got ",
|
||||
|
|
@ -77,8 +76,7 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
|
|||
"Input 2 dimension 0 should have same length as dimension 1 of input 1");
|
||||
}
|
||||
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
if (mask_index != nullptr) {
|
||||
if (mask_index != nullptr) { // mask_index is optional
|
||||
// unidirectional (like GPT2) does not need mask input. Here we do not allowed the input for unidirectional.
|
||||
if (is_unidirectional_) {
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 (mask_index) is not allowed for unidirectional");
|
||||
|
|
@ -103,13 +101,11 @@ Attention<T>::Attention(const OpKernelInfo& info) : OpKernel(info), AttentionBas
|
|||
|
||||
template <typename T>
|
||||
Status Attention<T>::Compute(OpKernelContext* context) const {
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(context));
|
||||
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* weights = context->Input<Tensor>(1);
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, mask_index));
|
||||
|
||||
const auto dims = input->Shape().GetDims();
|
||||
const int batch_size = static_cast<int>(dims[0]);
|
||||
|
|
@ -125,6 +121,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
AllocatorPtr allocator;
|
||||
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
|
||||
|
||||
auto* tp = context->GetOperatorThreadPool();
|
||||
// STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH)
|
||||
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
|
||||
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
|
||||
|
|
@ -180,7 +177,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
|
|||
qkv_dest + qkv_offset, // C
|
||||
head_size, // ldc
|
||||
nullptr // use single-thread
|
||||
);
|
||||
);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -12,7 +12,10 @@ namespace contrib {
|
|||
class AttentionBase {
|
||||
protected:
|
||||
AttentionBase(const OpKernelInfo& info);
|
||||
Status CheckInputs(const OpKernelContext* context) const;
|
||||
Status CheckInputs(const Tensor* input,
|
||||
const Tensor* weights,
|
||||
const Tensor* bias,
|
||||
const Tensor* mask_index) const;
|
||||
|
||||
int num_heads_; // number of attention heads
|
||||
bool is_unidirectional_; // whether every token can only attend to previous tokens.
|
||||
|
|
|
|||
|
|
@ -34,8 +34,6 @@ Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB
|
|||
|
||||
template <typename T>
|
||||
Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(context));
|
||||
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
|
|
@ -46,6 +44,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
const Tensor* weights = context->Input<Tensor>(1);
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
const Tensor* mask_index = context->Input<Tensor>(3);
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, mask_index));
|
||||
|
||||
const auto dims = input->Shape().GetDims();
|
||||
int batch_size = static_cast<int>(dims[0]);
|
||||
|
|
|
|||
|
|
@ -68,6 +68,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
|
|||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
|
||||
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
|
||||
|
||||
Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
||||
static const BuildKernelCreateInfoFn function_table[] = {
|
||||
|
|
@ -130,7 +132,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
|
|||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear)>};
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
|
||||
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>};
|
||||
|
||||
for (auto& function_table_entry : function_table) {
|
||||
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));
|
||||
|
|
|
|||
|
|
@ -0,0 +1,189 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "attention_quantization.h"
|
||||
#include "attention_quantization_impl.cuh"
|
||||
#include "contrib_ops/cuda/bert/attention_impl.h"
|
||||
#include "core/framework/tensorprotoutils.h"
|
||||
#include "core/providers/common.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cuda/shared_inc/integer_gemm.h"
|
||||
#include "core/providers/cuda/tensor/quantize_linear.h"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace onnxruntime::common;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
#define REGISTER_KERNEL_TYPED(T, TQuant) \
|
||||
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
||||
QAttention, \
|
||||
kMSDomain, \
|
||||
1, \
|
||||
T##_##TQuant, \
|
||||
kCudaExecutionProvider, \
|
||||
KernelDefBuilder() \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(3) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(4) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(6) \
|
||||
.InputMemoryType<OrtMemTypeCPUInput>(7) \
|
||||
.TypeConstraint("T1", DataTypeImpl::GetTensorType<TQuant>()) \
|
||||
.TypeConstraint("T2", DataTypeImpl::GetTensorType<TQuant>()) \
|
||||
.TypeConstraint("T3", DataTypeImpl::GetTensorType<T>()) \
|
||||
.TypeConstraint("T4", DataTypeImpl::GetTensorType<int32_t>()), \
|
||||
QAttention<T, TQuant>);
|
||||
|
||||
REGISTER_KERNEL_TYPED(float, int8_t)
|
||||
REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
|
||||
|
||||
template <typename T>
|
||||
Status QAttention<T, int8_t>::CheckInputs(const Tensor* input,
|
||||
const Tensor* weights,
|
||||
const Tensor* bias,
|
||||
const Tensor* input_scale_tensor,
|
||||
const Tensor* weight_scale_tensor,
|
||||
const Tensor* mask_index,
|
||||
const Tensor* i_zp_tensor,
|
||||
const Tensor* w_zp_tensor) const {
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
// Input 2 - bias : (3 * hidden_size)
|
||||
// Input 3 - input_scale : scalar
|
||||
// Input 4 - weight_scale : scalar
|
||||
// Input 5 - mask_index : (batch_size)
|
||||
// Input 6 - input_zero_point : scalar
|
||||
// Input 7 - weight_zero_point : scalar
|
||||
// Output : (batch_size, sequence_length, hidden_size)
|
||||
|
||||
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input, weights, bias, mask_index));
|
||||
|
||||
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor),
|
||||
"input scale must be a scalar or 1D tensor of size 1");
|
||||
|
||||
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(weight_scale_tensor),
|
||||
"weight must be a scalar or 1D tensor of size 1");
|
||||
|
||||
if (i_zp_tensor != nullptr) {
|
||||
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(i_zp_tensor),
|
||||
"input zero point must be a scalar or 1D tensor of size 1.");
|
||||
if (0 != *(i_zp_tensor->template Data<int8_t>()))
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention");
|
||||
}
|
||||
|
||||
if (w_zp_tensor != nullptr) {
|
||||
// CUDA only support symmetric quantization for Attention
|
||||
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(w_zp_tensor),
|
||||
"weight zero point must be a scalar or 1D tensor of size 1.");
|
||||
if (0 != *(w_zp_tensor->template Data<int8_t>()))
|
||||
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention");
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
||||
// Input and output shapes:
|
||||
// Input 0 - input : (batch_size, sequence_length, hidden_size)
|
||||
// Input 1 - weights : (hidden_size, 3 * hidden_size)
|
||||
// Input 2 - bias : (3 * hidden_size)
|
||||
// Input 3 - input_scale : scalar
|
||||
// Input 4 - weight_scale : scalar
|
||||
// Input 5 - mask_index : (batch_size)
|
||||
// Input 6 - input_zero_point : scalar
|
||||
// Input 7 - weight_zero_point : scalar
|
||||
// Output : (batch_size, sequence_length, hidden_size)
|
||||
// ORT_RETURN_IF_ERROR(CheckInputs(context));
|
||||
const Tensor* input = context->Input<Tensor>(0);
|
||||
const Tensor* weights = context->Input<Tensor>(1);
|
||||
const Tensor* bias = context->Input<Tensor>(2);
|
||||
const Tensor* input_scale_tensor = context->Input<Tensor>(3);
|
||||
const Tensor* weight_scale_tensor = context->Input<Tensor>(4);
|
||||
const Tensor* mask_index = context->Input<Tensor>(5);
|
||||
const Tensor* i_zp_tensor = context->Input<Tensor>(6);
|
||||
const Tensor* w_zp_tensor = context->Input<Tensor>(7);
|
||||
|
||||
ORT_RETURN_IF_ERROR(CheckInputs(input,
|
||||
weights,
|
||||
bias,
|
||||
input_scale_tensor,
|
||||
weight_scale_tensor,
|
||||
mask_index,
|
||||
i_zp_tensor,
|
||||
w_zp_tensor));
|
||||
|
||||
const auto dims = input->Shape().GetDims();
|
||||
/*int input_size = static_cast<int>(input->Shape().Size());*/
|
||||
int batch_size = static_cast<int>(dims[0]);
|
||||
int sequence_length = static_cast<int>(dims[1]);
|
||||
int hidden_size = static_cast<int>(dims[2]);
|
||||
int head_size = hidden_size / num_heads_;
|
||||
|
||||
TensorShape output_shape(dims);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
cublasHandle_t cublas = CublasHandle();
|
||||
const size_t element_size = sizeof(T);
|
||||
|
||||
// Use GEMM for fully connection.
|
||||
int m = batch_size * sequence_length;
|
||||
int n = 3 * hidden_size;
|
||||
int k = hidden_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
|
||||
auto gemm_buffer_quantized = GetScratchBuffer<int32_t>(batch_size * sequence_length * 3 * hidden_size);
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
||||
GemmInt8(m, n, k,
|
||||
1 /*alpha_matmul*/, 0 /* beta_matmul*/,
|
||||
input->template Data<int8_t>(), k,
|
||||
weights->template Data<int8_t>(), n,
|
||||
gemm_buffer_quantized.get(), n,
|
||||
this);
|
||||
|
||||
CudaT dequant_scale;
|
||||
CudaT input_scale = *(reinterpret_cast<const CudaT*>(input_scale_tensor->template Data<T>()));
|
||||
CudaT weight_scale = *(reinterpret_cast<const CudaT*>(weight_scale_tensor->template Data<T>()));
|
||||
if (sizeof(T) == 2) {
|
||||
dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale));
|
||||
} else {
|
||||
dequant_scale = input_scale * weight_scale;
|
||||
}
|
||||
// scale back and bias
|
||||
CudaDequantizeWithBias(
|
||||
gemm_buffer_quantized.get(),
|
||||
reinterpret_cast<const CudaT*>(bias->template Data<T>()),
|
||||
reinterpret_cast<CudaT*>(gemm_buffer.get()),
|
||||
dequant_scale,
|
||||
m,
|
||||
n);
|
||||
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length);
|
||||
auto temp_buffer = GetScratchBuffer<void>(workSpaceSize);
|
||||
if (!LaunchAttentionKernel(
|
||||
reinterpret_cast<const CudaT*>(gemm_buffer.get()),
|
||||
nullptr == mask_index ? nullptr : mask_index->template Data<int>(),
|
||||
output->template MutableData<T>(),
|
||||
batch_size,
|
||||
sequence_length,
|
||||
num_heads_,
|
||||
head_size,
|
||||
temp_buffer.get(),
|
||||
cublas,
|
||||
element_size,
|
||||
false)) {
|
||||
// Get last error to reset it to cudaSuccess.
|
||||
CUDA_CALL(cudaGetLastError());
|
||||
return Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/common/common.h"
|
||||
#include "core/framework/op_kernel.h"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "contrib_ops/cpu/bert/attention.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
|
||||
template <typename T, typename TQuant>
|
||||
class QAttention;
|
||||
|
||||
template <typename T>
|
||||
class QAttention<T, int8_t> final : public CudaKernel, public AttentionBase {
|
||||
using Base = CudaKernel;
|
||||
|
||||
public:
|
||||
QAttention(const OpKernelInfo& info) : CudaKernel(info),
|
||||
AttentionBase(info) {
|
||||
}
|
||||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
Status CheckInputs(const Tensor* input,
|
||||
const Tensor* weights,
|
||||
const Tensor* bias,
|
||||
const Tensor* input_scale_tensor,
|
||||
const Tensor* weight_scale_tensor,
|
||||
const Tensor* mask_index,
|
||||
const Tensor* i_zp_tensor,
|
||||
const Tensor* w_zp_tensor) const;
|
||||
};
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
// Modifications: scaling is moved from masked softmax to the gemm before that.
|
||||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include <cub/cub.cuh>
|
||||
#include <cublas_v2.h>
|
||||
#include <cuda_fp16.h>
|
||||
#include <math_constants.h>
|
||||
#include "core/providers/cuda/cu_inc/common.cuh"
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "attention_quantization_impl.cuh"
|
||||
|
||||
using namespace onnxruntime::cuda;
|
||||
using namespace cub;
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
|
||||
template <class T, int NumThreadsPerBlock, int NumElementsPerThread>
|
||||
__global__ void DequantizeLinearKernel(const int32_t* quantize, const T* bias, T* output, T scale, int bias_len, CUDA_LONG N) {
|
||||
CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
|
||||
|
||||
#pragma unroll
|
||||
for (int i = 0; i < NumElementsPerThread; i++) {
|
||||
if (id < N) {
|
||||
output[id] = (static_cast<T>(quantize[id]) * scale) + bias[id % bias_len];
|
||||
id += NumThreadsPerBlock;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template <class T>
|
||||
Status CudaDequantizeWithBias(const int32_t* quantize, const T* bias, T* output, T scale, int m, int n) {
|
||||
int blocksPerGrid = static_cast<int>(CeilDiv(m * n, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
|
||||
CUDA_LONG N = static_cast<CUDA_LONG>(m * n);
|
||||
DequantizeLinearKernel<T, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
|
||||
quantize,
|
||||
bias,
|
||||
output,
|
||||
scale,
|
||||
n,
|
||||
N);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template Status CudaDequantizeWithBias<float>(const int32_t* quantize, const float* bias, float* output, float scale, int m, int n);
|
||||
template Status CudaDequantizeWithBias<half>(const int32_t* quantize, const half* bias, half* output, half scale, int m, int n);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
#include "core/providers/cuda/shared_inc/cuda_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace contrib {
|
||||
namespace cuda {
|
||||
template <class Tin>
|
||||
Status CudaDequantizeWithBias(const int32_t* quantize, const Tin* bias, Tin* output, Tin scale, int m, int n);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -314,6 +314,79 @@ mask_index shall not be provided.)DOC";
|
|||
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types")
|
||||
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
|
||||
|
||||
ONNX_CONTRIB_OPERATOR_SCHEMA(QAttention)
|
||||
.SetDomain(kMSDomain)
|
||||
.SinceVersion(1)
|
||||
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
|
||||
.SetDoc("Quantization of Multi-Head Self Attention.")
|
||||
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
|
||||
.Input(
|
||||
0,
|
||||
"input",
|
||||
"3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size",
|
||||
"T1")
|
||||
.Input(
|
||||
1,
|
||||
"weight",
|
||||
"2D input tensor with shape (hidden_size, 3 * hidden_size)",
|
||||
"T2")
|
||||
.Input(
|
||||
2,
|
||||
"bias",
|
||||
"1D input tensor with shape (3 * hidden_size)",
|
||||
"T3")
|
||||
.Input(
|
||||
3,
|
||||
"input_scale",
|
||||
"scale of quantized input tensor. It's a scalar, which means a per-tensor/layer quantization.",
|
||||
"T3")
|
||||
.Input(
|
||||
4,
|
||||
"weight_scale",
|
||||
"scale of weight scale. It's a scalar, which means a per-tensor/layer quantization.",
|
||||
"T3")
|
||||
.Input(
|
||||
5,
|
||||
"mask_index",
|
||||
"Attention mask index with shape (batch_size)",
|
||||
"T4",
|
||||
OpSchema::Optional)
|
||||
.Input(
|
||||
6,
|
||||
"input_zero_point",
|
||||
"zero point of quantized input tensor.It's a scalar, which means a per-tensor/layer quantization.",
|
||||
"T1",
|
||||
OpSchema::Optional)
|
||||
.Input(
|
||||
7,
|
||||
"weight_zero_point",
|
||||
"zero point of quantized weight tensor. It's a scalar, which means a per-tensor/layer quantization.",
|
||||
"T2",
|
||||
OpSchema::Optional)
|
||||
.Output(
|
||||
0,
|
||||
"output",
|
||||
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
|
||||
"T3")
|
||||
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.")
|
||||
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.")
|
||||
.TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
|
||||
.TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types")
|
||||
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
|
||||
// Type inference
|
||||
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0);
|
||||
|
||||
// Shape inference
|
||||
// if the input shape doesn't exist, further shape inference is not possible
|
||||
if (!hasNInputShapes(ctx, 1)) {
|
||||
return;
|
||||
}
|
||||
|
||||
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0);
|
||||
|
||||
return;
|
||||
});
|
||||
|
||||
static const char* EmbedLayerNormalization_ver1_doc = R"DOC(
|
||||
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
|
||||
The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding,
|
||||
|
|
|
|||
|
|
@ -161,7 +161,6 @@ class CudaKernel : public OpKernel {
|
|||
const CudaKernel* op_kernel_;
|
||||
};
|
||||
|
||||
protected:
|
||||
inline cublasHandle_t CublasHandle() const {
|
||||
return provider_->PerThreadCublasHandle();
|
||||
}
|
||||
|
|
@ -169,6 +168,8 @@ class CudaKernel : public OpKernel {
|
|||
inline cudnnHandle_t CudnnHandle() const {
|
||||
return provider_->PerThreadCudnnHandle();
|
||||
}
|
||||
|
||||
protected:
|
||||
inline curandGenerator_t CurandGenerator() const {
|
||||
return provider_->PerThreadCurandGenerator();
|
||||
}
|
||||
|
|
|
|||
58
onnxruntime/core/providers/cuda/integer_gemm.cc
Normal file
58
onnxruntime/core/providers/cuda/integer_gemm.cc
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "core/providers/cuda/shared_inc/integer_gemm.h"
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
#include "core/providers/cuda/shared_inc/cuda_call.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
|
||||
inline int roundoff(int v, int d) {
|
||||
return (v + d - 1) / d * d;
|
||||
}
|
||||
|
||||
Status GemmInt8(int m, int n, int k,
|
||||
int32_t alpha, int32_t beta,
|
||||
const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc,
|
||||
const CudaKernel* cuda_kernel) {
|
||||
ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null");
|
||||
ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null");
|
||||
|
||||
// pad A and B to make their leading dimension be multiples of 32
|
||||
// because cublasGemmEx requires:
|
||||
// 1. leading dimension is multiples of 4
|
||||
// 2. A, B is 32-bit aligned
|
||||
|
||||
const int mask = 0x1F;
|
||||
int lda_aligned = lda;
|
||||
IAllocatorUniquePtr<int8_t> a_padded;
|
||||
if ((mask & lda_aligned) != 0) {
|
||||
lda_aligned = roundoff(lda, 32);
|
||||
a_padded = cuda_kernel->GetScratchBuffer<int8_t>(m * lda_aligned);
|
||||
cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, 0);
|
||||
}
|
||||
|
||||
int ldb_aligned = ldb;
|
||||
IAllocatorUniquePtr<int8_t> b_padded;
|
||||
if ((mask & ldb_aligned) != 0) {
|
||||
ldb_aligned = roundoff(ldb, 32);
|
||||
b_padded = cuda_kernel->GetScratchBuffer<int8_t>(k * ldb_aligned);
|
||||
cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, 0);
|
||||
}
|
||||
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmEx(
|
||||
cuda_kernel->CublasHandle(),
|
||||
CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
n, m, k,
|
||||
&alpha,
|
||||
ldb_aligned == ldb ? b : b_padded.get(), CUDA_R_8I, ldb_aligned,
|
||||
lda_aligned == lda ? a : a_padded.get(), CUDA_R_8I, lda_aligned,
|
||||
&beta,
|
||||
c, CUDA_R_32I, ldc, CUDA_R_32I,
|
||||
CUBLAS_GEMM_DFALT));
|
||||
return Status::OK();
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
@ -5,6 +5,7 @@
|
|||
#include "matmul_integer.cuh"
|
||||
#include "core/providers/cpu/math/matmul_helper.h"
|
||||
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
||||
#include "core/providers/cuda/shared_inc/integer_gemm.h"
|
||||
#include "core/providers/cuda/cuda_allocator.h"
|
||||
#include "core/providers/common.h"
|
||||
|
||||
|
|
@ -25,26 +26,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
|
|||
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
|
||||
MatMulInteger<int8_t, int8_t>);
|
||||
|
||||
template <>
|
||||
Status MatMulInteger<int8_t, int8_t>::PadMatrix(
|
||||
int64_t row,
|
||||
int64_t col,
|
||||
int64_t align_size,
|
||||
const int8_t*& src,
|
||||
int64_t& pad_size,
|
||||
IAllocatorUniquePtr<int8_t>& temp_mem_holder) const {
|
||||
pad_size = align_size - col % align_size;
|
||||
if (pad_size != align_size) {
|
||||
temp_mem_holder = GetScratchBuffer<int8_t>(row * (col + pad_size));
|
||||
ORT_RETURN_IF_ERROR(PadMatrixInLeadingDimension(src, temp_mem_holder.get(), row, col, pad_size));
|
||||
src = temp_mem_holder.get();
|
||||
} else {
|
||||
pad_size = 0;
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
template <>
|
||||
Status MatMulInteger<int8_t, int8_t>::ComputeInternal(OpKernelContext* ctx) const {
|
||||
auto a = ctx->Input<Tensor>(0);
|
||||
|
|
@ -106,49 +87,19 @@ Status MatMulInteger<int8_t, int8_t>::ComputeInternal(OpKernelContext* ctx) cons
|
|||
beta = 1;
|
||||
}
|
||||
|
||||
// pad A and B to make their leading dimension be multiples of 32
|
||||
// because cublasGemmEx requires:
|
||||
// 1. leading dimension is multiples of 4
|
||||
// 2. A, B is 32-bit aligned
|
||||
const int64_t align_size = 32;
|
||||
int64_t a_pad_size = 0;
|
||||
int64_t b_pad_size = 0;
|
||||
IAllocatorUniquePtr<int8_t> a_padded;
|
||||
IAllocatorUniquePtr<int8_t> b_padded;
|
||||
ORT_RETURN_IF_ERROR(PadMatrix(a->Shape().Size() / helper.K(),
|
||||
helper.K(),
|
||||
align_size,
|
||||
a_ptr,
|
||||
a_pad_size,
|
||||
a_padded));
|
||||
ORT_RETURN_IF_ERROR(PadMatrix(b->Shape().Size() / helper.N(),
|
||||
helper.N(),
|
||||
align_size,
|
||||
b_ptr,
|
||||
b_pad_size,
|
||||
b_padded));
|
||||
|
||||
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmEx(
|
||||
Base::CublasHandle(),
|
||||
CUBLAS_OP_N,
|
||||
CUBLAS_OP_N,
|
||||
static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.M()),
|
||||
static_cast<int>(helper.K()),
|
||||
&alpha,
|
||||
b_ptr + helper.RightOffsets()[batch] + helper.RightOffsets()[batch] / helper.N() * b_pad_size,
|
||||
CUDA_R_8I,
|
||||
static_cast<int>(helper.N() + b_pad_size),
|
||||
a_ptr + helper.LeftOffsets()[batch] + helper.LeftOffsets()[batch] / helper.K() * a_pad_size,
|
||||
CUDA_R_8I,
|
||||
static_cast<int>(helper.K() + a_pad_size),
|
||||
&beta,
|
||||
output_ptr + helper.OutputOffsets()[batch],
|
||||
CUDA_R_32I,
|
||||
static_cast<int>(helper.N()),
|
||||
CUDA_R_32I,
|
||||
CUBLAS_GEMM_DFALT));
|
||||
GemmInt8(static_cast<int>(helper.M()),
|
||||
static_cast<int>(helper.N()),
|
||||
static_cast<int>(helper.K()),
|
||||
alpha,
|
||||
beta,
|
||||
a_ptr + helper.LeftOffsets()[batch],
|
||||
static_cast<int>(helper.K()),
|
||||
b_ptr + helper.RightOffsets()[batch],
|
||||
static_cast<int>(helper.N()),
|
||||
output_ptr + helper.OutputOffsets()[batch],
|
||||
static_cast<int>(helper.N()),
|
||||
this);
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
|
|
|
|||
|
|
@ -28,11 +28,10 @@ __global__ void ReduceRowSumOnMatrixAKernel(const int8_t* matrix, int32_t* row_s
|
|||
|
||||
Status ReduceRowSumOnMatrixA(const int8_t* matrix, int32_t* row_sum, const int8_t offset, const MatMulComputeHelper& helper) {
|
||||
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
|
||||
ReduceRowSumOnMatrixAKernel<static_cast<int>(GridDim::maxThreadsPerBlock)>
|
||||
<<<static_cast<int>(helper.M()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.LeftOffsets()[batch],
|
||||
row_sum + batch * helper.M(),
|
||||
offset,
|
||||
static_cast<int>(helper.K()));
|
||||
ReduceRowSumOnMatrixAKernel<static_cast<int>(GridDim::maxThreadsPerBlock)><<<static_cast<int>(helper.M()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.LeftOffsets()[batch],
|
||||
row_sum + batch * helper.M(),
|
||||
offset,
|
||||
static_cast<int>(helper.K()));
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
|
||||
|
|
@ -57,12 +56,11 @@ __global__ void ReduceColSumOnMatrixBKernel(const int8_t* matrix, int32_t* col_s
|
|||
|
||||
Status ReduceColSumOnMatrixB(const int8_t* matrix, int32_t* col_sum, const int8_t offset, const MatMulComputeHelper& helper) {
|
||||
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
|
||||
ReduceColSumOnMatrixBKernel<static_cast<int>(GridDim::maxThreadsPerBlock)>
|
||||
<<<static_cast<int>(helper.N()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.RightOffsets()[batch],
|
||||
col_sum + batch * helper.N(),
|
||||
offset,
|
||||
static_cast<int32_t>(helper.K()),
|
||||
static_cast<int32_t>(helper.N()));
|
||||
ReduceColSumOnMatrixBKernel<static_cast<int>(GridDim::maxThreadsPerBlock)><<<static_cast<int>(helper.N()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.RightOffsets()[batch],
|
||||
col_sum + batch * helper.N(),
|
||||
offset,
|
||||
static_cast<int32_t>(helper.K()),
|
||||
static_cast<int32_t>(helper.N()));
|
||||
}
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
|
||||
|
|
@ -128,21 +126,5 @@ Status OffsetOutput(const int32_t* row_sum,
|
|||
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
|
||||
}
|
||||
|
||||
__global__ void PadMatrixInLeadingDimensionKernel(const int8_t* src, int8_t* dst, int col_src, int col_dst) {
|
||||
for (int32_t i = threadIdx.x; i < col_src; i += blockDim.x) {
|
||||
*(dst + blockIdx.x * col_dst + i) = *(src + blockIdx.x * col_src + i);
|
||||
}
|
||||
}
|
||||
|
||||
Status PadMatrixInLeadingDimension(const int8_t* src, int8_t* dst, int64_t row, int64_t col, int64_t pad_size) {
|
||||
PadMatrixInLeadingDimensionKernel<<<static_cast<int>(row), GridDim::maxThreadsPerBlock, 0>>>(
|
||||
src,
|
||||
dst,
|
||||
static_cast<int>(col),
|
||||
static_cast<int>(col + pad_size));
|
||||
|
||||
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
|
||||
;
|
||||
}
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -20,7 +20,5 @@ Status OffsetOutput(const int32_t* row_sum,
|
|||
const int8_t b_offset,
|
||||
const MatMulComputeHelper& helper);
|
||||
|
||||
Status PadMatrixInLeadingDimension(const int8_t* src, int8_t* dst, int64_t row, int64_t col, int64_t pad_size);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace onnxruntime
|
||||
|
|
|
|||
|
|
@ -25,16 +25,6 @@ class MatMulInteger final : public CudaKernel {
|
|||
|
||||
Status ComputeInternal(OpKernelContext* context) const override;
|
||||
|
||||
private:
|
||||
// pad matrix and B to make their leading dimension be multiples of *align_size*
|
||||
Status PadMatrix(
|
||||
int64_t row,
|
||||
int64_t col,
|
||||
int64_t align_size,
|
||||
const int8_t*& src,
|
||||
int64_t& pad_size,
|
||||
IAllocatorUniquePtr<int8_t>& temp_mem_holder) const;
|
||||
|
||||
private:
|
||||
bool has_a_zero_point_;
|
||||
bool has_b_zero_point_;
|
||||
|
|
|
|||
23
onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h
Normal file
23
onnxruntime/core/providers/cuda/shared_inc/integer_gemm.h
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "core/providers/cuda/cuda_common.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace cuda {
|
||||
Status GemmInt8(int m,
|
||||
int n,
|
||||
int k,
|
||||
int32_t alpha_matmul,
|
||||
int32_t beta_matmul,
|
||||
const int8_t* a,
|
||||
int lda,
|
||||
const int8_t* b,
|
||||
int ldb,
|
||||
int32_t* c,
|
||||
int ldc,
|
||||
const CudaKernel* cuda_kernel);
|
||||
}
|
||||
}
|
||||
252
onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Normal file
252
onnxruntime/test/contrib_ops/quantize_attention_op_test.cc
Normal file
|
|
@ -0,0 +1,252 @@
|
|||
// Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
// Licensed under the MIT License.
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "test/common/tensor_op_test_utils.h"
|
||||
#include "test/common/cuda_op_test_utils.h"
|
||||
#include "test/providers/provider_test_utils.h"
|
||||
|
||||
namespace onnxruntime {
|
||||
namespace test {
|
||||
|
||||
template <typename Integer, typename = typename std::enable_if<std::is_integral<Integer>::value, Integer>::type>
|
||||
inline std::vector<Integer> ToInteger(const std::vector<float>& data, float scale) {
|
||||
std::vector<Integer> result;
|
||||
result.reserve(data.size());
|
||||
for (size_t i = 0; i < data.size(); i++) {
|
||||
result.push_back(static_cast<Integer>(std::round(data[i] / scale)));
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
static void RunAttentionTest(
|
||||
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size]
|
||||
const std::vector<float>& weights_data, // weights: [hidden_size, 3 * hidden_size]
|
||||
const std::vector<float>& bias_data, // bias: [3 * hidden_size]
|
||||
const std::vector<int32_t>& mask_index_data, // mask_index: [batch_size]
|
||||
const std::vector<float>& output_data, // output: [batch_size, sequence_length, hidden_size]
|
||||
int batch_size,
|
||||
int sequence_length,
|
||||
int hidden_size,
|
||||
int number_of_heads,
|
||||
bool use_float16 = false) {
|
||||
int min_cuda_architecture = 530;
|
||||
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
|
||||
|
||||
if (enable_cuda) {
|
||||
OpTester tester("QAttention", 1, onnxruntime::kMSDomain);
|
||||
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
|
||||
|
||||
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
|
||||
std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
|
||||
std::vector<int64_t> bias_dims = {3 * hidden_size};
|
||||
std::vector<int64_t> mask_index_dims = {batch_size};
|
||||
std::vector<int64_t> output_dims = input_dims;
|
||||
|
||||
float input_scale = 0.1f;
|
||||
float weight_scale = 0.1f;
|
||||
tester.AddInput<int8_t>("input", input_dims, ToInteger<int8_t>(input_data, input_scale));
|
||||
tester.AddInput<int8_t>("weight", weights_dims, ToInteger<int8_t>(weights_data, weight_scale));
|
||||
if (use_float16) {
|
||||
tester.AddInput<MLFloat16>("bias", bias_dims, ToFloat16(bias_data));
|
||||
tester.AddInput<MLFloat16>("input_scale", {1}, ToFloat16({input_scale}));
|
||||
tester.AddInput<MLFloat16>("weight_scale", {1}, ToFloat16({weight_scale}));
|
||||
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
|
||||
} else {
|
||||
tester.AddInput<float>("bias", bias_dims, bias_data);
|
||||
tester.AddInput<float>("input_scale", {1}, {input_scale});
|
||||
tester.AddInput<float>("weight_scale", {1}, {weight_scale});
|
||||
tester.AddOutput<float>("output", output_dims, output_data);
|
||||
}
|
||||
|
||||
if (mask_index_data.size() > 0) { // mask index is optional.
|
||||
tester.AddInput<int32_t>("mask_index", mask_index_dims, mask_index_data);
|
||||
}
|
||||
|
||||
tester.Run();
|
||||
}
|
||||
}
|
||||
|
||||
TEST(QAttentionTest, AttentionBatch1) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
std::vector<int32_t> mask_index_data = {2L};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
|
||||
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
|
||||
TEST(QAttentionTest, AttentionBatch1_Float16) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
std::vector<int32_t> mask_index_data = {2L};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.15039f, 0.1082763671875f, 4.24609375f, 5.6484375f,
|
||||
3.96679f, 0.072998046875f, 4.24609f, 5.6484375f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads, true);
|
||||
}
|
||||
|
||||
TEST(QAttentionTest, AttentionBatch2) {
|
||||
int batch_size = 2;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f,
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
std::vector<int32_t> mask_index_data = {2L, 2L};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
|
||||
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f,
|
||||
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
|
||||
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
|
||||
TEST(QAttentionTest, AttentionMaskPartialSequence) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// Test mask_index < sequence_length
|
||||
std::vector<int32_t> mask_index_data = {1L};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
8.6899995803833008f, -0.13000002503395081f, 4.25f, 5.6499996185302734f,
|
||||
8.6899995803833008f, -0.13000002503395081f, 4.2499995231628418f, 5.6499991416931152f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
|
||||
TEST(QAttentionTest, AttentionMaskExceedSequence) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// Test mask_index > sequence_length
|
||||
std::vector<int32_t> mask_index_data = {3L};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
|
||||
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
|
||||
TEST(QAttentionTest, AttentionNoMaskIndex) {
|
||||
int batch_size = 1;
|
||||
int sequence_length = 2;
|
||||
int hidden_size = 4;
|
||||
int number_of_heads = 2;
|
||||
|
||||
std::vector<float> input_data = {
|
||||
0.8f, -0.5f, 0.0f, 1.f,
|
||||
0.5f, 0.2f, 0.3f, -0.6f};
|
||||
|
||||
std::vector<float> weight_data = {
|
||||
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
|
||||
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
|
||||
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
|
||||
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
|
||||
|
||||
std::vector<float> bias_data = {
|
||||
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
|
||||
|
||||
// No mask_index
|
||||
std::vector<int32_t> mask_index_data = {};
|
||||
|
||||
std::vector<float> output_data = {
|
||||
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
|
||||
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
|
||||
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads);
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace onnxruntime
|
||||
Loading…
Reference in a new issue