mirror of
https://github.com/saymrwulf/onnxruntime.git
synced 2026-06-23 02:38:28 +00:00
Change head_size parameter dependent on qkv_hidden_size (#12933)
**Description**: Add qkv_hidden_size support in CUDA Attention Layer implementation. Changes include: - Modify UT to test GPU and CPU implementation - Add overload for CUDA kernel `AddBiasTransposeQKV` to support scenario where V_HIDDEN_SIZE != QK_HIDDEN_SIZE - Update variable names from `head_size` to `qkv_head_sizes[0]` or `qkv_head_sizes[2]` - Modify function definitions to allow communication of `qkv_hidden_sizes` or `qkv_head_sizes` Note that this feature is not supported in Rocm EP or quantized attention right now. **Motivation and Context** - Why is this change required? What problem does it solve? The current CUDA implementation of attention layer doesn't support the parameter qkv_hidden_size added in the CPU implementation in PR [8039](https://github.com/microsoft/onnxruntime/pull/8039) - If it fixes an open issue, please link to the issue here. Co-authored-by: Peter Mcaughan <petermca@microsoft.com>
This commit is contained in:
parent
b9e23bd086
commit
febd5facce
9 changed files with 194 additions and 98 deletions
|
|
@ -118,6 +118,52 @@ __global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output)
|
|||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddBiasTransposeQKV(const T* input, const T* biases, T* output, int v_head_size) {
|
||||
// Input: BxSxMxNxH (Format 1)
|
||||
// Output: MxBxNxSxH
|
||||
// B is batch_size, S is sequence_length, M is number of matrices, N is num_heads, H is head_size
|
||||
|
||||
int n = threadIdx.y; // head_num_id
|
||||
int s = blockIdx.x; // sequence_id
|
||||
int b = blockIdx.y; // batch_id
|
||||
int m = blockIdx.z; // matrix id (Q=0, K=1, V=2)
|
||||
const int h = threadIdx.x; // head_element_id
|
||||
|
||||
const int qk_head_size = blockDim.x;
|
||||
const int num_heads = blockDim.y;
|
||||
|
||||
const int sequence_length = gridDim.x;
|
||||
const int batch_size = gridDim.y;
|
||||
|
||||
const int qkv_head_sizes[3] = {qk_head_size, qk_head_size, v_head_size};
|
||||
|
||||
const int total_head_size = num_heads * (qkv_head_sizes[0] + qkv_head_sizes[1] + qkv_head_sizes[2]);
|
||||
|
||||
int in_offset;
|
||||
int out_offset;
|
||||
int bias_offset;
|
||||
in_offset = b * (total_head_size * sequence_length) + // B
|
||||
s * (total_head_size) + // S
|
||||
m * (qk_head_size * num_heads) + // M
|
||||
n * qkv_head_sizes[m] + // N
|
||||
h; // H
|
||||
|
||||
out_offset = m * (num_heads * qk_head_size * sequence_length * batch_size) + // M
|
||||
b * (num_heads * qkv_head_sizes[m] * sequence_length) + // B
|
||||
n * (sequence_length * qkv_head_sizes[m]) + // N
|
||||
s * (qkv_head_sizes[m]) + // S
|
||||
h; // H
|
||||
|
||||
bias_offset = m * (num_heads * qk_head_size)+ // QKV
|
||||
n * (qkv_head_sizes[m]) + // N
|
||||
h; // H
|
||||
|
||||
if (h < qkv_head_sizes[m]) {
|
||||
output[out_offset] = input[in_offset] + biases[bias_offset];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
__global__ void AddBiasTransposeQKVLarge(const int head_size, const T* input, const T* biases, T* output) {
|
||||
int n = threadIdx.y;
|
||||
|
|
@ -203,26 +249,30 @@ __global__ void AddBiasTransposeLarge(const int head_size, const T* input, const
|
|||
template <typename T>
|
||||
void InvokeAddBiasTranspose(
|
||||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
|
||||
const T* input, const T* biases, T* output) {
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const T* input, const T* biases, T* output, const int v_head_size) {
|
||||
const dim3 grid(sequence_length, batch_size, num_matrices);
|
||||
if (head_size * num_heads <= max_threads_per_block) {
|
||||
const dim3 block(head_size, num_heads, 1);
|
||||
if (qk_head_size * num_heads <= max_threads_per_block) {
|
||||
const dim3 block(qk_head_size, num_heads, 1);
|
||||
if (format == 2) {
|
||||
AddBiasTransposeTrt<T><<<grid, block, 0, stream>>>(input, biases, output);
|
||||
} else if (format == 1) {
|
||||
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output);
|
||||
if ((v_head_size == -1) || (qk_head_size == v_head_size)) {
|
||||
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output);
|
||||
} else {
|
||||
AddBiasTransposeQKV<T><<<grid, block, 0, stream>>>(input, biases, output, v_head_size);
|
||||
}
|
||||
} else {
|
||||
AddBiasTranspose<T><<<grid, block, 0, stream>>>(input, biases, output);
|
||||
}
|
||||
} else {
|
||||
const dim3 block(CeilDiv(max_threads_per_block, num_heads), num_heads, 1);
|
||||
if (format == 2) {
|
||||
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
|
||||
AddBiasTransposeTrtLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
|
||||
} else if (format == 1) {
|
||||
AddBiasTransposeQKVLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
|
||||
AddBiasTransposeQKVLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
|
||||
} else {
|
||||
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(head_size, input, biases, output);
|
||||
AddBiasTransposeLarge<T><<<grid, block, 0, stream>>>(qk_head_size, input, biases, output);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -230,53 +280,55 @@ void InvokeAddBiasTranspose(
|
|||
template <>
|
||||
void LaunchAddBiasTranspose(
|
||||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const half* input, const half* biases, half* output,
|
||||
bool enable_half4) {
|
||||
if (enable_half4 && 0 == (head_size % 4)) {
|
||||
const int H = head_size / 4;
|
||||
bool enable_half4, const int v_head_size) {
|
||||
if (enable_half4 && 0 == (qk_head_size % 4) && 0 == (v_head_size % 4)) {
|
||||
const int H_q = qk_head_size / 4;
|
||||
const int H_v = v_head_size / 4;
|
||||
const Half4* input2 = reinterpret_cast<const Half4*>(input);
|
||||
const Half4* biases2 = reinterpret_cast<const Half4*>(biases);
|
||||
Half4* output2 = reinterpret_cast<Half4*>(output);
|
||||
InvokeAddBiasTranspose<Half4>(stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
|
||||
} else if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
|
||||
} else if (0 == (qk_head_size & 1) && 0 == (v_head_size % 1)) {
|
||||
const int H_q = qk_head_size / 2;
|
||||
const int H_v = v_head_size / 2;
|
||||
const half2* input2 = reinterpret_cast<const half2*>(input);
|
||||
const half2* biases2 = reinterpret_cast<const half2*>(biases);
|
||||
half2* output2 = reinterpret_cast<half2*>(output);
|
||||
InvokeAddBiasTranspose<half2>(stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
|
||||
batch_size, sequence_length, num_heads, H_q, input2, biases2, output2, H_v);
|
||||
} else {
|
||||
InvokeAddBiasTranspose<half>(stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, head_size, input, biases, output);
|
||||
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
|
||||
}
|
||||
}
|
||||
|
||||
template <>
|
||||
void LaunchAddBiasTranspose(
|
||||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const float* input, const float* biases, float* output,
|
||||
bool /*enable_half4*/) {
|
||||
if (0 == (head_size % 4)) {
|
||||
const int H = head_size / 4;
|
||||
bool /*enable_half4*/, const int v_head_size) {
|
||||
if (0 == (qk_head_size % 4)) {
|
||||
const int H = qk_head_size / 4;
|
||||
const float4* input2 = reinterpret_cast<const float4*>(input);
|
||||
const float4* biases2 = reinterpret_cast<const float4*>(biases);
|
||||
float4* output2 = reinterpret_cast<float4*>(output);
|
||||
InvokeAddBiasTranspose<float4>(stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
|
||||
} else if (0 == (head_size & 1)) {
|
||||
const int H = head_size / 2;
|
||||
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 4);
|
||||
} else if (0 == (qk_head_size & 1)) {
|
||||
const int H = qk_head_size / 2;
|
||||
const float2* input2 = reinterpret_cast<const float2*>(input);
|
||||
const float2* biases2 = reinterpret_cast<const float2*>(biases);
|
||||
float2* output2 = reinterpret_cast<float2*>(output);
|
||||
|
||||
InvokeAddBiasTranspose<float2>(stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, H, input2, biases2, output2);
|
||||
batch_size, sequence_length, num_heads, H, input2, biases2, output2, v_head_size / 2);
|
||||
} else {
|
||||
InvokeAddBiasTranspose<float>(stream, num_matrices, format, max_threads_per_block,
|
||||
batch_size, sequence_length, num_heads, head_size, input, biases, output);
|
||||
batch_size, sequence_length, num_heads, qk_head_size, input, biases, output, v_head_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ namespace cuda {
|
|||
template <typename T>
|
||||
void LaunchAddBiasTranspose(
|
||||
cudaStream_t stream, const int num_matrices, const int format, const int max_threads_per_block,
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int head_size,
|
||||
const T* input, const T* biases, T* output, bool enable_half4);
|
||||
const int batch_size, const int sequence_length, const int num_heads, const int qk_head_size,
|
||||
const T* input, const T* biases, T* output, bool enable_half4, const int v_head_size);
|
||||
|
||||
} // namespace cuda
|
||||
} // namespace contrib
|
||||
|
|
|
|||
|
|
@ -82,18 +82,31 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
// bias shape (3 * hidden_size)
|
||||
const auto& bias_shape = bias->Shape();
|
||||
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
|
||||
int q_hidden_size;
|
||||
int k_hidden_size;
|
||||
int v_hidden_size;
|
||||
|
||||
int head_size = hidden_size / num_heads_;
|
||||
|
||||
if (qkv_hidden_sizes_.size() == 0) {
|
||||
q_hidden_size = static_cast<int>(bias_shape[0]) / 3;
|
||||
k_hidden_size = static_cast<int>(bias_shape[0]) / 3;
|
||||
v_hidden_size = static_cast<int>(bias_shape[0]) / 3;
|
||||
} else {
|
||||
q_hidden_size = static_cast<int>(qkv_hidden_sizes_[0]);
|
||||
k_hidden_size = static_cast<int>(qkv_hidden_sizes_[1]);
|
||||
v_hidden_size = static_cast<int>(qkv_hidden_sizes_[2]);
|
||||
}
|
||||
|
||||
const int qkv_head_size[3] = {q_hidden_size / num_heads_, k_hidden_size / num_heads_, v_hidden_size / num_heads_};
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
output_shape[1] = shape[1];
|
||||
output_shape[2] = static_cast<int64_t>(hidden_size);
|
||||
output_shape[2] = static_cast<int64_t>(v_hidden_size);
|
||||
Tensor* output = context->Output(0, output_shape);
|
||||
|
||||
int past_sequence_length = 0;
|
||||
Tensor* present = GetPresent(context, past, batch_size, head_size, sequence_length, past_sequence_length);
|
||||
Tensor* present = GetPresent(context, past, batch_size, qkv_head_size[1], sequence_length, past_sequence_length);
|
||||
|
||||
// Check whether we can use fused kernel
|
||||
int sm = device_prop.major * 10 + device_prop.minor;
|
||||
|
|
@ -103,12 +116,14 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
nullptr == present &&
|
||||
nullptr == extra_add_qk &&
|
||||
!is_unidirectional_ &&
|
||||
HasFusedFp16Kernel(sm, head_size, sequence_length));
|
||||
qkv_head_size[0] == qkv_head_size[1] &&
|
||||
qkv_head_size[1] == qkv_head_size[2] &&
|
||||
HasFusedFp16Kernel(sm, qkv_head_size[0], sequence_length));
|
||||
|
||||
MHARunner* fused_runner = nullptr;
|
||||
if (use_fused_runner) {
|
||||
if (nullptr == fused_fp16_runner_.get()) {
|
||||
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, head_size, sm));
|
||||
fused_fp16_runner_.reset(new FusedMHARunnerFP16v2(num_heads_, qkv_head_size[0], sm));
|
||||
}
|
||||
// In case some kernel not loaded due to shared memory limit, we need to double check here.
|
||||
if (fused_fp16_runner_->isValid(sequence_length)) {
|
||||
|
|
@ -121,9 +136,9 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
// Use GEMM for fully connection.
|
||||
int m = batch_size * sequence_length;
|
||||
int n = 3 * hidden_size;
|
||||
int n = (q_hidden_size + k_hidden_size + v_hidden_size);
|
||||
int k = input_hidden_size;
|
||||
size_t gemm_buffer_size = static_cast<size_t>(batch_size) * sequence_length * 3 * hidden_size * element_size;
|
||||
size_t gemm_buffer_size = static_cast<size_t>(batch_size) * sequence_length * n * element_size;
|
||||
auto gemm_buffer = GetScratchBuffer<T>(gemm_buffer_size);
|
||||
|
||||
typedef typename ToCudaType<T>::MappedType CudaT;
|
||||
|
|
@ -140,10 +155,11 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size,
|
||||
batch_size,
|
||||
num_heads_,
|
||||
head_size,
|
||||
qkv_head_size[0],
|
||||
sequence_length,
|
||||
past_sequence_length,
|
||||
fused_runner);
|
||||
fused_runner,
|
||||
qkv_head_size[2]);
|
||||
|
||||
auto work_space = GetScratchBuffer<void>(workSpaceSize);
|
||||
ORT_RETURN_IF_ERROR(LaunchAttentionKernel(
|
||||
|
|
@ -154,7 +170,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
batch_size,
|
||||
sequence_length,
|
||||
num_heads_,
|
||||
head_size,
|
||||
qkv_head_size[0],
|
||||
past_sequence_length,
|
||||
is_unidirectional_,
|
||||
reinterpret_cast<const void*>(gemm_buffer.get()),
|
||||
|
|
@ -166,7 +182,8 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
work_space.get(),
|
||||
output->MutableData<T>(),
|
||||
nullptr == present ? nullptr : present->MutableData<T>(),
|
||||
fused_runner));
|
||||
fused_runner,
|
||||
qkv_head_size[2]));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -64,11 +64,13 @@ size_t GetAttentionWorkspaceSize(
|
|||
size_t element_size,
|
||||
size_t batch_size,
|
||||
size_t num_heads,
|
||||
size_t head_size,
|
||||
size_t qk_head_size,
|
||||
size_t sequence_length,
|
||||
size_t past_sequence_length,
|
||||
void* fused_runner) {
|
||||
size_t q_size = element_size * batch_size * sequence_length * num_heads * head_size;
|
||||
void* fused_runner,
|
||||
size_t v_head_size) {
|
||||
size_t q_size = element_size * batch_size * sequence_length * num_heads * qk_head_size;
|
||||
size_t v_size = element_size * batch_size * sequence_length * num_heads * v_head_size;
|
||||
|
||||
if (fused_runner != nullptr) {
|
||||
// Offsets without padding is B + 1. When we add padding, the size need to increase to 2B + 1.
|
||||
|
|
@ -76,7 +78,7 @@ size_t GetAttentionWorkspaceSize(
|
|||
return 4 * q_size + reinterpret_cast<MHARunner*>(fused_runner)->getWorkspaceSize() + sequenceOffsetBytes;
|
||||
}
|
||||
|
||||
return 3 * q_size + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
|
||||
return (2 * q_size + v_size) + 2 * GetAttentionScratchSize(element_size, batch_size, num_heads, sequence_length,
|
||||
past_sequence_length + sequence_length);
|
||||
}
|
||||
|
||||
|
|
@ -88,7 +90,7 @@ Status QkvToContext(
|
|||
const int batch_size,
|
||||
const int sequence_length,
|
||||
const int num_heads,
|
||||
const int head_size,
|
||||
const int qk_head_size,
|
||||
const size_t element_size,
|
||||
const T* input,
|
||||
const T* bias,
|
||||
|
|
@ -102,34 +104,40 @@ Status QkvToContext(
|
|||
const T* extra_add_qk,
|
||||
T* present,
|
||||
bool use_persistent_softmax,
|
||||
MHARunner* fused_runner) {
|
||||
MHARunner* fused_runner,
|
||||
const int v_head_size) {
|
||||
|
||||
const int max_threads_per_block = prop.maxThreadsPerBlock;
|
||||
|
||||
// input should be BxSx3xNxH => qkv: 3xBxNxSxH
|
||||
T* qkv = workspace;
|
||||
if (bias == nullptr) {
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads,
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, false, input, qkv));
|
||||
|
||||
} else {
|
||||
// For fused TRT attention, qkv need transpose to BxSxNx3xH
|
||||
const int format = (nullptr == fused_runner ? 1 : 2);
|
||||
const bool enable_half4 = true;
|
||||
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size,
|
||||
sequence_length, num_heads, head_size,
|
||||
sequence_length, num_heads, qk_head_size,
|
||||
input, bias, qkv,
|
||||
enable_half4);
|
||||
enable_half4, v_head_size);
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
}
|
||||
|
||||
// Q, K, V has size BxNxSxH
|
||||
const int batches = batch_size * num_heads;
|
||||
const int size_per_batch = sequence_length * head_size;
|
||||
const int total_size = batches * size_per_batch;
|
||||
const int size_per_batch_qk = sequence_length * qk_head_size;
|
||||
const int size_per_batch_v = sequence_length * v_head_size;
|
||||
const int total_size_qk = batches * size_per_batch_qk;
|
||||
|
||||
T* scratch1;
|
||||
scratch1 = qkv + (batches * sequence_length * (qk_head_size + qk_head_size + v_head_size));
|
||||
|
||||
T* scratch1 = qkv + 3 * total_size;
|
||||
T* temp_output = scratch1;
|
||||
if (nullptr != fused_runner && bias != nullptr) {
|
||||
int* sequence_offset = reinterpret_cast<int*>(qkv + 4 * total_size);
|
||||
int* sequence_offset = reinterpret_cast<int*>(qkv + 4 * total_size_qk);
|
||||
LaunchTrtSequenceOffset(sequence_offset, mask_index, batch_size, stream);
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
|
||||
|
|
@ -147,44 +155,45 @@ Status QkvToContext(
|
|||
T* scratch2 = scratch1 + (bytes / element_size);
|
||||
|
||||
const T* q = qkv;
|
||||
const T* k = q + total_size;
|
||||
const T* v = k + total_size;
|
||||
const T* k = q + (batches * sequence_length * qk_head_size);
|
||||
const T* v = k + (batches * sequence_length * qk_head_size);
|
||||
|
||||
cublasSetStream(cublas, stream);
|
||||
|
||||
// Concat past (2xBxNxS'xH) to present (2xBxNxS*xH):
|
||||
// past_k (BxNxS'xH) + k (BxNxSxH) => present_k (BxNxS*xH)
|
||||
// past_v (BxNxS'xH) + v (BxNxSxH) => present_v (BxNxS*xH)
|
||||
const int present_size_per_batch = all_sequence_length * head_size;
|
||||
const int present_size_per_batch_k = all_sequence_length * qk_head_size;
|
||||
const int present_size_per_batch_v = all_sequence_length * v_head_size;
|
||||
if (nullptr != present) {
|
||||
ORT_RETURN_IF_ERROR(
|
||||
LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, head_size, num_heads,
|
||||
LaunchConcatPastToPresent(stream, all_sequence_length, sequence_length, batch_size, qk_head_size, num_heads,
|
||||
max_threads_per_block, past, k, present));
|
||||
|
||||
// update pointers to present_k and present_v.
|
||||
k = present;
|
||||
v = present + batches * present_size_per_batch;
|
||||
v = present + batches * present_size_per_batch_k;
|
||||
}
|
||||
|
||||
// Raw attention mask could be 2D (BxS) or 3D (BxSxS*) or 4D(Bx1xMxM), where M is the max sequence length.
|
||||
bool use_raw_attention_mask = (nullptr != mask_index && mask_index_dims.size() >= 2);
|
||||
|
||||
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H) and store in scratch1: BxNxSxS*
|
||||
// compute Q*K' (as K'*Q), scaled by 1/sqrt(H_qk) and store in scratch1: BxNxSxS*
|
||||
// Q: BxNxSxH, K (present_k): BxNxS*xH, Q*K': BxNxSxS*
|
||||
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(head_size));
|
||||
const float rsqrt_head_size = 1.f / sqrt(static_cast<float>(qk_head_size));
|
||||
const int temp_matrix_size = sequence_length * all_sequence_length;
|
||||
float one = 1.0f;
|
||||
float zero = 0.f;
|
||||
|
||||
// For raw attention mask, the scalar 1/sqrt(H) is moved to combine with softmax computation.
|
||||
// For raw attention mask, the scalar 1/sqrt(H_qk) is moved to combine with softmax computation.
|
||||
float alpha = use_raw_attention_mask ? one : rsqrt_head_size;
|
||||
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
|
||||
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
all_sequence_length, sequence_length, head_size,
|
||||
&alpha, k, head_size, present_size_per_batch,
|
||||
q, head_size, size_per_batch,
|
||||
&zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop));
|
||||
cublas, CUBLAS_OP_T, CUBLAS_OP_N,
|
||||
all_sequence_length, sequence_length, qk_head_size,
|
||||
&alpha, k, qk_head_size, present_size_per_batch_k,
|
||||
q, qk_head_size, size_per_batch_qk,
|
||||
&zero, scratch1, all_sequence_length, temp_matrix_size, batches, prop));
|
||||
|
||||
// apply softmax and store result P to scratch2: BxNxSxS*
|
||||
if (use_raw_attention_mask) { // 2d, 3d or 4d attention mask
|
||||
|
|
@ -214,13 +223,13 @@ Status QkvToContext(
|
|||
temp_output = qkv;
|
||||
CUBLAS_RETURN_IF_ERROR(cublasGemmStridedBatchedHelper(
|
||||
cublas, CUBLAS_OP_N, CUBLAS_OP_N,
|
||||
head_size, sequence_length, all_sequence_length,
|
||||
&one, v, head_size, present_size_per_batch,
|
||||
v_head_size, sequence_length, all_sequence_length,
|
||||
&one, v, v_head_size, present_size_per_batch_v,
|
||||
scratch2, all_sequence_length, temp_matrix_size,
|
||||
&zero, temp_output, head_size, size_per_batch, batches, prop));
|
||||
&zero, temp_output, v_head_size, size_per_batch_v, batches, prop));
|
||||
|
||||
// temp_output is BxNxSxH, transpose to output BxSxNxH
|
||||
return LaunchTransCtx(stream, sequence_length, batch_size, head_size, num_heads,
|
||||
return LaunchTransCtx(stream, sequence_length, batch_size, v_head_size, num_heads,
|
||||
max_threads_per_block, false, temp_output, output);
|
||||
}
|
||||
|
||||
|
|
@ -232,7 +241,7 @@ Status LaunchAttentionKernel(
|
|||
int batch_size,
|
||||
int sequence_length,
|
||||
int num_heads,
|
||||
int head_size,
|
||||
const int qk_head_size,
|
||||
int past_sequence_length,
|
||||
bool is_unidirectional,
|
||||
const void* input,
|
||||
|
|
@ -244,13 +253,14 @@ Status LaunchAttentionKernel(
|
|||
void* workspace,
|
||||
void* output,
|
||||
void* present,
|
||||
void* fused_runner) {
|
||||
void* fused_runner,
|
||||
const int v_head_size) {
|
||||
// For testing, environment variable ORT_TRANSFORMER_OPTIONS=1 could enable persistent softmax used in Torch.
|
||||
const TransformerOptions* options = TransformerOptions::GetInstance();
|
||||
bool use_persistent_softmax = options->IsPrecisionMode() && !options->DisablePersistentSoftmax();
|
||||
|
||||
if (element_size == 2) {
|
||||
return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
|
||||
return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, qk_head_size, element_size,
|
||||
reinterpret_cast<const half*>(input),
|
||||
reinterpret_cast<const half*>(bias),
|
||||
reinterpret_cast<half*>(output),
|
||||
|
|
@ -263,9 +273,10 @@ Status LaunchAttentionKernel(
|
|||
reinterpret_cast<const half*>(extra_add_qk),
|
||||
reinterpret_cast<half*>(present),
|
||||
use_persistent_softmax,
|
||||
reinterpret_cast<MHARunner*>(fused_runner));
|
||||
reinterpret_cast<MHARunner*>(fused_runner),
|
||||
v_head_size);
|
||||
} else {
|
||||
return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, head_size, element_size,
|
||||
return QkvToContext(prop, cublas, stream, batch_size, sequence_length, num_heads, qk_head_size, element_size,
|
||||
reinterpret_cast<const float*>(input),
|
||||
reinterpret_cast<const float*>(bias),
|
||||
reinterpret_cast<float*>(output),
|
||||
|
|
@ -278,7 +289,8 @@ Status LaunchAttentionKernel(
|
|||
reinterpret_cast<const float*>(extra_add_qk),
|
||||
reinterpret_cast<float*>(present),
|
||||
use_persistent_softmax,
|
||||
nullptr);
|
||||
nullptr,
|
||||
v_head_size);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -21,10 +21,11 @@ size_t GetAttentionWorkspaceSize(
|
|||
size_t element_size,
|
||||
size_t batchsize,
|
||||
size_t num_heads,
|
||||
size_t head_size,
|
||||
size_t qk_head_size,
|
||||
size_t sequence_length,
|
||||
size_t past_sequence_length,
|
||||
void* fused_runner);
|
||||
void* fused_runner,
|
||||
size_t v_head_size);
|
||||
|
||||
Status LaunchAttentionKernel(
|
||||
const cudaDeviceProp& prop, // Device Properties
|
||||
|
|
@ -34,7 +35,7 @@ Status LaunchAttentionKernel(
|
|||
int batch_size, // Batch size (B)
|
||||
int sequence_length, // Sequence length (S)
|
||||
int num_heads, // Number of attention heads (N)
|
||||
int head_size, // Hidden layer size per head (H)
|
||||
const int qk_head_size, // Hidden layer size per head for q and k (H_qk)
|
||||
int past_sequence_length, // Sequence length in past state
|
||||
bool is_unidirectional, // Whether there is unidirecitonal mask.
|
||||
const void* input, // Input tensor
|
||||
|
|
@ -46,8 +47,8 @@ Status LaunchAttentionKernel(
|
|||
void* workspace, // Temporary buffer
|
||||
void* output, // Output tensor
|
||||
void* present, // Present state output
|
||||
void* fused_runner // Fused multi-head attention
|
||||
);
|
||||
void* fused_runner, // Fused multi-head attention
|
||||
const int v_head_size); // Hidden layer size per head for v (H_v)
|
||||
|
||||
Status LaunchDecoderAttentionKernel(
|
||||
const cudaDeviceProp& prop, // Device Properties
|
||||
|
|
|
|||
|
|
@ -868,6 +868,7 @@ Status LongformerQkvToContext(
|
|||
// The order of qkv space:
|
||||
// Q, K, V, Global_K, Global_V, Global_Q (format 0)
|
||||
// Q, K, V, Global_Q, Global_K, Global_V (format 1)
|
||||
// Assume H_q == H_k == H_v
|
||||
if (format == 1 || max_num_global == 0 || nullptr == global_input) {
|
||||
if (bias == nullptr) {
|
||||
ORT_RETURN_IF_ERROR(LaunchTransQkv(stream, 3, sequence_length, batch_size, head_size, num_heads,
|
||||
|
|
@ -876,7 +877,7 @@ Status LongformerQkvToContext(
|
|||
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size,
|
||||
sequence_length, num_heads, head_size,
|
||||
input, bias, qkv,
|
||||
use_half4);
|
||||
use_half4, head_size);
|
||||
}
|
||||
|
||||
if (max_num_global > 0 && nullptr != global_input) {
|
||||
|
|
@ -887,20 +888,20 @@ Status LongformerQkvToContext(
|
|||
LaunchAddBiasTranspose(stream, 3, format, max_threads_per_block, batch_size,
|
||||
sequence_length, num_heads, head_size,
|
||||
global_input, global_bias, qkv + 3 * elements,
|
||||
use_half4);
|
||||
use_half4, head_size);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
LaunchAddBiasTranspose(stream, 5, format, max_threads_per_block, batch_size,
|
||||
sequence_length, num_heads, head_size,
|
||||
input, bias, qkv,
|
||||
use_half4);
|
||||
use_half4, head_size);
|
||||
|
||||
compact_global_q = (disable_compact_memory == false);
|
||||
LaunchAddBiasTranspose(stream, 1, format, max_threads_per_block, batch_size,
|
||||
compact_global_q ? max_num_global : sequence_length, num_heads, head_size,
|
||||
global_input + 2 * elements, global_bias, qkv + 5 * elements,
|
||||
use_half4);
|
||||
use_half4, head_size);
|
||||
}
|
||||
CUDA_RETURN_IF_ERROR(cudaGetLastError());
|
||||
|
||||
|
|
|
|||
|
|
@ -119,7 +119,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
|
||||
const auto& bias_shape = bias->Shape();
|
||||
const int hidden_size = SafeInt<int>(bias_shape.GetDims()[0]) / 3;
|
||||
const int head_size = hidden_size / num_heads_;
|
||||
// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in quantization
|
||||
const int qkv_head_size[3] = {hidden_size / num_heads_, hidden_size / num_heads_, hidden_size / num_heads_};
|
||||
|
||||
TensorShapeVector output_shape(3);
|
||||
output_shape[0] = shape[0];
|
||||
|
|
@ -166,12 +167,12 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
n));
|
||||
|
||||
int past_sequence_length = 0;
|
||||
Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, head_size,
|
||||
Tensor* present_tensor = GetPresent(context, past_tensor, batch_size, qkv_head_size[1],
|
||||
sequence_length, past_sequence_length);
|
||||
|
||||
void* fused_runner = nullptr; // TODO(tianleiwu): use fused kernel to speed up
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, head_size,
|
||||
sequence_length, past_sequence_length, fused_runner);
|
||||
size_t workSpaceSize = GetAttentionWorkspaceSize(element_size, batch_size, num_heads_, qkv_head_size[0],
|
||||
sequence_length, past_sequence_length, fused_runner, qkv_head_size[2]);
|
||||
|
||||
auto work_space = GetScratchBuffer<void>(workSpaceSize);
|
||||
return LaunchAttentionKernel(
|
||||
|
|
@ -182,7 +183,7 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
batch_size,
|
||||
sequence_length,
|
||||
num_heads_,
|
||||
head_size,
|
||||
qkv_head_size[0],
|
||||
past_sequence_length,
|
||||
is_unidirectional_,
|
||||
reinterpret_cast<const void*>(gemm_buffer.get()),
|
||||
|
|
@ -194,7 +195,8 @@ Status QAttention<T, int8_t>::ComputeInternal(OpKernelContext* context) const {
|
|||
work_space.get(),
|
||||
output->MutableData<T>(),
|
||||
nullptr == present_tensor ? nullptr : present_tensor->MutableData<T>(),
|
||||
fused_runner);
|
||||
fused_runner,
|
||||
qkv_head_size[2]);
|
||||
}
|
||||
|
||||
} // namespace cuda
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ Status Attention<T>::ComputeInternal(OpKernelContext* context) const {
|
|||
int sequence_length = static_cast<int>(shape[1]);
|
||||
int input_hidden_size = static_cast<int>(shape[2]);
|
||||
|
||||
// Note: Scenario where q_hidden_size == k_hidden_size != v_hidden_size is not supported in ROCM EP
|
||||
// bias shape (3 * hidden_size)
|
||||
const auto& bias_shape = bias->Shape();
|
||||
int hidden_size = static_cast<int>(bias_shape[0]) / 3;
|
||||
|
|
|
|||
|
|
@ -41,12 +41,13 @@ static void RunAttentionTest(
|
|||
bool only_enable_cuda = false,
|
||||
bool only_enable_cpu = false,
|
||||
std::vector<int32_t> qkv_sizes = {},
|
||||
const std::vector<float>& extra_add_data = {}) {
|
||||
const std::vector<float>& extra_add_data = {},
|
||||
const bool disable_rocm = false) {
|
||||
input_hidden_size = (input_hidden_size == 0 ? hidden_size : input_hidden_size); // By default, no pruning.
|
||||
|
||||
int min_cuda_architecture = use_float16 ? 530 : 0;
|
||||
bool enable_cuda = HasCudaEnvironment(min_cuda_architecture) && !is_weights_constant && !only_enable_cpu;
|
||||
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu;
|
||||
bool enable_rocm = (nullptr != DefaultRocmExecutionProvider().get()) && !is_weights_constant && !only_enable_cpu && !disable_rocm;
|
||||
bool enable_cpu = (nullptr != DefaultCpuExecutionProvider().get()) && !use_float16 && !only_enable_cuda;
|
||||
|
||||
int head_size = hidden_size / number_of_heads;
|
||||
|
|
@ -188,17 +189,18 @@ static void RunAttentionTest(
|
|||
bool only_enable_cuda = false,
|
||||
bool only_enable_cpu = false,
|
||||
const std::vector<int32_t> qkv_sizes = {},
|
||||
const std::vector<float>& extra_add_data = {}) {
|
||||
const std::vector<float>& extra_add_data = {},
|
||||
const bool disable_rocm = false) {
|
||||
RunAttentionTest(input_data, weights_data, false, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
|
||||
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data);
|
||||
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data, disable_rocm);
|
||||
RunAttentionTest(input_data, weights_data, true, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
use_float16, is_unidirectional, use_past_state, past_sequence_length,
|
||||
past_data, present_data, mask_index_type, input_hidden_size, max_sequence_length,
|
||||
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data);
|
||||
only_enable_cuda, only_enable_cpu, qkv_sizes, extra_add_data, disable_rocm);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1) {
|
||||
|
|
@ -267,7 +269,11 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr1) {
|
|||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
|
||||
0, false, true, qkv_sizes);
|
||||
0, true, false, qkv_sizes, {}, true);
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
|
||||
0, false, true, qkv_sizes, {}, true);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1WithQKVAttr2) {
|
||||
|
|
@ -304,7 +310,11 @@ TEST(AttentionTest, AttentionBatch1WithQKVAttr2) {
|
|||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
|
||||
0, false, true, qkv_sizes);
|
||||
0, true, false, qkv_sizes, {}, true);
|
||||
RunAttentionTest(input_data, weight_data, bias_data, mask_index_data, output_data,
|
||||
batch_size, sequence_length, hidden_size, number_of_heads,
|
||||
false, false, false, 0, nullptr, nullptr, kMaskIndexEnd, 0,
|
||||
0, false, true, qkv_sizes, {}, true);
|
||||
}
|
||||
|
||||
TEST(AttentionTest, AttentionBatch1ExtraAdd) {
|
||||
|
|
|
|||
Loading…
Reference in a new issue