onnxruntime/onnxruntime/contrib_ops/cuda/quantization/attention_quantization.cc
petermcaughan febd5facce
Change head_size parameter dependent on qkv_hidden_size (#12933)
**Description**: Add qkv_hidden_size support in CUDA Attention Layer
implementation.

Changes include:

- Modify UT to test GPU and CPU implementation
- Add overload for CUDA kernel `AddBiasTransposeQKV` to support scenario
where V_HIDDEN_SIZE != QK_HIDDEN_SIZE
- Update variable names from `head_size` to `qkv_head_sizes[0]` or
`qkv_head_sizes[2]`
- Modify function definitions to allow communication of
`qkv_hidden_sizes` or `qkv_head_sizes`

Note that this feature is not supported in Rocm EP or quantized
attention right now.

**Motivation and Context**
- Why is this change required? What problem does it solve? The current
CUDA implementation of attention layer doesn't support the parameter
qkv_hidden_size added in the CPU implementation in PR
[8039](https://github.com/microsoft/onnxruntime/pull/8039)
- If it fixes an open issue, please link to the issue here.

Co-authored-by: Peter Mcaughan <petermca@microsoft.com>
2022-10-11 00:25:47 -07:00

204 lines
9.7 KiB
C++

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "attention_quantization.h"
#include "attention_quantization_impl.cuh"
#include "contrib_ops/cuda/bert/attention_impl.h"
#include "core/common/safeint.h"
#include "core/providers/cuda/cuda_common.h"
#include "core/providers/cuda/shared_inc/fpgeneric.h"
#include "core/providers/cuda/shared_inc/integer_gemm.h"
#include "core/providers/cuda/tensor/quantize_linear.h"
using namespace onnxruntime::cuda;
using namespace onnxruntime::common;
namespace onnxruntime {
namespace contrib {
namespace cuda {
#define REGISTER_KERNEL_TYPED(T, TQuant) \
ONNX_OPERATOR_TYPED_KERNEL_EX( \
QAttention, \
kMSDomain, \
1, \
T##_##TQuant, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.InputMemoryType(OrtMemTypeCPUInput, 3) \
.InputMemoryType(OrtMemTypeCPUInput, 4) \
.InputMemoryType(OrtMemTypeCPUInput, 6) \
.InputMemoryType(OrtMemTypeCPUInput, 7) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<TQuant>()) \
.TypeConstraint("T2", DataTypeImpl::GetTensorType<TQuant>()) \
.TypeConstraint("T3", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T4", DataTypeImpl::GetTensorType<int32_t>()), \
QAttention<T, TQuant>);
REGISTER_KERNEL_TYPED(float, int8_t)
REGISTER_KERNEL_TYPED(MLFloat16, int8_t)
template <typename T>
Status QAttention<T, int8_t>::CheckInputs(const Tensor* input,
const Tensor* weights,
const Tensor* bias,
const Tensor* input_scale_tensor,
const Tensor* weight_scale_tensor,
const Tensor*& mask_index,
const Tensor* i_zp_tensor,
const Tensor* w_zp_tensor,
const Tensor* past_tensor) const {
auto& device_prop = GetDeviceProp();
ORT_RETURN_IF_ERROR(AttentionBase::CheckInputs(input->Shape(), weights->Shape(), bias->Shape(),
mask_index, past_tensor, nullptr, device_prop.maxThreadsPerBlock));
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(input_scale_tensor),
"input scale must be a scalar or 1D tensor of size 1");
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(weight_scale_tensor),
"weight must be a scalar or 1D tensor of size 1");
if (i_zp_tensor != nullptr) {
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(i_zp_tensor),
"input zero point must be a scalar or 1D tensor of size 1.");
if (0 != *(i_zp_tensor->Data<int8_t>()))
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention");
}
if (w_zp_tensor != nullptr) {
// CUDA only support symmetric quantization for Attention
ORT_RETURN_IF_NOT(IsScalarOr1ElementVector(w_zp_tensor),
"weight zero point must be a scalar or 1D tensor of size 1.");
if (0 != *(w_zp_tensor->Data<int8_t>()))
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "CUDA only support symmetric quantization for Attention");
}
return Status::OK();
}
template <typename T>
Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
// Input and output shapes:
// Input 0 - input : (batch_size, sequence_length, input_hidden_size)
// Input 1 - weights : (input_hidden_size, 3 * hidden_size)
// Input 2 - bias : (3 * hidden_size)
// Input 3 - input_scale : scalar
// Input 4 - weight_scale : scalar
// Input 5 - mask_index : nullptr, (batch_size), (2 * batch_size), (batch_size, 1), (1, 1)
// or (batch_size, past_sequence_length + sequence_length)
// Input 6 - input_zero_point : scalar
// Input 7 - weight_zero_point : scalar
// Input 8 - past : (2, batch_size, num_heads, past_sequence_length, head_size)
// Output 0 - output : (batch_size, sequence_length, hidden_size)
// Output 1 - present : (2, batch_size, num_heads, past_sequence_length + sequence_length, head_size)
// ORT_RETURN_IF_ERROR(CheckInputs(context));
const Tensor* input = context->Input<Tensor>(0);
const Tensor* weights = context->Input<Tensor>(1);
const Tensor* bias = context->Input<Tensor>(2);
const Tensor* input_scale_tensor = context->Input<Tensor>(3);
const Tensor* weight_scale_tensor = context->Input<Tensor>(4);
const Tensor* mask_index = context->Input<Tensor>(5);
const Tensor* i_zp_tensor = context->Input<Tensor>(6);
const Tensor* w_zp_tensor = context->Input<Tensor>(7);
const Tensor* past_tensor = context->Input<Tensor>(8);
ORT_RETURN_IF_ERROR(CheckInputs(input,
weights,
bias,
input_scale_tensor,
weight_scale_tensor,
mask_index,
i_zp_tensor,
w_zp_tensor,
past_tensor));
const auto& shape = input->Shape();
int batch_size = SafeInt<int>(shape[0]);
int sequence_length = SafeInt<int>(shape[1]);
int input_hidden_size = SafeInt<int>(shape[2]);
const auto& bias_shape = bias->Shape();
const int hidden_size = SafeInt<int>(bias_shape.GetDims()[0]) / 3;
// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in quantization
const int qkv_head_size[3] = {hidden_size / num_heads_, hidden_size / num_heads_, hidden_size / num_heads_};
TensorShapeVector output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = SafeInt<int64_t>(hidden_size);
Tensor* output = context->Output(0, output_shape);
cublasHandle_t cublas = CublasHandle();
const size_t element_size = sizeof(T);
// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int k = input_hidden_size;
auto gemm_buffer = GetScratchBuffer<T>(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size);
auto gemm_buffer_quantized = GetScratchBuffer<int32_t>(SafeInt<size_t>(batch_size) * sequence_length * 3 * hidden_size);
typedef typename ToCudaType<T>::MappedType CudaT;
ORT_RETURN_IF_ERROR(GemmInt8(m, n, k,
1 /*alpha_matmul*/, 0 /* beta_matmul*/,
input->Data<int8_t>(), k,
weights->Data<int8_t>(), n,
gemm_buffer_quantized.get(), n,
this));
CudaT dequant_scale;
CudaT input_scale = *(reinterpret_cast<const CudaT*>(input_scale_tensor->Data<T>()));
CudaT weight_scale = *(reinterpret_cast<const CudaT*>(weight_scale_tensor->Data<T>()));
if (sizeof(T) == 2) {
dequant_scale = __float2half(__half2float(input_scale) * __half2float(weight_scale));
} else {
dequant_scale = input_scale * weight_scale;
}
// scale back and bias
// TODO(tianleiwu): fuse Dequantize with Add bias and Transpose.
ORT_RETURN_IF_ERROR(CudaDequantizeWithBias(Stream(),
gemm_buffer_quantized.get(),
reinterpret_cast<const CudaT*>(bias->Data<T>()),
reinterpret_cast<CudaT*>(gemm_buffer.get()),
dequant_scale,
m,
n));
int past_sequence_length = 0;
Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, qkv_head_size[1],
sequence_length, past_sequence_length);
void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, qkv_head_size[0],
sequence_length, past_sequence_length, fused_runner, qkv_head_size[2]);
auto work_space = GetScratchBuffer<void>(workSpaceSize);
return LaunchAttentionKernel(
GetDeviceProp(),
Stream(),
cublas,
element_size,
batch_size,
sequence_length,
num_heads_,
qkv_head_size[0],
past_sequence_length,
is_unidirectional_,
reinterpret_cast<const void*>(gemm_buffer.get()),
nullptr, // bias has been added
nullptr == mask_index ? nullptr : mask_index->Data<int>(),
nullptr == mask_index ? gsl::span<const int64_t>() : mask_index->Shape().GetDims(),
nullptr == past_tensor ? nullptr : past_tensor->Data<T>(),
nullptr, // TODO: support add_qk in quantized attention
work_space.get(),
output->MutableData<T>(),
nullptr == present_tensor ? nullptr : present_tensor->MutableData<T>(),
fused_runner,
qkv_head_size[2]);
}
} // namespace cuda
} // namespace contrib
} // namespace onnxruntime