Change head_size parameter dependent on qkv_hidden_size (#12933)

**Description**: Add qkv_hidden_size support in CUDA Attention Layer
implementation.

Changes include:

- Modify UT to test GPU and CPU implementation
- Add overload for CUDA kernel `AddBiasTransposeQKV` to support scenario
where V_HIDDEN_SIZE != QK_HIDDEN_SIZE
- Update variable names from `head_size` to `qkv_head_sizes[0]` or
`qkv_head_sizes[2]`
- Modify function definitions to allow communication of
`qkv_hidden_sizes` or `qkv_head_sizes`

Note that this feature is not supported in Rocm EP or quantized
attention right now.

**Motivation and Context**
- Why is this change required? What problem does it solve? The current
CUDA implementation of attention layer doesn't support the parameter
qkv_hidden_size added in the CPU implementation in PR
[8039](https://github.com/microsoft/onnxruntime/pull/8039)
- If it fixes an open issue, please link to the issue here.

Co-authored-by: Peter Mcaughan <petermca@microsoft.com>
This commit is contained in:
petermcaughan 2022-10-11 00:25:47 -07:00 committed by GitHub
parent b9e23bd086
commit febd5facce
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
9 changed files with 194 additions and 98 deletions

View file

@ -118,6 +118,52 @@ __global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output)
}
}
template <typename T>
__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output, int v_head_size) {
// Input: BxSxMxNxH (Format 1)
// Output: MxBxNxSxH
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
int n = threadIdx.y; // head_num_id
int s = blockIdx.x; // sequence_id
int b = blockIdx.y; // batch_id
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
const int h = threadIdx.x; // head_element_id
const int qk_head_size = blockDim.x;
const int num_heads = blockDim.y;
const int sequence_length = gridDim.x;
const int batch_size = gridDim.y;
const int qkv_head_sizes[3] = {qk_head_size, qk_head_size, v_head_size};
const int total_head_size = num_heads * (qkv_head_sizes[0] + qkv_head_sizes[1] + qkv_head_sizes[2]);
int in_offset;
int out_offset;
int bias_offset;
in_offset = b * (total_head_size * sequence_length) + // B
s * (total_head_size) + // S
m * (qk_head_size * num_heads) + // M
n * qkv_head_sizes[m] + // N
h; // H
out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
b * (num_heads * qkv_head_sizes[m] * sequence_length) + // B
n * (sequence_length * qkv_head_sizes[m]) + // N
s * (qkv_head_sizes[m]) + // S
h; // H
bias_offset = m * (num_heads * qk_head_size)+ // QKV
n * (qkv_head_sizes[m]) + // N
h; // H
if (h < qkv_head_sizes[m]) {
output[out_offset] = input[in_offset] + biases[bias_offset];
}
}
template <typename T>
__global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, const T* biases, T* output) {
int n = threadIdx.y;
@ -203,26 +249,30 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const
template <typename T>
void InvokeAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const T* input, const T* biases, T* output) {
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, const int v_head_size) {
const dim3 grid(sequence_length, batch_size, num_matrices);
if (head_size * num_heads <= max_threads_per_block) {
const dim3 block(head_size, num_heads, 1);
if (qk_head_size * num_heads <= max_threads_per_block) {
const dim3 block(qk_head_size, num_heads, 1);
if (format == 2) {
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(input, biases, output);
} else if (format == 1) {
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output);
if ((v_head_size == -1) || (qk_head_size == v_head_size)) {
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output);
} else {
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
}
} else {
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
}
} else {
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
if (format == 2) {
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
} else if (format == 1) {
AddBiasTransposeQKVLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
AddBiasTransposeQKVLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
} else {
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
}
}
}
@ -230,53 +280,55 @@ void InvokeAddBiasTranspose(
template <>
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const half* input, const half* biases, half* output,
bool enable_half4) {
if (enable_half4 && 0 == (head_size % 4)) {
const int H = head_size / 4;
bool enable_half4, const int v_head_size) {
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
const int H_q = qk_head_size / 4;
const int H_v = v_head_size / 4;
const Half4* input2 = reinterpret_cast<const Half4*>(input);
const Half4* biases2 = reinterpret_cast<const Half4*>(biases);
Half4* output2 = reinterpret_cast<Half4*>(output);
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) {
const int H_q = qk_head_size / 2;
const int H_v = v_head_size / 2;
const half2* input2 = reinterpret_cast<const half2*>(input);
const half2* biases2 = reinterpret_cast<const half2*>(biases);
half2* output2 = reinterpret_cast<half2*>(output);
InvokeAddBiasTranspose<half2>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
} else {
InvokeAddBiasTranspose<half>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, head_size, input, biases, output);
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
}
}
template <>
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const float* input, const float* biases, float* output,
bool /*enable_half4*/) {
if (0 == (head_size % 4)) {
const int H = head_size / 4;
bool /*enable_half4*/, const int v_head_size) {
if (0 == (qk_head_size % 4)) {
const int H = qk_head_size / 4;
const float4* input2 = reinterpret_cast<const float4*>(input);
const float4* biases2 = reinterpret_cast<const float4*>(biases);
float4* output2 = reinterpret_cast<float4*>(output);
InvokeAddBiasTranspose<float4>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
} else if (0 == (head_size & 1)) {
const int H = head_size / 2;
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 4);
} else if (0 == (qk_head_size & 1)) {
const int H = qk_head_size / 2;
const float2* input2 = reinterpret_cast<const float2*>(input);
const float2* biases2 = reinterpret_cast<const float2*>(biases);
float2* output2 = reinterpret_cast<float2*>(output);
InvokeAddBiasTranspose<float2>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 2);
} else {
InvokeAddBiasTranspose<float>(stream, num_matrices, format, max_threads_per_block,
batch_size, sequence_length, num_heads, head_size, input, biases, output);
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
}
}

View file

@ -24,8 +24,8 @@ namespace cuda {
template <typename T>
void LaunchAddBiasTranspose(
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
const T* input, const T* biases, T* output, bool enable_half4);
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size);
} // namespace cuda
} // namespace contrib

View file

@ -82,18 +82,31 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// bias shape (3 * hidden_size)
const auto& bias_shape = bias->Shape();
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
int q_hidden_size;
int k_hidden_size;
int v_hidden_size;
int head_size = hidden_size / num_heads_;
if (qkv_hidden_sizes_.size() == 0) {
q_hidden_size = static_cast<int>(bias_shape[0]) / 3;
k_hidden_size = static_cast<int>(bias_shape[0]) / 3;
v_hidden_size = static_cast<int>(bias_shape[0]) / 3;
} else {
q_hidden_size = static_cast<int>(qkv_hidden_sizes_[0]);
k_hidden_size = static_cast<int>(qkv_hidden_sizes_[1]);
v_hidden_size = static_cast<int>(qkv_hidden_sizes_[2]);
}
const int qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
TensorShapeVector output_shape(3);
output_shape[0] = shape[0];
output_shape[1] = shape[1];
output_shape[2] = static_cast<int64_t>(hidden_size);
output_shape[2] = static_cast<int64_t>(v_hidden_size);
Tensor* output = context->Output(0, output_shape);
int past_sequence_length = 0;
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
Tensor* present = GetPresent(context, past, batch_size, qkv_head_size[1], sequence_length, past_sequence_length);
// Check whether we can use fused kernel
int sm = device_prop.major * 10 + device_prop.minor;
@ -103,12 +116,14 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
nullptr == present &&
nullptr == extra_add_qk &&
!is_unidirectional_ &&
HasFusedFp16Kernel(sm, head_size, sequence_length));
qkv_head_size[0] == qkv_head_size[1] &&
qkv_head_size[1] == qkv_head_size[2] &&
HasFusedFp16Kernel(sm, qkv_head_size[0], sequence_length));
MHARunner* fused_runner = nullptr;
if (use_fused_runner) {
if (nullptr == fused_fp16_runner_.get()) {
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, head_size, sm));
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, qkv_head_size[0], sm));
}
// In case some kernel not loaded due to shared memory limit, we need to double check here.
if (fused_fp16_runner_->isValid(sequence_length)) {
@ -121,9 +136,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
// Use GEMM for fully connection.
int m = batch_size * sequence_length;
int n = 3 * hidden_size;
int n = (q_hidden_size + k_hidden_size + v_hidden_size);
int k = input_hidden_size;
size_t gemm_buffer_size = static_cast<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size;
size_t gemm_buffer_size = static_cast<size_t>(batch_size) * sequence_length * n * element_size;
auto gemm_buffer = GetScratchBuffer<T>(gemm_buffer_size);
typedef typename ToCudaType<T>::MappedType CudaT;
@ -140,10 +155,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
batch_size,
num_heads_,
head_size,
qkv_head_size[0],
sequence_length,
past_sequence_length,
fused_runner);
fused_runner,
qkv_head_size[2]);
auto work_space = GetScratchBuffer<void>(workSpaceSize);
ORT_RETURN_IF_ERROR(LaunchAttentionKernel(
@ -154,7 +170,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
batch_size,
sequence_length,
num_heads_,
head_size,
qkv_head_size[0],
past_sequence_length,
is_unidirectional_,
reinterpret_cast<const void*>(gemm_buffer.get()),
@ -166,7 +182,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
work_space.get(),
output->MutableData<T>(),
nullptr == present ? nullptr : present->MutableData<T>(),
fused_runner));
fused_runner,
qkv_head_size[2]));
return Status::OK();
}

View file

@ -64,11 +64,13 @@ size_t GetAttentionWorkspaceSize(
size_t element_size,
size_t batch_size,
size_t num_heads,
size_t head_size,
size_t qk_head_size,
size_t sequence_length,
size_t past_sequence_length,
void* fused_runner) {
size_t q_size = element_size * batch_size * sequence_length * num_heads * head_size;
void* fused_runner,
size_t v_head_size) {
size_t q_size = element_size * batch_size * sequence_length * num_heads * qk_head_size;
size_t v_size = element_size * batch_size * sequence_length * num_heads * v_head_size;
if (fused_runner != nullptr) {
// Offsets without padding is B + 1. When we add padding, the size need to increase to 2B + 1.
@ -76,7 +78,7 @@ size_t GetAttentionWorkspaceSize(
return 4 * q_size + reinterpret_cast<MHARunner*>(fused_runner)->getWorkspaceSize() + sequenceOffsetBytes;
}
return 3 * q_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
return (2 * q_size + v_size) + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
past_sequence_length + sequence_length);
}
@ -88,7 +90,7 @@ Status QkvToContext(
const int batch_size,
const int sequence_length,
const int num_heads,
const int head_size,
const int qk_head_size,
const size_t element_size,
const T* input,
const T* bias,
@ -102,34 +104,40 @@ Status QkvToContext(
const T* extra_add_qk,
T* present,
bool use_persistent_softmax,
MHARunner* fused_runner) {
MHARunner* fused_runner,
const int v_head_size) {
const int max_threads_per_block = prop.maxThreadsPerBlock;
// input should be BxSx3xNxH => qkv: 3xBxNxSxH
T* qkv = workspace;
if (bias == nullptr) {
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads,
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, false, input, qkv));
} else {
// For fused TRT attention, qkv need transpose to BxSxNx3xH
const int format = (nullptr == fused_runner ? 1 : 2);
const bool enable_half4 = true;
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size,
sequence_length, num_heads, head_size,
sequence_length, num_heads, qk_head_size,
input, bias, qkv,
enable_half4);
enable_half4, v_head_size);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
}
// Q, K, V has size BxNxSxH
const int batches = batch_size * num_heads;
const int size_per_batch = sequence_length * head_size;
const int total_size = batches * size_per_batch;
const int size_per_batch_qk = sequence_length * qk_head_size;
const int size_per_batch_v = sequence_length * v_head_size;
const int total_size_qk = batches * size_per_batch_qk;
T* scratch1;
scratch1 = qkv + (batches * sequence_length * (qk_head_size + qk_head_size + v_head_size));
T* scratch1 = qkv + 3 * total_size;
T* temp_output = scratch1;
if (nullptr != fused_runner && bias != nullptr) {
int* sequence_offset = reinterpret_cast<int*>(qkv + 4 * total_size);
int* sequence_offset = reinterpret_cast<int*>(qkv + 4 * total_size_qk);
LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, stream);
CUDA_RETURN_IF_ERROR(cudaGetLastError());
@ -147,44 +155,45 @@ Status QkvToContext(
T* scratch2 = scratch1 + (bytes / element_size);
const T* q = qkv;
const T* k = q + total_size;
const T* v = k + total_size;
const T* k = q + (batches * sequence_length * qk_head_size);
const T* v = k + (batches * sequence_length * qk_head_size);
cublasSetStream(cublas, stream);
// Concat past (2xBxNxS'xH) to present (2xBxNxS*xH):
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH)
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH)
const int present_size_per_batch = all_sequence_length * head_size;
const int present_size_per_batch_k = all_sequence_length * qk_head_size;
const int present_size_per_batch_v = all_sequence_length * v_head_size;
if (nullptr != present) {
ORT_RETURN_IF_ERROR(
LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads,
LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
max_threads_per_block, past, k, present));
// update pointers to present_k and present_v.
k = present;
v = present + batches * present_size_per_batch;
v = present + batches * present_size_per_batch_k;
}
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2);
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H_qk) and store in scratch1: BxNxSxS*
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(qk_head_size));
const int temp_matrix_size = sequence_length * all_sequence_length;
float one = 1.0f;
float zero = 0.f;
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
// For raw attention mask, the scalar 1/sqrt(H_qk) is moved to combine with softmax computation.
float alpha = use_raw_attention_mask ? one : rsqrt_head_size;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
all_sequence_length, sequence_length, head_size,
&alpha, k, head_size, present_size_per_batch,
q, head_size, size_per_batch,
&zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop));
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
all_sequence_length, sequence_length, qk_head_size,
&alpha, k, qk_head_size, present_size_per_batch_k,
q, qk_head_size, size_per_batch_qk,
&zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop));
// apply softmax and store result P to scratch2: BxNxSxS*
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
@ -214,13 +223,13 @@ Status QkvToContext(
temp_output = qkv;
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
head_size, sequence_length, all_sequence_length,
&one, v, head_size, present_size_per_batch,
v_head_size, sequence_length, all_sequence_length,
&one, v, v_head_size, present_size_per_batch_v,
scratch2, all_sequence_length, temp_matrix_size,
&zero, temp_output, head_size, size_per_batch, batches, prop));
&zero, temp_output, v_head_size, size_per_batch_v, batches, prop));
// temp_output is BxNxSxH, transpose to output BxSxNxH
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
return LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
max_threads_per_block, false, temp_output, output);
}
@ -232,7 +241,7 @@ Status LaunchAttentionKernel(
int batch_size,
int sequence_length,
int num_heads,
int head_size,
const int qk_head_size,
int past_sequence_length,
bool is_unidirectional,
const void* input,
@ -244,13 +253,14 @@ Status LaunchAttentionKernel(
void* workspace,
void* output,
void* present,
void* fused_runner) {
void* fused_runner,
const int v_head_size) {
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax used in Torch.
const TransformerOptions* options = TransformerOptions::GetInstance();
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
if (element_size == 2) {
return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, qk_head_size, element_size,
reinterpret_cast<const half*>(input),
reinterpret_cast<const half*>(bias),
reinterpret_cast<half*>(output),
@ -263,9 +273,10 @@ Status LaunchAttentionKernel(
reinterpret_cast<const half*>(extra_add_qk),
reinterpret_cast<half*>(present),
use_persistent_softmax,
reinterpret_cast<MHARunner*>(fused_runner));
reinterpret_cast<MHARunner*>(fused_runner),
v_head_size);
} else {
return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, qk_head_size, element_size,
reinterpret_cast<const float*>(input),
reinterpret_cast<const float*>(bias),
reinterpret_cast<float*>(output),
@ -278,7 +289,8 @@ Status LaunchAttentionKernel(
reinterpret_cast<const float*>(extra_add_qk),
reinterpret_cast<float*>(present),
use_persistent_softmax,
nullptr);
nullptr,
v_head_size);
}
}

View file

@ -21,10 +21,11 @@ size_t GetAttentionWorkspaceSize(
size_t element_size,
size_t batchsize,
size_t num_heads,
size_t head_size,
size_t qk_head_size,
size_t sequence_length,
size_t past_sequence_length,
void* fused_runner);
void* fused_runner,
size_t v_head_size);
Status LaunchAttentionKernel(
const cudaDeviceProp& prop, // Device Properties
@ -34,7 +35,7 @@ Status LaunchAttentionKernel(
int batch_size, // Batch size (B)
int sequence_length, // Sequence length (S)
int num_heads, // Number of attention heads (N)
int head_size, // Hidden layer size per head (H)
const int qk_head_size, // Hidden layer size per head for q and k (H_qk)
int past_sequence_length, // Sequence length in past state
bool is_unidirectional, // Whether there is unidirecitonal mask.
const void* input, // Input tensor
@ -46,8 +47,8 @@ Status LaunchAttentionKernel(
void* workspace, // Temporary buffer
void* output, // Output tensor
void* present, // Present state output
void* fused_runner // Fused multi-head attention
);
void* fused_runner, // Fused multi-head attention
const int v_head_size); // Hidden layer size per head for v (H_v)
Status LaunchDecoderAttentionKernel(
const cudaDeviceProp& prop, // Device Properties

View file

@ -868,6 +868,7 @@ Status LongformerQkvToContext(
// The order of qkv space:
// Q, K, V, Global_K, Global_V, Global_Q (format 0)
// Q, K, V, Global_Q, Global_K, Global_V (format 1)
// Assume H_q == H_k == H_v
if (format == 1 || max_num_global == 0 || nullptr == global_input) {
if (bias == nullptr) {
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads,
@ -876,7 +877,7 @@ Status LongformerQkvToContext(
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size,
sequence_length, num_heads, head_size,
input, bias, qkv,
use_half4);
use_half4, head_size);
}
if (max_num_global > 0 && nullptr != global_input) {
@ -887,20 +888,20 @@ Status LongformerQkvToContext(
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size,
sequence_length, num_heads, head_size,
global_input, global_bias, qkv + 3 * elements,
use_half4);
use_half4, head_size);
}
}
} else {
LaunchAddBiasTranspose(stream, 5, format, max_threads_per_block, batch_size,
sequence_length, num_heads, head_size,
input, bias, qkv,
use_half4);
use_half4, head_size);
compact_global_q = (disable_compact_memory == false);
LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, batch_size,
compact_global_q ? max_num_global : sequence_length, num_heads, head_size,
global_input + 2 * elements, global_bias, qkv + 5 * elements,
use_half4);
use_half4, head_size);
}
CUDA_RETURN_IF_ERROR(cudaGetLastError());

View file

@ -119,7 +119,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
const auto& bias_shape = bias->Shape();
const int hidden_size = SafeInt<int>(bias_shape.GetDims()[0]) / 3;
const int head_size = hidden_size / num_heads_;
// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in quantization
const int qkv_head_size[3] = {hidden_size / num_heads_, hidden_size / num_heads_, hidden_size / num_heads_};
TensorShapeVector output_shape(3);
output_shape[0] = shape[0];
@ -166,12 +167,12 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
n));
int past_sequence_length = 0;
Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, head_size,
Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, qkv_head_size[1],
sequence_length, past_sequence_length);
void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size,
sequence_length, past_sequence_length, fused_runner);
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, qkv_head_size[0],
sequence_length, past_sequence_length, fused_runner, qkv_head_size[2]);
auto work_space = GetScratchBuffer<void>(workSpaceSize);
return LaunchAttentionKernel(
@ -182,7 +183,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
batch_size,
sequence_length,
num_heads_,
head_size,
qkv_head_size[0],
past_sequence_length,
is_unidirectional_,
reinterpret_cast<const void*>(gemm_buffer.get()),
@ -194,7 +195,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
work_space.get(),
output->MutableData<T>(),
nullptr == present_tensor ? nullptr : present_tensor->MutableData<T>(),
fused_runner);
fused_runner,
qkv_head_size[2]);
}
} // namespace cuda

View file

@ -50,6 +50,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
int sequence_length = static_cast<int>(shape[1]);
int input_hidden_size = static_cast<int>(shape[2]);
// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in ROCM EP
// bias shape (3 * hidden_size)
const auto& bias_shape = bias->Shape();
int hidden_size = static_cast<int>(bias_shape[0]) / 3;

View file

@ -41,12 +41,13 @@ static void RunAttentionTest(
bool only_enable_cuda = false,
bool only_enable_cpu = false,
std::vector<int32_t> qkv_sizes = {},
const std::vector<float>& extra_add_data = {}) {
const std::vector<float>& extra_add_data = {},
const bool disable_rocm = false) {
input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning.
int min_cuda_architecture = use_float16 ? 530 : 0;
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu;
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu;
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu && !disable_rocm;
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda;
int head_size = hidden_size / number_of_heads;
@ -188,17 +189,18 @@ static void RunAttentionTest(
bool only_enable_cuda = false,
bool only_enable_cpu = false,
const std::vector<int32_t> qkv_sizes = {},
const std::vector<float>& extra_add_data = {}) {
const std::vector<float>& extra_add_data = {},
const bool disable_rocm = false) {
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data);
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data, disable_rocm);
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
use_float16, is_unidirectional, use_past_state, past_sequence_length,
past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data);
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data, disable_rocm);
}
TEST(AttentionTest, AttentionBatch1) {
@ -267,7 +269,11 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr1) {
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
0, false, true, qkv_sizes);
0, true, false, qkv_sizes, {}, true);
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
0, false, true, qkv_sizes, {}, true);
}
TEST(AttentionTest, AttentionBatch1WithQKVAttr2) {
@ -304,7 +310,11 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) {
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
0, false, true, qkv_sizes);
0, true, false, qkv_sizes, {}, true);
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
batch_size, sequence_length, hidden_size, number_of_heads,
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
0, false, true, qkv_sizes, {}, true);
}
TEST(AttentionTest, AttentionBatch1ExtraAdd) {