mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
**Description**: Add qkv_hidden_size support in CUDA Attention Layer implementation. Changes include: - Modify UT to test GPU and CPU implementation - Add overload for CUDA kernel `AddBiasTransposeQKV` to support scenario where V_HIDDEN_SIZE != QK_HIDDEN_SIZE - Update variable names from `head_size` to `qkv_head_sizes[0]` or `qkv_head_sizes[2]` - Modify function definitions to allow communication of `qkv_hidden_sizes` or `qkv_head_sizes` Note that this feature is not supported in Rocm EP or quantized attention right now. **Motivation and Context** - Why is this change required? What problem does it solve? The current CUDA implementation of attention layer doesn't support the parameter qkv_hidden_size added in the CPU implementation in PR [8039](https://github.com/microsoft/onnxruntime/pull/8039) - If it fixes an open issue, please link to the issue here. Co-authored-by: Peter Mcaughan <petermca@microsoft.com>
204 lines
9.7 KiB
C++
204 lines
9.7 KiB
C++
// Copyright (c) Microsoft Corporation. All rights reserved.
|
|
// Licensed under the MIT License.
|
|
|
|
#include "attention_quantization.h"
|
|
#include "attention_quantization_impl.cuh"
|
|
#include "contrib_ops/cuda/bert/attention_impl.h"
|
|
#include "core/common/safeint.h"
|
|
#include "core/providers/cuda/cuda_common.h"
|
|
#include "core/providers/cuda/shared_inc/fpgeneric.h"
|
|
#include "core/providers/cuda/shared_inc/integer_gemm.h"
|
|
#include "core/providers/cuda/tensor/quantize_linear.h"
|
|
|
|
using namespace onnxruntime::cuda;
|
|
using namespace onnxruntime::common;
|
|
|
|
namespace onnxruntime {
|
|
namespace contrib {
|
|
namespace cuda {
|
|
|
|
#define REGISTER_KERNEL_TYPED(T, TQuant) \
|
|
ONNX_OPERATOR_TYPED_KERNEL_EX( \
|
|
QAttention, \
|
|
kMSDomain, \
|
|
1, \
|
|
T##_##TQuant, \
|
|
kCudaExecutionProvider, \
|
|
(*KernelDefBuilder::Create()) \
|
|
.InputMemoryType(OrtMemTypeCPUInput, 3) \
|
|
.InputMemoryType(OrtMemTypeCPUInput, 4) \
|
|
.InputMemoryType(OrtMemTypeCPUInput, 6) \
|
|
.InputMemoryType(OrtMemTypeCPUInput, 7) \
|
|
.TypeConstraint("T1", DataTypeImpl::GetTensorType<TQuant>()) \
|
|
.TypeConstraint("T2", DataTypeImpl::GetTensorType<TQuant>()) \
|
|
.TypeConstraint("T3", DataTypeImpl::GetTensorType<T>()) \
|
|
.TypeConstraint("T4", DataTypeImpl::GetTensorType<int32_t>()), \
|
|
QAttention<T, TQuant>);
|
|
|
|
REGISTER_KERNEL_TYPED(float, int8_t)
|
|
REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
|
|
|
|
template <typename T>
|
|
Status QAttention<T, int8_t>::CheckInputs(const Tensor* input,
|
|
const Tensor* weights,
|
|
const Tensor* bias,
|
|
const Tensor* input_scale_tensor,
|
|
const Tensor* weight_scale_tensor,
|
|
const Tensor*& mask_index,
|
|
const Tensor* i_zp_tensor,
|
|
const Tensor* w_zp_tensor,
|
|
const Tensor* past_tensor) const {
|
|
auto& device_prop = GetDeviceProp();
|
|
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(),
|
|
mask_index, past_tensor, nullptr, device_prop.maxThreadsPerBlock));
|
|
|
|
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor),
|
|
"input scale must be a scalar or 1D tensor of size 1");
|
|
|
|
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(weight_scale_tensor),
|
|
"weight must be a scalar or 1D tensor of size 1");
|
|
|
|
if (i_zp_tensor != nullptr) {
|
|
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(i_zp_tensor),
|
|
"input zero point must be a scalar or 1D tensor of size 1.");
|
|
if (0 != *(i_zp_tensor->Data<int8_t>()))
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention");
|
|
}
|
|
|
|
if (w_zp_tensor != nullptr) {
|
|
// CUDA only support symmetric quantization for Attention
|
|
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(w_zp_tensor),
|
|
"weight zero point must be a scalar or 1D tensor of size 1.");
|
|
if (0 != *(w_zp_tensor->Data<int8_t>()))
|
|
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention");
|
|
}
|
|
|
|
return Status::OK();
|
|
}
|
|
|
|
template <typename T>
|
|
Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|
// Input and output shapes:
|
|
// Input 0 - input : (batch_size, sequence_length, input_hidden_size)
|
|
// Input 1 - weights : (input_hidden_size, 3 * hidden_size)
|
|
// Input 2 - bias : (3 * hidden_size)
|
|
// Input 3 - input_scale : scalar
|
|
// Input 4 - weight_scale : scalar
|
|
// Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1)
|
|
// or (batch_size, past_sequence_length + sequence_length)
|
|
// Input 6 - input_zero_point : scalar
|
|
// Input 7 - weight_zero_point : scalar
|
|
// Input 8 - past : (2, batch_size, num_heads, past_sequence_length, head_size)
|
|
// Output 0 - output : (batch_size, sequence_length, hidden_size)
|
|
// Output 1 - present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)
|
|
// ORT_RETURN_IF_ERROR(CheckInputs(context));
|
|
const Tensor* input = context->Input<Tensor>(0);
|
|
const Tensor* weights = context->Input<Tensor>(1);
|
|
const Tensor* bias = context->Input<Tensor>(2);
|
|
const Tensor* input_scale_tensor = context->Input<Tensor>(3);
|
|
const Tensor* weight_scale_tensor = context->Input<Tensor>(4);
|
|
const Tensor* mask_index = context->Input<Tensor>(5);
|
|
const Tensor* i_zp_tensor = context->Input<Tensor>(6);
|
|
const Tensor* w_zp_tensor = context->Input<Tensor>(7);
|
|
const Tensor* past_tensor = context->Input<Tensor>(8);
|
|
|
|
ORT_RETURN_IF_ERROR(CheckInputs(input,
|
|
weights,
|
|
bias,
|
|
input_scale_tensor,
|
|
weight_scale_tensor,
|
|
mask_index,
|
|
i_zp_tensor,
|
|
w_zp_tensor,
|
|
past_tensor));
|
|
|
|
const auto& shape = input->Shape();
|
|
int batch_size = SafeInt<int>(shape[0]);
|
|
int sequence_length = SafeInt<int>(shape[1]);
|
|
int input_hidden_size = SafeInt<int>(shape[2]);
|
|
|
|
const auto& bias_shape = bias->Shape();
|
|
const int hidden_size = SafeInt<int>(bias_shape.GetDims()[0]) / 3;
|
|
// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in quantization
|
|
const int qkv_head_size[3] = {hidden_size / num_heads_, hidden_size / num_heads_, hidden_size / num_heads_};
|
|
|
|
TensorShapeVector output_shape(3);
|
|
output_shape[0] = shape[0];
|
|
output_shape[1] = shape[1];
|
|
output_shape[2] = SafeInt<int64_t>(hidden_size);
|
|
Tensor* output = context->Output(0, output_shape);
|
|
|
|
cublasHandle_t cublas = CublasHandle();
|
|
const size_t element_size = sizeof(T);
|
|
|
|
// Use GEMM for fully connection.
|
|
int m = batch_size * sequence_length;
|
|
int n = 3 * hidden_size;
|
|
int k = input_hidden_size;
|
|
auto gemm_buffer = GetScratchBuffer<T>(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
|
|
auto gemm_buffer_quantized = GetScratchBuffer<int32_t>(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size);
|
|
|
|
typedef typename ToCudaType<T>::MappedType CudaT;
|
|
|
|
ORT_RETURN_IF_ERROR(GemmInt8(m, n, k,
|
|
1 /*alpha_matmul*/, 0 /* beta_matmul*/,
|
|
input->Data<int8_t>(), k,
|
|
weights->Data<int8_t>(), n,
|
|
gemm_buffer_quantized.get(), n,
|
|
this));
|
|
|
|
CudaT dequant_scale;
|
|
CudaT input_scale = *(reinterpret_cast<const CudaT*>(input_scale_tensor->Data<T>()));
|
|
CudaT weight_scale = *(reinterpret_cast<const CudaT*>(weight_scale_tensor->Data<T>()));
|
|
if (sizeof(T) == 2) {
|
|
dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale));
|
|
} else {
|
|
dequant_scale = input_scale * weight_scale;
|
|
}
|
|
|
|
// scale back and bias
|
|
// TODO(tianleiwu): fuse Dequantize with Add bias and Transpose.
|
|
ORT_RETURN_IF_ERROR(CudaDequantizeWithBias(Stream(),
|
|
gemm_buffer_quantized.get(),
|
|
reinterpret_cast<const CudaT*>(bias->Data<T>()),
|
|
reinterpret_cast<CudaT*>(gemm_buffer.get()),
|
|
dequant_scale,
|
|
m,
|
|
n));
|
|
|
|
int past_sequence_length = 0;
|
|
Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, qkv_head_size[1],
|
|
sequence_length, past_sequence_length);
|
|
|
|
void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up
|
|
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, qkv_head_size[0],
|
|
sequence_length, past_sequence_length, fused_runner, qkv_head_size[2]);
|
|
|
|
auto work_space = GetScratchBuffer<void>(workSpaceSize);
|
|
return LaunchAttentionKernel(
|
|
GetDeviceProp(),
|
|
Stream(),
|
|
cublas,
|
|
element_size,
|
|
batch_size,
|
|
sequence_length,
|
|
num_heads_,
|
|
qkv_head_size[0],
|
|
past_sequence_length,
|
|
is_unidirectional_,
|
|
reinterpret_cast<const void*>(gemm_buffer.get()),
|
|
nullptr, // bias has been added
|
|
nullptr == mask_index ? nullptr : mask_index->Data<int>(),
|
|
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
|
|
nullptr == past_tensor ? nullptr : past_tensor->Data<T>(),
|
|
nullptr, // TODO: support add_qk in quantized attention
|
|
work_space.get(),
|
|
output->MutableData<T>(),
|
|
nullptr == present_tensor ? nullptr : present_tensor->MutableData<T>(),
|
|
fused_runner,
|
|
qkv_head_size[2]);
|
|
}
|
|
|
|
} // namespace cuda
|
|
} // namespace contrib
|
|
} // namespace onnxruntime
|