Quantize attention with Cuda (#3693)

* Add definition of QAttention
* implemention of QAttention on GPU
This commit is contained in:
Yufeng Li 2020-05-04 14:20:38 -07:00 committed by GitHub
parent 49f0610447
commit 156368b67f
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 753 additions and 122 deletions

View file

@ -28,18 +28,19 @@ AttentionBase::AttentionBase(const OpKernelInfo& info) {
int64_t num_heads = 0;
ORT_ENFORCE(info.GetAttr("num_heads", &num_heads).IsOK() && num_heads > 0);
num_heads_ = static_cast<int>(num_heads);
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 1) == 1;
is_unidirectional_ = info.GetAttrOrDefault<int64_t>("unidirectional", 0) == 1;
}
Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
Status AttentionBase::CheckInputs(const Tensor* input,
const Tensor* weights,
const Tensor* bias,
const Tensor* mask_index) const {
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - mask_index : (batch_size) if presented
// Output : (batch_size, sequence_length, hidden_size)
// input : (batch_size, sequence_length, hidden_size)
// weights : (hidden_size, 3 * hidden_size)
// bias : (3 * hidden_size)
// mask_index : (batch_size) if presented
const Tensor* input = context->Input<Tensor>(0);
const auto dims = input->Shape().GetDims();
if (dims.size() != 3) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 0 is expected to have 3 dimensions, got ",
@ -52,7 +53,6 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
"Input 0 dimension 2 should be divisiable by value of the num_heads attribute.");
}
const Tensor* weights = context->Input<Tensor>(1);
const auto weights_dims = weights->Shape().GetDims();
if (weights_dims.size() != 2) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 1 is expected to have 2 dimensions, got ",
@ -66,7 +66,6 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 1 dimension 1 should be 3 times of dimension 0");
}
const Tensor* bias = context->Input<Tensor>(2);
const auto bias_dims = bias->Shape().GetDims();
if (bias_dims.size() != 1) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 2 is expected to have 1 dimension, got ",
@ -77,8 +76,7 @@ Status AttentionBase::CheckInputs(const OpKernelContext* context) const {
"Input 2 dimension 0 should have same length as dimension 1 of input 1");
}
const Tensor* mask_index = context->Input<Tensor>(3);
if (mask_index != nullptr) {
if (mask_index != nullptr) { // mask_index is optional
// unidirectional (like GPT2) does not need mask input. Here we do not allowed the input for unidirectional.
if (is_unidirectional_) {
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Input 3 (mask_index) is not allowed for unidirectional");
@ -103,13 +101,11 @@ Attention<T>::Attention(const OpKernelInfo& info) : OpKernel(info), AttentionBas
template <typename T>
Status Attention<T>::Compute(OpKernelContext* context) const {
auto* tp = context->GetOperatorThreadPool();
ORT_RETURN_IF_ERROR(CheckInputs(context));
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* mask_index = context->Input<Tensor>(3);
ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, mask_index));
const auto dims = input->Shape().GetDims();
const int batch_size = static_cast<int>(dims[0]);
@ -125,6 +121,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
AllocatorPtr allocator;
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
auto* tp = context->GetOperatorThreadPool();
// STEP.1: gemm_data(BS, 3NH) = input(BS, NH) x weights(NH, 3NH) + bias(3NH)
auto gemm_data = allocator->Alloc(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
BufferUniquePtr gemm_buffer(gemm_data, BufferDeleter(allocator));
@ -180,7 +177,7 @@ Status Attention<T>::Compute(OpKernelContext* context) const {
qkv_dest + qkv_offset, // C
head_size, // ldc
nullptr // use single-thread
);
);
}
});
}

View file

@ -12,7 +12,10 @@ namespace contrib {
class AttentionBase {
protected:
AttentionBase(const OpKernelInfo& info);
Status CheckInputs(const OpKernelContext* context) const;
Status CheckInputs(const Tensor* input,
const Tensor* weights,
const Tensor* bias,
const Tensor* mask_index) const;
int num_heads_; // number of attention heads
bool is_unidirectional_; // whether every token can only attend to previous tokens.

View file

@ -34,8 +34,6 @@ Attention<T>::Attention(const OpKernelInfo& info) : CudaKernel(info), AttentionB
template <typename T>
Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
ORT_RETURN_IF_ERROR(CheckInputs(context));
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
@ -46,6 +44,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
const Tensor* weights = context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* mask_index = context->Input<Tensor>(3);
ORT_RETURN_IF_ERROR(CheckInputs(input, weights, bias, mask_index));
const auto dims = input->Shape().GetDims();
int batch_size = static_cast<int>(dims[0]);

View file

@ -68,6 +68,8 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention);
Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
static const BuildKernelCreateInfoFn function_table[] = {
@ -130,7 +132,9 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, QuantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, int8_t_MLFloat16, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear)>};
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, uint8_t_MLFloat16, DequantizeLinear)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float_int8_t, QAttention)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16_int8_t, QAttention)>};
for (auto& function_table_entry : function_table) {
ORT_RETURN_IF_ERROR(kernel_registry.Register(function_table_entry()));

View file

@ -0,0 +1,189 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "attention_quantization.h"
#include "attention_quantization_impl.cuh"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "core/framework/tensorprotoutils.h"
#include "core/providers/common.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/providers/cuda/shared_inc/integer_gemm.h"
#include "core/providers/cuda/tensor/quantize_linear.h"
using namespace onnxruntime::cuda;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T, TQuant) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
QAttention, \
kMSDomain, \
1, \
T##_##TQuant, \
kCudaExecutionProvider, \
KernelDefBuilder() \
.InputMemoryType<OrtMemTypeCPUInput>(3) \
.InputMemoryType<OrtMemTypeCPUInput>(4) \
.InputMemoryType<OrtMemTypeCPUInput>(6) \
.InputMemoryType<OrtMemTypeCPUInput>(7) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<TQuant>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<TQuant>()) \
.TypeConstraint("T3", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T4", DataTypeImpl::GetTensorType<int32_t>()), \
QAttention<T, TQuant>);
REGISTER_KERNEL_TYPED(float, int8_t)
REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
template <typename T>
Status QAttention<T, int8_t>::CheckInputs(const Tensor* input,
const Tensor* weights,
const Tensor* bias,
const Tensor* input_scale_tensor,
const Tensor* weight_scale_tensor,
const Tensor* mask_index,
const Tensor* i_zp_tensor,
const Tensor* w_zp_tensor) const {
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - input_scale : scalar
// Input 4 - weight_scale : scalar
// Input 5 - mask_index : (batch_size)
// Input 6 - input_zero_point : scalar
// Input 7 - weight_zero_point : scalar
// Output : (batch_size, sequence_length, hidden_size)
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input, weights, bias, mask_index));
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor),
"input scale must be a scalar or 1D tensor of size 1");
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(weight_scale_tensor),
"weight must be a scalar or 1D tensor of size 1");
if (i_zp_tensor != nullptr) {
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(i_zp_tensor),
"input zero point must be a scalar or 1D tensor of size 1.");
if (0 != *(i_zp_tensor->template Data<int8_t>()))
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention");
}
if (w_zp_tensor != nullptr) {
// CUDA only support symmetric quantization for Attention
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(w_zp_tensor),
"weight zero point must be a scalar or 1D tensor of size 1.");
if (0 != *(w_zp_tensor->template Data<int8_t>()))
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention");
}
return Status::OK();
}
template <typename T>
Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, hidden_size)
// Input 1 - weights : (hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - input_scale : scalar
// Input 4 - weight_scale : scalar
// Input 5 - mask_index : (batch_size)
// Input 6 - input_zero_point : scalar
// Input 7 - weight_zero_point : scalar
// Output : (batch_size, sequence_length, hidden_size)
// ORT_RETURN_IF_ERROR(CheckInputs(context));
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* input_scale_tensor = context->Input<Tensor>(3);
const Tensor* weight_scale_tensor = context->Input<Tensor>(4);
const Tensor* mask_index = context->Input<Tensor>(5);
const Tensor* i_zp_tensor = context->Input<Tensor>(6);
const Tensor* w_zp_tensor = context->Input<Tensor>(7);
ORT_RETURN_IF_ERROR(CheckInputs(input,
weights,
bias,
input_scale_tensor,
weight_scale_tensor,
mask_index,
i_zp_tensor,
w_zp_tensor));
const auto dims = input->Shape().GetDims();
/*int input_size = static_cast<int>(input->Shape().Size());*/
int batch_size = static_cast<int>(dims[0]);
int sequence_length = static_cast<int>(dims[1]);
int hidden_size = static_cast<int>(dims[2]);
int head_size = hidden_size / num_heads_;
TensorShape output_shape(dims);
Tensor* output = context->Output(0, output_shape);
cublasHandle_t cublas = CublasHandle();
const size_t element_size = sizeof(T);
// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(batch_size * sequence_length * 3 * hidden_size * element_size);
auto gemm_buffer_quantized = GetScratchBuffer<int32_t>(batch_size * sequence_length * 3 * hidden_size);
typedef typename ToCudaType<T>::MappedType CudaT;
GemmInt8(m, n, k,
1 /*alpha_matmul*/, 0 /* beta_matmul*/,
input->template Data<int8_t>(), k,
weights->template Data<int8_t>(), n,
gemm_buffer_quantized.get(), n,
this);
CudaT dequant_scale;
CudaT input_scale = *(reinterpret_cast<const CudaT*>(input_scale_tensor->template Data<T>()));
CudaT weight_scale = *(reinterpret_cast<const CudaT*>(weight_scale_tensor->template Data<T>()));
if (sizeof(T) == 2) {
dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale));
} else {
dequant_scale = input_scale * weight_scale;
}
// scale back and bias
CudaDequantizeWithBias(
gemm_buffer_quantized.get(),
reinterpret_cast<const CudaT*>(bias->template Data<T>()),
reinterpret_cast<CudaT*>(gemm_buffer.get()),
dequant_scale,
m,
n);
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size, sequence_length);
auto temp_buffer = GetScratchBuffer<void>(workSpaceSize);
if (!LaunchAttentionKernel(
reinterpret_cast<const CudaT*>(gemm_buffer.get()),
nullptr == mask_index ? nullptr : mask_index->template Data<int>(),
output->template MutableData<T>(),
batch_size,
sequence_length,
num_heads_,
head_size,
temp_buffer.get(),
cublas,
element_size,
false)) {
// Get last error to reset it to cudaSuccess.
CUDA_CALL(cudaGetLastError());
return Status(common::ONNXRUNTIME, common::FAIL);
}
return Status::OK();
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,44 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/providers/cuda/cuda_common.h"
#include "contrib_ops/cpu/bert/attention.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;
template <typename T, typename TQuant>
class QAttention;
template <typename T>
class QAttention<T, int8_t> final : public CudaKernel, public AttentionBase {
using Base = CudaKernel;
public:
QAttention(const OpKernelInfo& info) : CudaKernel(info),
AttentionBase(info) {
}
Status ComputeInternal(OpKernelContext* context) const override;
private:
Status CheckInputs(const Tensor* input,
const Tensor* weights,
const Tensor* bias,
const Tensor* input_scale_tensor,
const Tensor* weight_scale_tensor,
const Tensor* mask_index,
const Tensor* i_zp_tensor,
const Tensor* w_zp_tensor) const;
};
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,52 @@
// Modifications: scaling is moved from masked softmax to the gemm before that.
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include <cub/cub.cuh>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <math_constants.h>
#include "core/providers/cuda/cu_inc/common.cuh"
#include "core/providers/cuda/cuda_common.h"
#include "attention_quantization_impl.cuh"
using namespace onnxruntime::cuda;
using namespace cub;
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <class T, int NumThreadsPerBlock, int NumElementsPerThread>
__global__ void DequantizeLinearKernel(const int32_t* quantize, const T* bias, T* output, T scale, int bias_len, CUDA_LONG N) {
CUDA_LONG id = NumElementsPerThread * NumThreadsPerBlock * blockIdx.x + threadIdx.x;
#pragma unroll
for (int i = 0; i < NumElementsPerThread; i++) {
if (id < N) {
output[id] = (static_cast<T>(quantize[id]) * scale) + bias[id % bias_len];
id += NumThreadsPerBlock;
}
}
}
template <class T>
Status CudaDequantizeWithBias(const int32_t* quantize, const T* bias, T* output, T scale, int m, int n) {
int blocksPerGrid = static_cast<int>(CeilDiv(m * n, GridDim::maxThreadsPerBlock * GridDim::maxElementsPerThread));
CUDA_LONG N = static_cast<CUDA_LONG>(m * n);
DequantizeLinearKernel<T, GridDim::maxThreadsPerBlock, GridDim::maxElementsPerThread><<<blocksPerGrid, GridDim::maxThreadsPerBlock, 0>>>(
quantize,
bias,
output,
scale,
n,
N);
return Status::OK();
}
template Status CudaDequantizeWithBias<float>(const int32_t* quantize, const float* bias, float* output, float scale, int m, int n);
template Status CudaDequantizeWithBias<half>(const int32_t* quantize, const half* bias, half* output, half scale, int m, int n);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -0,0 +1,15 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/shared_inc/cuda_utils.h"
namespace onnxruntime {
namespace contrib {
namespace cuda {
template <class Tin>
Status CudaDequantizeWithBias(const int32_t* quantize, const Tin* bias, Tin* output, Tin scale, int m, int n);
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime

View file

@ -314,6 +314,79 @@ mask_index shall not be provided.)DOC";
.TypeConstraint("M", {"tensor(int32)"}, "Constrain mask index to integer types")
.TypeAndShapeInferenceFunction(ONNX_NAMESPACE::propagateShapeAndTypeFromFirstInput);
ONNX_CONTRIB_OPERATOR_SCHEMA(QAttention)
.SetDomain(kMSDomain)
.SinceVersion(1)
.SetSupportLevel(OpSchema::SupportType::EXPERIMENTAL)
.SetDoc("Quantization of Multi-Head Self Attention.")
.Attr("num_heads", "Number of attention heads", AttributeProto::INT)
.Input(
0,
"input",
"3D input tensor with shape (batch_size, sequence_length, hidden_size), hidden_size = num_heads * head_size",
"T1")
.Input(
1,
"weight",
"2D input tensor with shape (hidden_size, 3 * hidden_size)",
"T2")
.Input(
2,
"bias",
"1D input tensor with shape (3 * hidden_size)",
"T3")
.Input(
3,
"input_scale",
"scale of quantized input tensor. It's a scalar, which means a per-tensor/layer quantization.",
"T3")
.Input(
4,
"weight_scale",
"scale of weight scale. It's a scalar, which means a per-tensor/layer quantization.",
"T3")
.Input(
5,
"mask_index",
"Attention mask index with shape (batch_size)",
"T4",
OpSchema::Optional)
.Input(
6,
"input_zero_point",
"zero point of quantized input tensor.It's a scalar, which means a per-tensor/layer quantization.",
"T1",
OpSchema::Optional)
.Input(
7,
"weight_zero_point",
"zero point of quantized weight tensor. It's a scalar, which means a per-tensor/layer quantization.",
"T2",
OpSchema::Optional)
.Output(
0,
"output",
"3D output tensor with shape (batch_size, sequence_length, hidden_size)",
"T3")
.TypeConstraint("T1", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.")
.TypeConstraint("T2", {"tensor(int8)", "tensor(uint8)"}, "Constrain input and output types to int8 tensors.")
.TypeConstraint("T3", {"tensor(float)", "tensor(float16)"}, "Constrain input and output types to float tensors.")
.TypeConstraint("T4", {"tensor(int32)"}, "Constrain mask index to integer types")
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
// Type inference
ONNX_NAMESPACE::propagateElemTypeFromInputToOutput(ctx, 2, 0);
// Shape inference
// if the input shape doesn't exist, further shape inference is not possible
if (!hasNInputShapes(ctx, 1)) {
return;
}
ONNX_NAMESPACE::propagateShapeFromInputToOutput(ctx, 0, 0);
return;
});
static const char* EmbedLayerNormalization_ver1_doc = R"DOC(
EmbedLayerNormalization is the fusion of embedding layer in BERT model, with optional mask processing.
The embedding layer takes input_ids (word IDs) and segment_ids (sentence IDs) to look up word_embedding, position_embedding,

View file

@ -161,7 +161,6 @@ class CudaKernel : public OpKernel {
const CudaKernel* op_kernel_;
};
protected:
inline cublasHandle_t CublasHandle() const {
return provider_->PerThreadCublasHandle();
}
@ -169,6 +168,8 @@ class CudaKernel : public OpKernel {
inline cudnnHandle_t CudnnHandle() const {
return provider_->PerThreadCudnnHandle();
}
protected:
inline curandGenerator_t CurandGenerator() const {
return provider_->PerThreadCurandGenerator();
}

View file

@ -0,0 +1,58 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "core/providers/cuda/shared_inc/integer_gemm.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
namespace onnxruntime {
namespace cuda {
inline int roundoff(int v, int d) {
return (v + d - 1) / d * d;
}
Status GemmInt8(int m, int n, int k,
int32_t alpha, int32_t beta,
const int8_t* a, int lda, const int8_t* b, int ldb, int32_t* c, int ldc,
const CudaKernel* cuda_kernel) {
ORT_ENFORCE(a != nullptr && b != nullptr && c != nullptr, "input matrix should not be null");
ORT_ENFORCE(cuda_kernel != nullptr, "kernel is null");
// pad A and B to make their leading dimension be multiples of 32
// because cublasGemmEx requires:
// 1. leading dimension is multiples of 4
// 2. A, B is 32-bit aligned
const int mask = 0x1F;
int lda_aligned = lda;
IAllocatorUniquePtr<int8_t> a_padded;
if ((mask & lda_aligned) != 0) {
lda_aligned = roundoff(lda, 32);
a_padded = cuda_kernel->GetScratchBuffer<int8_t>(m * lda_aligned);
cudaMemcpy2DAsync(a_padded.get(), lda_aligned, a, lda, k, m, cudaMemcpyDeviceToDevice, 0);
}
int ldb_aligned = ldb;
IAllocatorUniquePtr<int8_t> b_padded;
if ((mask & ldb_aligned) != 0) {
ldb_aligned = roundoff(ldb, 32);
b_padded = cuda_kernel->GetScratchBuffer<int8_t>(k * ldb_aligned);
cudaMemcpy2DAsync(b_padded.get(), ldb_aligned, b, ldb, n, k, cudaMemcpyDeviceToDevice, 0);
}
CUBLAS_RETURN_IF_ERROR(cublasGemmEx(
cuda_kernel->CublasHandle(),
CUBLAS_OP_N, CUBLAS_OP_N,
n, m, k,
&alpha,
ldb_aligned == ldb ? b : b_padded.get(), CUDA_R_8I, ldb_aligned,
lda_aligned == lda ? a : a_padded.get(), CUDA_R_8I, lda_aligned,
&beta,
c, CUDA_R_32I, ldc, CUDA_R_32I,
CUBLAS_GEMM_DFALT));
return Status::OK();
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -5,6 +5,7 @@
#include "matmul_integer.cuh"
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/providers/cuda/shared_inc/integer_gemm.h"
#include "core/providers/cuda/cuda_allocator.h"
#include "core/providers/common.h"
@ -25,26 +26,6 @@ ONNX_OPERATOR_TYPED_KERNEL_EX(
.TypeConstraint("T3", DataTypeImpl::GetTensorType<int32_t>()),
MatMulInteger<int8_t, int8_t>);
template <>
Status MatMulInteger<int8_t, int8_t>::PadMatrix(
int64_t row,
int64_t col,
int64_t align_size,
const int8_t*& src,
int64_t& pad_size,
IAllocatorUniquePtr<int8_t>& temp_mem_holder) const {
pad_size = align_size - col % align_size;
if (pad_size != align_size) {
temp_mem_holder = GetScratchBuffer<int8_t>(row * (col + pad_size));
ORT_RETURN_IF_ERROR(PadMatrixInLeadingDimension(src, temp_mem_holder.get(), row, col, pad_size));
src = temp_mem_holder.get();
} else {
pad_size = 0;
}
return Status::OK();
}
template <>
Status MatMulInteger<int8_t, int8_t>::ComputeInternal(OpKernelContext* ctx) const {
auto a = ctx->Input<Tensor>(0);
@ -106,49 +87,19 @@ Status MatMulInteger<int8_t, int8_t>::ComputeInternal(OpKernelContext* ctx) cons
beta = 1;
}
// pad A and B to make their leading dimension be multiples of 32
// because cublasGemmEx requires:
// 1. leading dimension is multiples of 4
// 2. A, B is 32-bit aligned
const int64_t align_size = 32;
int64_t a_pad_size = 0;
int64_t b_pad_size = 0;
IAllocatorUniquePtr<int8_t> a_padded;
IAllocatorUniquePtr<int8_t> b_padded;
ORT_RETURN_IF_ERROR(PadMatrix(a->Shape().Size() / helper.K(),
helper.K(),
align_size,
a_ptr,
a_pad_size,
a_padded));
ORT_RETURN_IF_ERROR(PadMatrix(b->Shape().Size() / helper.N(),
helper.N(),
align_size,
b_ptr,
b_pad_size,
b_padded));
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
CUBLAS_RETURN_IF_ERROR(cublasGemmEx(
Base::CublasHandle(),
CUBLAS_OP_N,
CUBLAS_OP_N,
static_cast<int>(helper.N()),
static_cast<int>(helper.M()),
static_cast<int>(helper.K()),
&alpha,
b_ptr + helper.RightOffsets()[batch] + helper.RightOffsets()[batch] / helper.N() * b_pad_size,
CUDA_R_8I,
static_cast<int>(helper.N() + b_pad_size),
a_ptr + helper.LeftOffsets()[batch] + helper.LeftOffsets()[batch] / helper.K() * a_pad_size,
CUDA_R_8I,
static_cast<int>(helper.K() + a_pad_size),
&beta,
output_ptr + helper.OutputOffsets()[batch],
CUDA_R_32I,
static_cast<int>(helper.N()),
CUDA_R_32I,
CUBLAS_GEMM_DFALT));
GemmInt8(static_cast<int>(helper.M()),
static_cast<int>(helper.N()),
static_cast<int>(helper.K()),
alpha,
beta,
a_ptr + helper.LeftOffsets()[batch],
static_cast<int>(helper.K()),
b_ptr + helper.RightOffsets()[batch],
static_cast<int>(helper.N()),
output_ptr + helper.OutputOffsets()[batch],
static_cast<int>(helper.N()),
this);
}
return Status::OK();

View file

@ -28,11 +28,10 @@ __global__ void ReduceRowSumOnMatrixAKernel(const int8_t* matrix, int32_t* row_s
Status ReduceRowSumOnMatrixA(const int8_t* matrix, int32_t* row_sum, const int8_t offset, const MatMulComputeHelper& helper) {
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
ReduceRowSumOnMatrixAKernel<static_cast<int>(GridDim::maxThreadsPerBlock)>
<<<static_cast<int>(helper.M()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.LeftOffsets()[batch],
row_sum + batch * helper.M(),
offset,
static_cast<int>(helper.K()));
ReduceRowSumOnMatrixAKernel<static_cast<int>(GridDim::maxThreadsPerBlock)><<<static_cast<int>(helper.M()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.LeftOffsets()[batch],
row_sum + batch * helper.M(),
offset,
static_cast<int>(helper.K()));
}
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
@ -57,12 +56,11 @@ __global__ void ReduceColSumOnMatrixBKernel(const int8_t* matrix, int32_t* col_s
Status ReduceColSumOnMatrixB(const int8_t* matrix, int32_t* col_sum, const int8_t offset, const MatMulComputeHelper& helper) {
for (size_t batch = 0; batch < helper.OutputOffsets().size(); batch++) {
ReduceColSumOnMatrixBKernel<static_cast<int>(GridDim::maxThreadsPerBlock)>
<<<static_cast<int>(helper.N()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.RightOffsets()[batch],
col_sum + batch * helper.N(),
offset,
static_cast<int32_t>(helper.K()),
static_cast<int32_t>(helper.N()));
ReduceColSumOnMatrixBKernel<static_cast<int>(GridDim::maxThreadsPerBlock)><<<static_cast<int>(helper.N()), GridDim::maxThreadsPerBlock, 0>>>(matrix + helper.RightOffsets()[batch],
col_sum + batch * helper.N(),
offset,
static_cast<int32_t>(helper.K()),
static_cast<int32_t>(helper.N()));
}
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
@ -128,21 +126,5 @@ Status OffsetOutput(const int32_t* row_sum,
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
}
__global__ void PadMatrixInLeadingDimensionKernel(const int8_t* src, int8_t* dst, int col_src, int col_dst) {
for (int32_t i = threadIdx.x; i < col_src; i += blockDim.x) {
*(dst + blockIdx.x * col_dst + i) = *(src + blockIdx.x * col_src + i);
}
}
Status PadMatrixInLeadingDimension(const int8_t* src, int8_t* dst, int64_t row, int64_t col, int64_t pad_size) {
PadMatrixInLeadingDimensionKernel<<<static_cast<int>(row), GridDim::maxThreadsPerBlock, 0>>>(
src,
dst,
static_cast<int>(col),
static_cast<int>(col + pad_size));
return CUDA_CALL(cudaPeekAtLastError()) ? Status::OK() : Status(common::ONNXRUNTIME, common::FAIL);
;
}
} // namespace cuda
} // namespace onnxruntime

View file

@ -20,7 +20,5 @@ Status OffsetOutput(const int32_t* row_sum,
const int8_t b_offset,
const MatMulComputeHelper& helper);
Status PadMatrixInLeadingDimension(const int8_t* src, int8_t* dst, int64_t row, int64_t col, int64_t pad_size);
} // namespace cuda
} // namespace onnxruntime

View file

@ -25,16 +25,6 @@ class MatMulInteger final : public CudaKernel {
Status ComputeInternal(OpKernelContext* context) const override;
private:
// pad matrix and B to make their leading dimension be multiples of *align_size*
Status PadMatrix(
int64_t row,
int64_t col,
int64_t align_size,
const int8_t*& src,
int64_t& pad_size,
IAllocatorUniquePtr<int8_t>& temp_mem_holder) const;
private:
bool has_a_zero_point_;
bool has_b_zero_point_;

View file

@ -0,0 +1,23 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/providers/cuda/cuda_common.h"
namespace onnxruntime {
namespace cuda {
Status GemmInt8(int m,
int n,
int k,
int32_t alpha_matmul,
int32_t beta_matmul,
const int8_t* a,
int lda,
const int8_t* b,
int ldb,
int32_t* c,
int ldc,
const CudaKernel* cuda_kernel);
}
}

View file

@ -0,0 +1,252 @@
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "gtest/gtest.h"
#include "test/common/tensor_op_test_utils.h"
#include "test/common/cuda_op_test_utils.h"
#include "test/providers/provider_test_utils.h"
namespace onnxruntime {
namespace test {
template <typename Integer, typename = typename std::enable_if<std::is_integral<Integer>::value, Integer>::type>
inline std::vector<Integer> ToInteger(const std::vector<float>& data, float scale) {
std::vector<Integer> result;
result.reserve(data.size());
for (size_t i = 0; i < data.size(); i++) {
result.push_back(static_cast<Integer>(std::round(data[i] / scale)));
}
return result;
}
static void RunAttentionTest(
const std::vector<float>& input_data, // input: [batch_size, sequence_length, hidden_size]
const std::vector<float>& weights_data, // weights: [hidden_size, 3 * hidden_size]
const std::vector<float>& bias_data, // bias: [3 * hidden_size]
const std::vector<int32_t>& mask_index_data, // mask_index: [batch_size]
const std::vector<float>& output_data, // output: [batch_size, sequence_length, hidden_size]
int batch_size,
int sequence_length,
int hidden_size,
int number_of_heads,
bool use_float16 = false) {
int min_cuda_architecture = 530;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture);
if (enable_cuda) {
OpTester tester("QAttention", 1, onnxruntime::kMSDomain);
tester.AddAttribute<int64_t>("num_heads", static_cast<int64_t>(number_of_heads));
std::vector<int64_t> input_dims = {batch_size, sequence_length, hidden_size};
std::vector<int64_t> weights_dims = {hidden_size, 3 * hidden_size};
std::vector<int64_t> bias_dims = {3 * hidden_size};
std::vector<int64_t> mask_index_dims = {batch_size};
std::vector<int64_t> output_dims = input_dims;
float input_scale = 0.1f;
float weight_scale = 0.1f;
tester.AddInput<int8_t>("input", input_dims, ToInteger<int8_t>(input_data, input_scale));
tester.AddInput<int8_t>("weight", weights_dims, ToInteger<int8_t>(weights_data, weight_scale));
if (use_float16) {
tester.AddInput<MLFloat16>("bias", bias_dims, ToFloat16(bias_data));
tester.AddInput<MLFloat16>("input_scale", {1}, ToFloat16({input_scale}));
tester.AddInput<MLFloat16>("weight_scale", {1}, ToFloat16({weight_scale}));
tester.AddOutput<MLFloat16>("output", output_dims, ToFloat16(output_data));
} else {
tester.AddInput<float>("bias", bias_dims, bias_data);
tester.AddInput<float>("input_scale", {1}, {input_scale});
tester.AddInput<float>("weight_scale", {1}, {weight_scale});
tester.AddOutput<float>("output", output_dims, output_data);
}
if (mask_index_data.size() > 0) { // mask index is optional.
tester.AddInput<int32_t>("mask_index", mask_index_dims, mask_index_data);
}
tester.Run();
}
}
TEST(QAttentionTest, AttentionBatch1) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
std::vector<int32_t> mask_index_data = {2L};
std::vector<float> output_data = {
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads);
}
TEST(QAttentionTest, AttentionBatch1_Float16) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
std::vector<int32_t> mask_index_data = {2L};
std::vector<float> output_data = {
3.15039f, 0.1082763671875f, 4.24609375f, 5.6484375f,
3.96679f, 0.072998046875f, 4.24609f, 5.6484375f};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads, true);
}
TEST(QAttentionTest, AttentionBatch2) {
int batch_size = 2;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f,
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
std::vector<int32_t> mask_index_data = {2L, 2L};
std::vector<float> output_data = {
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f,
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads);
}
TEST(QAttentionTest, AttentionMaskPartialSequence) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
// Test mask_index < sequence_length
std::vector<int32_t> mask_index_data = {1L};
std::vector<float> output_data = {
8.6899995803833008f, -0.13000002503395081f, 4.25f, 5.6499996185302734f,
8.6899995803833008f, -0.13000002503395081f, 4.2499995231628418f, 5.6499991416931152f};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads);
}
TEST(QAttentionTest, AttentionMaskExceedSequence) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
// Test mask_index > sequence_length
std::vector<int32_t> mask_index_data = {3L};
std::vector<float> output_data = {
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads);
}
TEST(QAttentionTest, AttentionNoMaskIndex) {
int batch_size = 1;
int sequence_length = 2;
int hidden_size = 4;
int number_of_heads = 2;
std::vector<float> input_data = {
0.8f, -0.5f, 0.0f, 1.f,
0.5f, 0.2f, 0.3f, -0.6f};
std::vector<float> weight_data = {
0.1f, -0.2f, 0.3f, 1.0f, 1.1f, 0.3f, 0.5f, 0.2f, 0.3f, -0.6f, 1.5f, 2.0f,
0.5f, 0.1f, 0.4f, 1.6f, 1.0f, 2.0f, 0.4f, 0.8f, 0.9f, 0.1f, -1.3f, 0.7f,
0.3f, 0.2f, 4.0f, 2.2f, 1.6f, 1.1f, 0.7f, 0.2f, 0.4f, 1.0f, 1.2f, 0.5f,
0.2f, 0.1f, 0.4f, 1.6f, 2.4f, 3.3f, 2.1f, 4.2f, 8.4f, 0.0f, 2.1f, 3.2f};
std::vector<float> bias_data = {
-0.5f, 0.6f, 1.2f, 2.1f, 0.5f, 0.7f, 0.2f, 1.2f, 0.5f, 0.4f, 0.3f, 1.2f};
// No mask_index
std::vector<int32_t> mask_index_data = {};
std::vector<float> output_data = {
3.1495983600616455f, 0.10843668878078461f, 4.25f, 5.6499996185302734f,
3.9696791172027588f, 0.073143675923347473f, 4.2499995231628418f, 5.6499991416931152f};
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads);
}
} // namespace test
} // namespace onnxruntime